DecisionTree.java
3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package detact.ML;
import detact.Aggregation;
import detact.Utill;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// DecisionTree Model
public class DecisionTree {
public static void main(String[] args) throws Exception {
if (args.length != 1) {
System.out.println("Usage: java -jar decisionTree.jar <agg_path>");
System.exit(0);
}
String agg_path = args[0];
//Create Session
SparkSession spark = SparkSession
.builder()
.appName("Detecting Fraud Clicks")
.master("local")
.getOrCreate();
// load aggregated dataset
Dataset<Row> resultds = Utill.loadCSVDataSet(agg_path, spark);
// show Dataset schema
// System.out.println("schema start");
// resultds.printSchema();
// String[] cols = resultds.columns();
// for (String col : cols) {
// System.out.println(col);
// }
// System.out.println("schema end");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{
"ip",
"app",
"device",
"os",
"channel",
"utc_click_time",
"avg_valid_click_count",
"click_time_delta",
"count_click_in_ten_mins"
})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(resultds);
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(2)
.fit(output);
// Split the result into training and test sets (30% held out for testing).
Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a detact.DecisionTreeionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol("is_attributed")
.setMaxDepth(10);
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{featureIndexer, dt});
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("is_attributed", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("is_attributed")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
}
}