add decision tree ml model
Java-Cesco/Detecting_fraud_clicks/#10
Showing
4 changed files
with
60 additions
and
35 deletions
... | @@ -75,3 +75,8 @@ fabric.properties | ... | @@ -75,3 +75,8 @@ fabric.properties |
75 | 75 | ||
76 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml | 76 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml |
77 | hs_err_pid* | 77 | hs_err_pid* |
78 | + | ||
79 | + | ||
80 | +# datafile | ||
81 | +train.zip | ||
82 | +train.csv | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
1 | +package detact; | ||
2 | + | ||
1 | import org.apache.spark.sql.Dataset; | 3 | import org.apache.spark.sql.Dataset; |
2 | import org.apache.spark.sql.Row; | 4 | import org.apache.spark.sql.Row; |
3 | import org.apache.spark.sql.SparkSession; | 5 | import org.apache.spark.sql.SparkSession; |
... | @@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window; | ... | @@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window; |
5 | import org.apache.spark.sql.expressions.WindowSpec; | 7 | import org.apache.spark.sql.expressions.WindowSpec; |
6 | 8 | ||
7 | import static org.apache.spark.sql.functions.*; | 9 | import static org.apache.spark.sql.functions.*; |
8 | -import static org.apache.spark.sql.functions.lit; | ||
9 | -import static org.apache.spark.sql.functions.when; | ||
10 | 10 | ||
11 | public class Aggregation { | 11 | public class Aggregation { |
12 | 12 | ||
13 | - public static void main(String[] args) throws Exception { | 13 | + public static String AGGREGATED_PATH = "agg_data"; |
14 | + public static String ORIGINAL_DATA_PATH = "train_sample.csv"; | ||
15 | + | ||
16 | + public static void main(String[] args) { | ||
14 | 17 | ||
15 | //Create Session | 18 | //Create Session |
16 | SparkSession spark = SparkSession | 19 | SparkSession spark = SparkSession |
... | @@ -19,10 +22,10 @@ public class Aggregation { | ... | @@ -19,10 +22,10 @@ public class Aggregation { |
19 | .master("local") | 22 | .master("local") |
20 | .getOrCreate(); | 23 | .getOrCreate(); |
21 | 24 | ||
22 | - // Aggregation | 25 | + // detact.Aggregation |
23 | Aggregation agg = new Aggregation(); | 26 | Aggregation agg = new Aggregation(); |
24 | 27 | ||
25 | - Dataset<Row> dataset = Utill.loadCSVDataSet("./train_sample.csv", spark); | 28 | + Dataset<Row> dataset = Utill.loadCSVDataSet(Aggregation.ORIGINAL_DATA_PATH, spark); |
26 | dataset = agg.changeTimestempToLong(dataset); | 29 | dataset = agg.changeTimestempToLong(dataset); |
27 | dataset = agg.averageValidClickCount(dataset); | 30 | dataset = agg.averageValidClickCount(dataset); |
28 | dataset = agg.clickTimeDelta(dataset); | 31 | dataset = agg.clickTimeDelta(dataset); |
... | @@ -32,7 +35,7 @@ public class Aggregation { | ... | @@ -32,7 +35,7 @@ public class Aggregation { |
32 | dataset.where("ip == '5348' and app == '19'").show(10); | 35 | dataset.where("ip == '5348' and app == '19'").show(10); |
33 | 36 | ||
34 | // Save to scv | 37 | // Save to scv |
35 | - Utill.saveCSVDataSet(dataset, "./agg_data"); | 38 | + Utill.saveCSVDataSet(dataset, Aggregation.AGGREGATED_PATH); |
36 | } | 39 | } |
37 | 40 | ||
38 | private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | 41 | private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ |
... | @@ -75,7 +78,7 @@ public class Aggregation { | ... | @@ -75,7 +78,7 @@ public class Aggregation { |
75 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); | 78 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); |
76 | 79 | ||
77 | Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins", | 80 | Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins", |
78 | - (count("utc_click_time").over(w)).minus(1)); //TODO 본인것 포함할 것인지 정해야함. | 81 | + (count("utc_click_time").over(w)).minus(1)); |
79 | return newDF; | 82 | return newDF; |
80 | } | 83 | } |
81 | 84 | ... | ... |
1 | -import org.apache.spark.SparkConf; | 1 | +package detact.ML; |
2 | -import org.apache.spark.api.java.JavaRDD; | 2 | + |
3 | -import org.apache.spark.api.java.JavaSparkContext; | 3 | +import detact.Aggregation; |
4 | -import org.apache.spark.api.java.function.Function; | 4 | +import detact.Utill; |
5 | import org.apache.spark.ml.Pipeline; | 5 | import org.apache.spark.ml.Pipeline; |
6 | import org.apache.spark.ml.PipelineModel; | 6 | import org.apache.spark.ml.PipelineModel; |
7 | import org.apache.spark.ml.PipelineStage; | 7 | import org.apache.spark.ml.PipelineStage; |
... | @@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel; | ... | @@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel; |
12 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | 12 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; |
13 | import org.apache.spark.ml.regression.DecisionTreeRegressor; | 13 | import org.apache.spark.ml.regression.DecisionTreeRegressor; |
14 | import org.apache.spark.sql.Dataset; | 14 | import org.apache.spark.sql.Dataset; |
15 | -import org.apache.spark.sql.Encoders; | ||
16 | import org.apache.spark.sql.Row; | 15 | import org.apache.spark.sql.Row; |
17 | -import org.apache.spark.sql.SQLContext; | 16 | +import org.apache.spark.sql.SparkSession; |
18 | -import scala.Serializable; | ||
19 | - | ||
20 | -import java.util.*; | ||
21 | 17 | ||
22 | 18 | ||
23 | -// ml | 19 | +// DecisionTree Model |
24 | 20 | ||
25 | -public class MapExample { | 21 | +public class DecisionTree { |
26 | 22 | ||
27 | public static void main(String[] args) throws Exception { | 23 | public static void main(String[] args) throws Exception { |
28 | 24 | ||
29 | - // Automatically identify categorical features, and index them. | 25 | + //Create Session |
30 | - // Set maxCategories so features with > 4 distinct values are treated as continuous. | 26 | + SparkSession spark = SparkSession |
31 | - | 27 | + .builder() |
32 | - Aggregation agg = new Aggregation(); | 28 | + .appName("Detecting Fraud Clicks") |
33 | - | 29 | + .master("local") |
34 | - agg. | 30 | + .getOrCreate(); |
35 | - | 31 | + |
36 | - Dataset<Row> resultds = sqlContext.createDataFrame(result); | 32 | + // load aggregated dataset |
37 | - | 33 | + Dataset<Row> resultds = Utill.loadCSVDataSet(Aggregation.AGGREGATED_PATH, spark); |
38 | - System.out.println("schema start"); | 34 | + |
39 | - resultds.printSchema(); | 35 | + // show Dataset schema |
40 | - System.out.println("schema end"); | 36 | +// System.out.println("schema start"); |
37 | +// resultds.printSchema(); | ||
38 | +// String[] cols = resultds.columns(); | ||
39 | +// for (String col : cols) { | ||
40 | +// System.out.println(col); | ||
41 | +// } | ||
42 | +// System.out.println("schema end"); | ||
41 | 43 | ||
42 | VectorAssembler assembler = new VectorAssembler() | 44 | VectorAssembler assembler = new VectorAssembler() |
43 | - .setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"}) | 45 | + .setInputCols(new String[]{ |
46 | + "ip", | ||
47 | + "app", | ||
48 | + "device", | ||
49 | + "os", | ||
50 | + "channel", | ||
51 | + "utc_click_time", | ||
52 | + "avg_valid_click_count", | ||
53 | + "click_time_delta", | ||
54 | + "count_click_in_ten_mins" | ||
55 | + }) | ||
44 | .setOutputCol("features"); | 56 | .setOutputCol("features"); |
45 | 57 | ||
46 | Dataset<Row> output = assembler.transform(resultds); | 58 | Dataset<Row> output = assembler.transform(resultds); |
... | @@ -56,9 +68,11 @@ public class MapExample { | ... | @@ -56,9 +68,11 @@ public class MapExample { |
56 | Dataset<Row> trainingData = splits[0]; | 68 | Dataset<Row> trainingData = splits[0]; |
57 | Dataset<Row> testData = splits[1]; | 69 | Dataset<Row> testData = splits[1]; |
58 | 70 | ||
59 | - // Train a DecisionTree model. | 71 | + // Train a detact.DecisionTreeionTree model. |
60 | DecisionTreeRegressor dt = new DecisionTreeRegressor() | 72 | DecisionTreeRegressor dt = new DecisionTreeRegressor() |
61 | - .setFeaturesCol("indexedFeatures").setLabelCol("attributed"); | 73 | + .setFeaturesCol("indexedFeatures") |
74 | + .setLabelCol("is_attributed") | ||
75 | + .setMaxDepth(10); | ||
62 | 76 | ||
63 | // Chain indexer and tree in a Pipeline. | 77 | // Chain indexer and tree in a Pipeline. |
64 | Pipeline pipeline = new Pipeline() | 78 | Pipeline pipeline = new Pipeline() |
... | @@ -71,11 +85,11 @@ public class MapExample { | ... | @@ -71,11 +85,11 @@ public class MapExample { |
71 | Dataset<Row> predictions = model.transform(testData); | 85 | Dataset<Row> predictions = model.transform(testData); |
72 | 86 | ||
73 | // Select example rows to display. | 87 | // Select example rows to display. |
74 | - predictions.select("attributed", "features").show(5); | 88 | + predictions.select("is_attributed", "features").show(5); |
75 | 89 | ||
76 | // Select (prediction, true label) and compute test error. | 90 | // Select (prediction, true label) and compute test error. |
77 | RegressionEvaluator evaluator = new RegressionEvaluator() | 91 | RegressionEvaluator evaluator = new RegressionEvaluator() |
78 | - .setLabelCol("attributed") | 92 | + .setLabelCol("is_attributed") |
79 | .setPredictionCol("prediction") | 93 | .setPredictionCol("prediction") |
80 | .setMetricName("rmse"); | 94 | .setMetricName("rmse"); |
81 | double rmse = evaluator.evaluate(predictions); | 95 | double rmse = evaluator.evaluate(predictions); |
... | @@ -86,4 +100,5 @@ public class MapExample { | ... | @@ -86,4 +100,5 @@ public class MapExample { |
86 | System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | 100 | System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); |
87 | 101 | ||
88 | } | 102 | } |
103 | + | ||
89 | } | 104 | } | ... | ... |
-
Please register or login to post a comment