신은섭(Shin Eun Seop)

add decision tree ml model

Java-Cesco/Detecting_fraud_clicks/#10
......@@ -74,4 +74,9 @@ fabric.properties
*.rar
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
\ No newline at end of file
hs_err_pid*
# datafile
train.zip
train.csv
\ No newline at end of file
......
package detact;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
......@@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.when;
public class Aggregation {
public static String AGGREGATED_PATH = "agg_data";
public static String ORIGINAL_DATA_PATH = "train_sample.csv";
public static void main(String[] args) throws Exception {
public static void main(String[] args) {
//Create Session
SparkSession spark = SparkSession
......@@ -19,10 +22,10 @@ public class Aggregation {
.master("local")
.getOrCreate();
// Aggregation
// detact.Aggregation
Aggregation agg = new Aggregation();
Dataset<Row> dataset = Utill.loadCSVDataSet("./train_sample.csv", spark);
Dataset<Row> dataset = Utill.loadCSVDataSet(Aggregation.ORIGINAL_DATA_PATH, spark);
dataset = agg.changeTimestempToLong(dataset);
dataset = agg.averageValidClickCount(dataset);
dataset = agg.clickTimeDelta(dataset);
......@@ -32,7 +35,7 @@ public class Aggregation {
dataset.where("ip == '5348' and app == '19'").show(10);
// Save to scv
Utill.saveCSVDataSet(dataset, "./agg_data");
Utill.saveCSVDataSet(dataset, Aggregation.AGGREGATED_PATH);
}
private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
......@@ -75,7 +78,7 @@ public class Aggregation {
.rangeBetween(Window.currentRow(),Window.currentRow()+600);
Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins",
(count("utc_click_time").over(w)).minus(1)); //TODO 본인것 포함할 것인지 정해야함.
(count("utc_click_time").over(w)).minus(1));
return newDF;
}
......
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
package detact.ML;
import detact.Aggregation;
import detact.Utill;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
......@@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import scala.Serializable;
import java.util.*;
import org.apache.spark.sql.SparkSession;
// ml
// DecisionTree Model
public class MapExample {
public class DecisionTree {
public static void main(String[] args) throws Exception {
//Create Session
SparkSession spark = SparkSession
.builder()
.appName("Detecting Fraud Clicks")
.master("local")
.getOrCreate();
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
Aggregation agg = new Aggregation();
agg.
Dataset<Row> resultds = sqlContext.createDataFrame(result);
// load aggregated dataset
Dataset<Row> resultds = Utill.loadCSVDataSet(Aggregation.AGGREGATED_PATH, spark);
System.out.println("schema start");
resultds.printSchema();
System.out.println("schema end");
// show Dataset schema
// System.out.println("schema start");
// resultds.printSchema();
// String[] cols = resultds.columns();
// for (String col : cols) {
// System.out.println(col);
// }
// System.out.println("schema end");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"})
.setInputCols(new String[]{
"ip",
"app",
"device",
"os",
"channel",
"utc_click_time",
"avg_valid_click_count",
"click_time_delta",
"count_click_in_ten_mins"
})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(resultds);
......@@ -56,9 +68,11 @@ public class MapExample {
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
// Train a detact.DecisionTreeionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures").setLabelCol("attributed");
.setFeaturesCol("indexedFeatures")
.setLabelCol("is_attributed")
.setMaxDepth(10);
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
......@@ -71,19 +85,20 @@ public class MapExample {
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("attributed", "features").show(5);
predictions.select("is_attributed", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("attributed")
.setLabelCol("is_attributed")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
}
}
......
package detact;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
......