신은섭(Shin Eun Seop)

apply ml

...@@ -19,7 +19,12 @@ ...@@ -19,7 +19,12 @@
19 <dependency> 19 <dependency>
20 <groupId>org.apache.spark</groupId> 20 <groupId>org.apache.spark</groupId>
21 <artifactId>spark-sql_2.11</artifactId> 21 <artifactId>spark-sql_2.11</artifactId>
22 - <version>2.2.0</version> 22 + <version>2.3.0</version>
23 + </dependency>
24 + <dependency>
25 + <groupId>org.apache.spark</groupId>
26 + <artifactId>spark-mllib_2.11</artifactId>
27 + <version>2.3.0</version>
23 </dependency> 28 </dependency>
24 29
25 </dependencies> 30 </dependencies>
......
1 import org.apache.spark.SparkConf; 1 import org.apache.spark.SparkConf;
2 -import org.apache.spark.api.java.JavaPairRDD;
3 import org.apache.spark.api.java.JavaRDD; 2 import org.apache.spark.api.java.JavaRDD;
4 import org.apache.spark.api.java.JavaSparkContext; 3 import org.apache.spark.api.java.JavaSparkContext;
5 import org.apache.spark.api.java.function.Function; 4 import org.apache.spark.api.java.function.Function;
5 +import org.apache.spark.ml.Pipeline;
6 +import org.apache.spark.ml.PipelineModel;
7 +import org.apache.spark.ml.PipelineStage;
8 +import org.apache.spark.ml.evaluation.RegressionEvaluator;
9 +import org.apache.spark.ml.feature.VectorAssembler;
10 +import org.apache.spark.ml.feature.VectorIndexer;
11 +import org.apache.spark.ml.feature.VectorIndexerModel;
12 +import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
13 +import org.apache.spark.ml.regression.DecisionTreeRegressor;
6 import org.apache.spark.sql.Dataset; 14 import org.apache.spark.sql.Dataset;
15 +import org.apache.spark.sql.Encoders;
7 import org.apache.spark.sql.Row; 16 import org.apache.spark.sql.Row;
8 import org.apache.spark.sql.SQLContext; 17 import org.apache.spark.sql.SQLContext;
9 -import org.apache.spark.sql.SparkSession;
10 -import org.apache.spark.sql.types.StructType;
11 import scala.Serializable; 18 import scala.Serializable;
12 -import scala.Tuple2;
13 19
14 import java.util.*; 20 import java.util.*;
15 21
22 +
23 +// ml
24 +
16 //ip,app,device,os,channel,click_time,attributed_time,is_attributed 25 //ip,app,device,os,channel,click_time,attributed_time,is_attributed
17 //87540,12,1,13,497,2017-11-07 09:30:38,,0 26 //87540,12,1,13,497,2017-11-07 09:30:38,,0
18 -class Record implements Serializable {
19 - Integer ip;
20 - Integer app;
21 - Integer device;
22 - Integer os;
23 - Integer channel;
24 - Calendar clickTime;
25 - Calendar attributedTime;
26 - Boolean isAttributed;
27 - Integer clickInTenMins;
28 -
29 - // constructor , getters and setters
30 - public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed) {
31 - ip = new Integer(pIp);
32 - app = new Integer(pApp);
33 - device = new Integer(pDevice);
34 - os = new Integer(pOs);
35 - channel = new Integer(pChannel);
36 - clickTime = pClickTime;
37 - attributedTime = pAttributedTime;
38 - isAttributed = new Boolean(pIsAttributed);
39 - clickInTenMins = new Integer(0);
40 - }
41 -
42 - public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed, int pClickInTenMins) {
43 - ip = new Integer(pIp);
44 - app = new Integer(pApp);
45 - device = new Integer(pDevice);
46 - os = new Integer(pOs);
47 - channel = new Integer(pChannel);
48 - clickTime = pClickTime;
49 - attributedTime = pAttributedTime;
50 - isAttributed = new Boolean(pIsAttributed);
51 - clickInTenMins = new Integer(pClickInTenMins);
52 - }
53 -}
54 27
55 class RecordComparator implements Comparator<Record> { 28 class RecordComparator implements Comparator<Record> {
56 @Override 29 @Override
...@@ -72,14 +45,14 @@ public class MapExample { ...@@ -72,14 +45,14 @@ public class MapExample {
72 static SQLContext sqlContext = new SQLContext(sc); 45 static SQLContext sqlContext = new SQLContext(sc);
73 46
74 public static void main(String[] args) throws Exception { 47 public static void main(String[] args) throws Exception {
75 - JavaRDD<String> file = sc.textFile("/Users/hyeongyunmun/Dropbox/DetectFraudClick/data/train.csv", 1); 48 + JavaRDD<String> file = sc.textFile("data/train.csv", 1);
76 49
77 final String header = file.first(); 50 final String header = file.first();
78 JavaRDD<String> data = file.filter(line -> !line.equalsIgnoreCase(header)); 51 JavaRDD<String> data = file.filter(line -> !line.equalsIgnoreCase(header));
79 52
80 JavaRDD<Record> records = data.map(line -> { 53 JavaRDD<Record> records = data.map(line -> {
81 String[] fields = line.split(","); 54 String[] fields = line.split(",");
82 - Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), DateUtil.CalendarFromString(fields[5]), DateUtil.CalendarFromString(fields[6]), "1".equalsIgnoreCase(fields[7].trim())); 55 + Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), fields[5], fields[6], Integer.parseInt(fields[7].trim()));
83 return sd; 56 return sd;
84 }); 57 });
85 58
...@@ -89,9 +62,9 @@ public class MapExample { ...@@ -89,9 +62,9 @@ public class MapExample {
89 // return new Tuple2(value._2(),value._3()); 62 // return new Tuple2(value._2(),value._3());
90 // }}).sortByKey(new TupleComparator()).values(); 63 // }}).sortByKey(new TupleComparator()).values();
91 64
92 - JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, Calendar>() { 65 + JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, String>() {
93 @Override 66 @Override
94 - public Calendar call(Record record) throws Exception { 67 + public String call(Record record) throws Exception {
95 return record.clickTime; 68 return record.clickTime;
96 } 69 }
97 }, true, 1); 70 }, true, 1);
...@@ -161,23 +134,83 @@ public class MapExample { ...@@ -161,23 +134,83 @@ public class MapExample {
161 134
162 Record record = list.get(i); 135 Record record = list.get(i);
163 136
137 + Calendar recordI = DateUtil.CalendarFromString(record.clickTime);
138 +
164 Calendar addTen = Calendar.getInstance(); 139 Calendar addTen = Calendar.getInstance();
165 - addTen.setTime(record.clickTime.getTime()); 140 + addTen.setTime(recordI.getTime());
166 addTen.add(Calendar.MINUTE, 10); 141 addTen.add(Calendar.MINUTE, 10);
167 142
168 int count = 0; 143 int count = 0;
169 144
170 - for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0 145 + for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0; j++) {
171 - && list.get(j).clickTime.compareTo(record.clickTime) > 0 &&list.get(j).clickTime.compareTo(addTen) < 0; j++) 146 + Calendar recordJ = DateUtil.CalendarFromString(list.get(j).clickTime);
172 - count++; 147 + if (recordJ.compareTo(recordI) > 0 && recordJ.compareTo(addTen) < 0) {
148 + count++;
149 + } else {
150 + break;
151 + }
152 + }
173 153
174 resultList.add(new Record(record.ip, record.app, record.device, record.os, record.channel, record.clickTime, record.attributedTime, record.isAttributed, count)); 154 resultList.add(new Record(record.ip, record.app, record.device, record.os, record.channel, record.clickTime, record.attributedTime, record.isAttributed, count));
175 155
176 } 156 }
177 157
178 -
179 JavaRDD<Record> result = sc.parallelize(resultList); 158 JavaRDD<Record> result = sc.parallelize(resultList);
180 - result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);}); 159 +// result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);});
181 - 160 +
161 + // Automatically identify categorical features, and index them.
162 + // Set maxCategories so features with > 4 distinct values are treated as continuous.
163 + Dataset<Row> resultds = sqlContext.createDataFrame(result, Record.class);
164 +
165 + System.out.println("schema start");
166 + resultds.printSchema();
167 + System.out.println("schema end");
168 +
169 + VectorAssembler assembler = new VectorAssembler()
170 + .setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"})
171 + .setOutputCol("features");
172 +
173 + Dataset<Row> output = assembler.transform(resultds);
174 +
175 + VectorIndexerModel featureIndexer = new VectorIndexer()
176 + .setInputCol("features")
177 + .setOutputCol("indexedFeatures")
178 + .setMaxCategories(2)
179 + .fit(output);
180 +
181 + // Split the result into training and test sets (30% held out for testing).
182 + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
183 + Dataset<Row> trainingData = splits[0];
184 + Dataset<Row> testData = splits[1];
185 +
186 + // Train a DecisionTree model.
187 + DecisionTreeRegressor dt = new DecisionTreeRegressor()
188 + .setFeaturesCol("indexedFeatures").setLabelCol("attributed");
189 +
190 + // Chain indexer and tree in a Pipeline.
191 + Pipeline pipeline = new Pipeline()
192 + .setStages(new PipelineStage[]{featureIndexer, dt});
193 +
194 + // Train model. This also runs the indexer.
195 + PipelineModel model = pipeline.fit(trainingData);
196 +
197 + // Make predictions.
198 + Dataset<Row> predictions = model.transform(testData);
199 +
200 + // Select example rows to display.
201 + predictions.select("attributed", "features").show(5);
202 +
203 + // Select (prediction, true label) and compute test error.
204 + RegressionEvaluator evaluator = new RegressionEvaluator()
205 + .setLabelCol("attributed")
206 + .setPredictionCol("prediction")
207 + .setMetricName("rmse");
208 + double rmse = evaluator.evaluate(predictions);
209 + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
210 +
211 + DecisionTreeRegressionModel treeModel =
212 + (DecisionTreeRegressionModel) (model.stages()[1]);
213 + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
214 +
182 } 215 }
183 } 216 }
......
1 +import scala.Serializable;
2 +
3 +public class Record implements Serializable {
4 + Integer ip;
5 + Integer app;
6 + Integer device;
7 + Integer os;
8 + Integer channel;
9 + String clickTime;
10 + String attributedTime;
11 + Integer isAttributed;
12 + Integer clickInTenMins;
13 +
14 + // constructor , getters and setters
15 + public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed) {
16 + ip = new Integer(pIp);
17 + app = new Integer(pApp);
18 + device = new Integer(pDevice);
19 + os = new Integer(pOs);
20 + channel = new Integer(pChannel);
21 + clickTime = pClickTime;
22 + attributedTime = pAttributedTime;
23 + isAttributed = new Integer(pIsAttributed);
24 + clickInTenMins = new Integer(0);
25 + }
26 +
27 + public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed, int pClickInTenMins) {
28 + ip = new Integer(pIp);
29 + app = new Integer(pApp);
30 + device = new Integer(pDevice);
31 + os = new Integer(pOs);
32 + channel = new Integer(pChannel);
33 + clickTime = pClickTime;
34 + attributedTime = pAttributedTime;
35 + isAttributed = new Integer(pIsAttributed);
36 + clickInTenMins = new Integer(pClickInTenMins);
37 + }
38 +
39 + public Integer getIp() {
40 + return ip;
41 + }
42 +
43 + public void setIp(Integer ip) {
44 + this.ip = ip;
45 + }
46 +
47 + public Integer getApp() {
48 + return app;
49 + }
50 +
51 + public void setApp(Integer app) {
52 + this.app = app;
53 + }
54 +
55 + public Integer getDevice() {
56 + return device;
57 + }
58 +
59 + public void setDevice(Integer device) {
60 + this.device = device;
61 + }
62 +
63 + public Integer getOs() {
64 + return os;
65 + }
66 +
67 + public void setOs(Integer os) {
68 + this.os = os;
69 + }
70 +
71 + public Integer getChannel() {
72 + return channel;
73 + }
74 +
75 + public void setChannel(Integer channel) {
76 + this.channel = channel;
77 + }
78 +
79 + public String getClickTime() {
80 + return clickTime;
81 + }
82 +
83 + public void setClickTime(String clickTime) {
84 + this.clickTime = clickTime;
85 + }
86 +
87 + public String getAttributedTime() {
88 + return attributedTime;
89 + }
90 +
91 + public void setAttributedTime(String attributedTime) {
92 + this.attributedTime = attributedTime;
93 + }
94 +
95 + public Integer getAttributed() {
96 + return isAttributed;
97 + }
98 +
99 + public void setAttributed(Integer attributed) {
100 + isAttributed = attributed;
101 + }
102 +
103 + public Integer getClickInTenMins() {
104 + return clickInTenMins;
105 + }
106 +
107 + public void setClickInTenMins(Integer clickInTenMins) {
108 + this.clickInTenMins = clickInTenMins;
109 + }
110 +}
...\ No newline at end of file ...\ No newline at end of file