Showing
3 changed files
with
200 additions
and
52 deletions
... | @@ -19,7 +19,12 @@ | ... | @@ -19,7 +19,12 @@ |
19 | <dependency> | 19 | <dependency> |
20 | <groupId>org.apache.spark</groupId> | 20 | <groupId>org.apache.spark</groupId> |
21 | <artifactId>spark-sql_2.11</artifactId> | 21 | <artifactId>spark-sql_2.11</artifactId> |
22 | - <version>2.2.0</version> | 22 | + <version>2.3.0</version> |
23 | + </dependency> | ||
24 | + <dependency> | ||
25 | + <groupId>org.apache.spark</groupId> | ||
26 | + <artifactId>spark-mllib_2.11</artifactId> | ||
27 | + <version>2.3.0</version> | ||
23 | </dependency> | 28 | </dependency> |
24 | 29 | ||
25 | </dependencies> | 30 | </dependencies> | ... | ... |
1 | import org.apache.spark.SparkConf; | 1 | import org.apache.spark.SparkConf; |
2 | -import org.apache.spark.api.java.JavaPairRDD; | ||
3 | import org.apache.spark.api.java.JavaRDD; | 2 | import org.apache.spark.api.java.JavaRDD; |
4 | import org.apache.spark.api.java.JavaSparkContext; | 3 | import org.apache.spark.api.java.JavaSparkContext; |
5 | import org.apache.spark.api.java.function.Function; | 4 | import org.apache.spark.api.java.function.Function; |
5 | +import org.apache.spark.ml.Pipeline; | ||
6 | +import org.apache.spark.ml.PipelineModel; | ||
7 | +import org.apache.spark.ml.PipelineStage; | ||
8 | +import org.apache.spark.ml.evaluation.RegressionEvaluator; | ||
9 | +import org.apache.spark.ml.feature.VectorAssembler; | ||
10 | +import org.apache.spark.ml.feature.VectorIndexer; | ||
11 | +import org.apache.spark.ml.feature.VectorIndexerModel; | ||
12 | +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | ||
13 | +import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
6 | import org.apache.spark.sql.Dataset; | 14 | import org.apache.spark.sql.Dataset; |
15 | +import org.apache.spark.sql.Encoders; | ||
7 | import org.apache.spark.sql.Row; | 16 | import org.apache.spark.sql.Row; |
8 | import org.apache.spark.sql.SQLContext; | 17 | import org.apache.spark.sql.SQLContext; |
9 | -import org.apache.spark.sql.SparkSession; | ||
10 | -import org.apache.spark.sql.types.StructType; | ||
11 | import scala.Serializable; | 18 | import scala.Serializable; |
12 | -import scala.Tuple2; | ||
13 | 19 | ||
14 | import java.util.*; | 20 | import java.util.*; |
15 | 21 | ||
22 | + | ||
23 | +// ml | ||
24 | + | ||
16 | //ip,app,device,os,channel,click_time,attributed_time,is_attributed | 25 | //ip,app,device,os,channel,click_time,attributed_time,is_attributed |
17 | //87540,12,1,13,497,2017-11-07 09:30:38,,0 | 26 | //87540,12,1,13,497,2017-11-07 09:30:38,,0 |
18 | -class Record implements Serializable { | ||
19 | - Integer ip; | ||
20 | - Integer app; | ||
21 | - Integer device; | ||
22 | - Integer os; | ||
23 | - Integer channel; | ||
24 | - Calendar clickTime; | ||
25 | - Calendar attributedTime; | ||
26 | - Boolean isAttributed; | ||
27 | - Integer clickInTenMins; | ||
28 | - | ||
29 | - // constructor , getters and setters | ||
30 | - public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed) { | ||
31 | - ip = new Integer(pIp); | ||
32 | - app = new Integer(pApp); | ||
33 | - device = new Integer(pDevice); | ||
34 | - os = new Integer(pOs); | ||
35 | - channel = new Integer(pChannel); | ||
36 | - clickTime = pClickTime; | ||
37 | - attributedTime = pAttributedTime; | ||
38 | - isAttributed = new Boolean(pIsAttributed); | ||
39 | - clickInTenMins = new Integer(0); | ||
40 | - } | ||
41 | - | ||
42 | - public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed, int pClickInTenMins) { | ||
43 | - ip = new Integer(pIp); | ||
44 | - app = new Integer(pApp); | ||
45 | - device = new Integer(pDevice); | ||
46 | - os = new Integer(pOs); | ||
47 | - channel = new Integer(pChannel); | ||
48 | - clickTime = pClickTime; | ||
49 | - attributedTime = pAttributedTime; | ||
50 | - isAttributed = new Boolean(pIsAttributed); | ||
51 | - clickInTenMins = new Integer(pClickInTenMins); | ||
52 | - } | ||
53 | -} | ||
54 | 27 | ||
55 | class RecordComparator implements Comparator<Record> { | 28 | class RecordComparator implements Comparator<Record> { |
56 | @Override | 29 | @Override |
... | @@ -72,14 +45,14 @@ public class MapExample { | ... | @@ -72,14 +45,14 @@ public class MapExample { |
72 | static SQLContext sqlContext = new SQLContext(sc); | 45 | static SQLContext sqlContext = new SQLContext(sc); |
73 | 46 | ||
74 | public static void main(String[] args) throws Exception { | 47 | public static void main(String[] args) throws Exception { |
75 | - JavaRDD<String> file = sc.textFile("/Users/hyeongyunmun/Dropbox/DetectFraudClick/data/train.csv", 1); | 48 | + JavaRDD<String> file = sc.textFile("data/train.csv", 1); |
76 | 49 | ||
77 | final String header = file.first(); | 50 | final String header = file.first(); |
78 | JavaRDD<String> data = file.filter(line -> !line.equalsIgnoreCase(header)); | 51 | JavaRDD<String> data = file.filter(line -> !line.equalsIgnoreCase(header)); |
79 | 52 | ||
80 | JavaRDD<Record> records = data.map(line -> { | 53 | JavaRDD<Record> records = data.map(line -> { |
81 | String[] fields = line.split(","); | 54 | String[] fields = line.split(","); |
82 | - Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), DateUtil.CalendarFromString(fields[5]), DateUtil.CalendarFromString(fields[6]), "1".equalsIgnoreCase(fields[7].trim())); | 55 | + Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), fields[5], fields[6], Integer.parseInt(fields[7].trim())); |
83 | return sd; | 56 | return sd; |
84 | }); | 57 | }); |
85 | 58 | ||
... | @@ -89,9 +62,9 @@ public class MapExample { | ... | @@ -89,9 +62,9 @@ public class MapExample { |
89 | // return new Tuple2(value._2(),value._3()); | 62 | // return new Tuple2(value._2(),value._3()); |
90 | // }}).sortByKey(new TupleComparator()).values(); | 63 | // }}).sortByKey(new TupleComparator()).values(); |
91 | 64 | ||
92 | - JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, Calendar>() { | 65 | + JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, String>() { |
93 | @Override | 66 | @Override |
94 | - public Calendar call(Record record) throws Exception { | 67 | + public String call(Record record) throws Exception { |
95 | return record.clickTime; | 68 | return record.clickTime; |
96 | } | 69 | } |
97 | }, true, 1); | 70 | }, true, 1); |
... | @@ -161,23 +134,83 @@ public class MapExample { | ... | @@ -161,23 +134,83 @@ public class MapExample { |
161 | 134 | ||
162 | Record record = list.get(i); | 135 | Record record = list.get(i); |
163 | 136 | ||
137 | + Calendar recordI = DateUtil.CalendarFromString(record.clickTime); | ||
138 | + | ||
164 | Calendar addTen = Calendar.getInstance(); | 139 | Calendar addTen = Calendar.getInstance(); |
165 | - addTen.setTime(record.clickTime.getTime()); | 140 | + addTen.setTime(recordI.getTime()); |
166 | addTen.add(Calendar.MINUTE, 10); | 141 | addTen.add(Calendar.MINUTE, 10); |
167 | 142 | ||
168 | int count = 0; | 143 | int count = 0; |
169 | 144 | ||
170 | - for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0 | 145 | + for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0; j++) { |
171 | - && list.get(j).clickTime.compareTo(record.clickTime) > 0 &&list.get(j).clickTime.compareTo(addTen) < 0; j++) | 146 | + Calendar recordJ = DateUtil.CalendarFromString(list.get(j).clickTime); |
172 | - count++; | 147 | + if (recordJ.compareTo(recordI) > 0 && recordJ.compareTo(addTen) < 0) { |
148 | + count++; | ||
149 | + } else { | ||
150 | + break; | ||
151 | + } | ||
152 | + } | ||
173 | 153 | ||
174 | resultList.add(new Record(record.ip, record.app, record.device, record.os, record.channel, record.clickTime, record.attributedTime, record.isAttributed, count)); | 154 | resultList.add(new Record(record.ip, record.app, record.device, record.os, record.channel, record.clickTime, record.attributedTime, record.isAttributed, count)); |
175 | 155 | ||
176 | } | 156 | } |
177 | 157 | ||
178 | - | ||
179 | JavaRDD<Record> result = sc.parallelize(resultList); | 158 | JavaRDD<Record> result = sc.parallelize(resultList); |
180 | - result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);}); | 159 | +// result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);}); |
181 | - | 160 | + |
161 | + // Automatically identify categorical features, and index them. | ||
162 | + // Set maxCategories so features with > 4 distinct values are treated as continuous. | ||
163 | + Dataset<Row> resultds = sqlContext.createDataFrame(result, Record.class); | ||
164 | + | ||
165 | + System.out.println("schema start"); | ||
166 | + resultds.printSchema(); | ||
167 | + System.out.println("schema end"); | ||
168 | + | ||
169 | + VectorAssembler assembler = new VectorAssembler() | ||
170 | + .setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"}) | ||
171 | + .setOutputCol("features"); | ||
172 | + | ||
173 | + Dataset<Row> output = assembler.transform(resultds); | ||
174 | + | ||
175 | + VectorIndexerModel featureIndexer = new VectorIndexer() | ||
176 | + .setInputCol("features") | ||
177 | + .setOutputCol("indexedFeatures") | ||
178 | + .setMaxCategories(2) | ||
179 | + .fit(output); | ||
180 | + | ||
181 | + // Split the result into training and test sets (30% held out for testing). | ||
182 | + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
183 | + Dataset<Row> trainingData = splits[0]; | ||
184 | + Dataset<Row> testData = splits[1]; | ||
185 | + | ||
186 | + // Train a DecisionTree model. | ||
187 | + DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
188 | + .setFeaturesCol("indexedFeatures").setLabelCol("attributed"); | ||
189 | + | ||
190 | + // Chain indexer and tree in a Pipeline. | ||
191 | + Pipeline pipeline = new Pipeline() | ||
192 | + .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
193 | + | ||
194 | + // Train model. This also runs the indexer. | ||
195 | + PipelineModel model = pipeline.fit(trainingData); | ||
196 | + | ||
197 | + // Make predictions. | ||
198 | + Dataset<Row> predictions = model.transform(testData); | ||
199 | + | ||
200 | + // Select example rows to display. | ||
201 | + predictions.select("attributed", "features").show(5); | ||
202 | + | ||
203 | + // Select (prediction, true label) and compute test error. | ||
204 | + RegressionEvaluator evaluator = new RegressionEvaluator() | ||
205 | + .setLabelCol("attributed") | ||
206 | + .setPredictionCol("prediction") | ||
207 | + .setMetricName("rmse"); | ||
208 | + double rmse = evaluator.evaluate(predictions); | ||
209 | + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); | ||
210 | + | ||
211 | + DecisionTreeRegressionModel treeModel = | ||
212 | + (DecisionTreeRegressionModel) (model.stages()[1]); | ||
213 | + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | ||
214 | + | ||
182 | } | 215 | } |
183 | } | 216 | } | ... | ... |
src/main/java/Record.java
0 → 100644
1 | +import scala.Serializable; | ||
2 | + | ||
3 | +public class Record implements Serializable { | ||
4 | + Integer ip; | ||
5 | + Integer app; | ||
6 | + Integer device; | ||
7 | + Integer os; | ||
8 | + Integer channel; | ||
9 | + String clickTime; | ||
10 | + String attributedTime; | ||
11 | + Integer isAttributed; | ||
12 | + Integer clickInTenMins; | ||
13 | + | ||
14 | + // constructor , getters and setters | ||
15 | + public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed) { | ||
16 | + ip = new Integer(pIp); | ||
17 | + app = new Integer(pApp); | ||
18 | + device = new Integer(pDevice); | ||
19 | + os = new Integer(pOs); | ||
20 | + channel = new Integer(pChannel); | ||
21 | + clickTime = pClickTime; | ||
22 | + attributedTime = pAttributedTime; | ||
23 | + isAttributed = new Integer(pIsAttributed); | ||
24 | + clickInTenMins = new Integer(0); | ||
25 | + } | ||
26 | + | ||
27 | + public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed, int pClickInTenMins) { | ||
28 | + ip = new Integer(pIp); | ||
29 | + app = new Integer(pApp); | ||
30 | + device = new Integer(pDevice); | ||
31 | + os = new Integer(pOs); | ||
32 | + channel = new Integer(pChannel); | ||
33 | + clickTime = pClickTime; | ||
34 | + attributedTime = pAttributedTime; | ||
35 | + isAttributed = new Integer(pIsAttributed); | ||
36 | + clickInTenMins = new Integer(pClickInTenMins); | ||
37 | + } | ||
38 | + | ||
39 | + public Integer getIp() { | ||
40 | + return ip; | ||
41 | + } | ||
42 | + | ||
43 | + public void setIp(Integer ip) { | ||
44 | + this.ip = ip; | ||
45 | + } | ||
46 | + | ||
47 | + public Integer getApp() { | ||
48 | + return app; | ||
49 | + } | ||
50 | + | ||
51 | + public void setApp(Integer app) { | ||
52 | + this.app = app; | ||
53 | + } | ||
54 | + | ||
55 | + public Integer getDevice() { | ||
56 | + return device; | ||
57 | + } | ||
58 | + | ||
59 | + public void setDevice(Integer device) { | ||
60 | + this.device = device; | ||
61 | + } | ||
62 | + | ||
63 | + public Integer getOs() { | ||
64 | + return os; | ||
65 | + } | ||
66 | + | ||
67 | + public void setOs(Integer os) { | ||
68 | + this.os = os; | ||
69 | + } | ||
70 | + | ||
71 | + public Integer getChannel() { | ||
72 | + return channel; | ||
73 | + } | ||
74 | + | ||
75 | + public void setChannel(Integer channel) { | ||
76 | + this.channel = channel; | ||
77 | + } | ||
78 | + | ||
79 | + public String getClickTime() { | ||
80 | + return clickTime; | ||
81 | + } | ||
82 | + | ||
83 | + public void setClickTime(String clickTime) { | ||
84 | + this.clickTime = clickTime; | ||
85 | + } | ||
86 | + | ||
87 | + public String getAttributedTime() { | ||
88 | + return attributedTime; | ||
89 | + } | ||
90 | + | ||
91 | + public void setAttributedTime(String attributedTime) { | ||
92 | + this.attributedTime = attributedTime; | ||
93 | + } | ||
94 | + | ||
95 | + public Integer getAttributed() { | ||
96 | + return isAttributed; | ||
97 | + } | ||
98 | + | ||
99 | + public void setAttributed(Integer attributed) { | ||
100 | + isAttributed = attributed; | ||
101 | + } | ||
102 | + | ||
103 | + public Integer getClickInTenMins() { | ||
104 | + return clickInTenMins; | ||
105 | + } | ||
106 | + | ||
107 | + public void setClickInTenMins(Integer clickInTenMins) { | ||
108 | + this.clickInTenMins = clickInTenMins; | ||
109 | + } | ||
110 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment