신은섭(Shin Eun Seop)

add Aggregation class to aggrigate new feature

Java-Cesco/Detecting_fraud_clicks#3
1 +import org.apache.spark.sql.Dataset;
2 +import org.apache.spark.sql.Row;
3 +import org.apache.spark.sql.SparkSession;
4 +import org.apache.spark.sql.expressions.Window;
5 +import org.apache.spark.sql.expressions.WindowSpec;
6 +
7 +import static org.apache.spark.sql.functions.*;
8 +import static org.apache.spark.sql.functions.lit;
9 +import static org.apache.spark.sql.functions.when;
10 +
11 +public class Aggregation {
12 +
13 + public static void main(String[] args) throws Exception {
14 +
15 + //Create Session
16 + SparkSession spark = SparkSession
17 + .builder()
18 + .appName("Detecting Fraud Clicks")
19 + .master("local")
20 + .getOrCreate();
21 +
22 + Aggregation agg = new Aggregation();
23 +
24 + Dataset<Row> dataset = agg.loadCSVDataSet("./train_sample.csv", spark);
25 + dataset = agg.changeTimestempToLong(dataset);
26 + dataset = agg.averageValidClickCount(dataset);
27 + dataset = agg.clickTimeDelta(dataset);
28 +
29 + dataset.where("ip == '5348' and app == '19'").show();
30 +
31 + }
32 +
33 +
34 + private Dataset<Row> loadCSVDataSet(String path, SparkSession spark){
35 + // Read SCV to DataSet
36 + Dataset<Row> dataset = spark.read().format("csv")
37 + .option("inferSchema", "true")
38 + .option("header", "true")
39 + .load("train_sample.csv");
40 + return dataset;
41 + }
42 +
43 + private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
44 + // cast timestamp to long
45 + Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long"));
46 + newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long"));
47 + newDF = newDF.drop("click_time").drop("attributed_time");
48 + return newDF;
49 + }
50 +
51 + private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
52 + // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row
53 + WindowSpec w = Window.partitionBy("ip", "app")
54 + .orderBy("utc_click_time")
55 + .rowsBetween(Window.unboundedPreceding(), Window.currentRow());
56 +
57 + // aggregation
58 + Dataset<Row> newDF = dataset.withColumn("cum_count_click", count("utc_click_time").over(w));
59 + newDF = newDF.withColumn("cum_sum_attributed", sum("is_attributed").over(w));
60 + newDF = newDF.withColumn("avg_valid_click_count", col("cum_sum_attributed").divide(col("cum_count_click")));
61 + newDF = newDF.drop("cum_count_click", "cum_sum_attributed");
62 + return newDF;
63 + }
64 +
65 + private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
66 + WindowSpec w = Window.partitionBy ("ip")
67 + .orderBy("utc_click_time");
68 +
69 + Dataset<Row> newDF = dataset.withColumn("lag(utc_click_time)", lag("utc_click_time",1).over(w));
70 + newDF = newDF.withColumn("click_time_delta", when(col("lag(utc_click_time)").isNull(),
71 + lit(0)).otherwise(col("utc_click_time")).minus(when(col("lag(utc_click_time)").isNull(),
72 + lit(0)).otherwise(col("lag(utc_click_time)"))));
73 + newDF = newDF.drop("lag(utc_click_time)");
74 + return newDF;
75 + }
76 +}