add Aggregation class to aggrigate new feature
Java-Cesco/Detecting_fraud_clicks#3
Showing
1 changed file
with
76 additions
and
0 deletions
src/main/java/Aggregation.java
0 → 100644
1 | +import org.apache.spark.sql.Dataset; | ||
2 | +import org.apache.spark.sql.Row; | ||
3 | +import org.apache.spark.sql.SparkSession; | ||
4 | +import org.apache.spark.sql.expressions.Window; | ||
5 | +import org.apache.spark.sql.expressions.WindowSpec; | ||
6 | + | ||
7 | +import static org.apache.spark.sql.functions.*; | ||
8 | +import static org.apache.spark.sql.functions.lit; | ||
9 | +import static org.apache.spark.sql.functions.when; | ||
10 | + | ||
11 | +public class Aggregation { | ||
12 | + | ||
13 | + public static void main(String[] args) throws Exception { | ||
14 | + | ||
15 | + //Create Session | ||
16 | + SparkSession spark = SparkSession | ||
17 | + .builder() | ||
18 | + .appName("Detecting Fraud Clicks") | ||
19 | + .master("local") | ||
20 | + .getOrCreate(); | ||
21 | + | ||
22 | + Aggregation agg = new Aggregation(); | ||
23 | + | ||
24 | + Dataset<Row> dataset = agg.loadCSVDataSet("./train_sample.csv", spark); | ||
25 | + dataset = agg.changeTimestempToLong(dataset); | ||
26 | + dataset = agg.averageValidClickCount(dataset); | ||
27 | + dataset = agg.clickTimeDelta(dataset); | ||
28 | + | ||
29 | + dataset.where("ip == '5348' and app == '19'").show(); | ||
30 | + | ||
31 | + } | ||
32 | + | ||
33 | + | ||
34 | + private Dataset<Row> loadCSVDataSet(String path, SparkSession spark){ | ||
35 | + // Read SCV to DataSet | ||
36 | + Dataset<Row> dataset = spark.read().format("csv") | ||
37 | + .option("inferSchema", "true") | ||
38 | + .option("header", "true") | ||
39 | + .load("train_sample.csv"); | ||
40 | + return dataset; | ||
41 | + } | ||
42 | + | ||
43 | + private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | ||
44 | + // cast timestamp to long | ||
45 | + Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); | ||
46 | + newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); | ||
47 | + newDF = newDF.drop("click_time").drop("attributed_time"); | ||
48 | + return newDF; | ||
49 | + } | ||
50 | + | ||
51 | + private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ | ||
52 | + // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row | ||
53 | + WindowSpec w = Window.partitionBy("ip", "app") | ||
54 | + .orderBy("utc_click_time") | ||
55 | + .rowsBetween(Window.unboundedPreceding(), Window.currentRow()); | ||
56 | + | ||
57 | + // aggregation | ||
58 | + Dataset<Row> newDF = dataset.withColumn("cum_count_click", count("utc_click_time").over(w)); | ||
59 | + newDF = newDF.withColumn("cum_sum_attributed", sum("is_attributed").over(w)); | ||
60 | + newDF = newDF.withColumn("avg_valid_click_count", col("cum_sum_attributed").divide(col("cum_count_click"))); | ||
61 | + newDF = newDF.drop("cum_count_click", "cum_sum_attributed"); | ||
62 | + return newDF; | ||
63 | + } | ||
64 | + | ||
65 | + private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ | ||
66 | + WindowSpec w = Window.partitionBy ("ip") | ||
67 | + .orderBy("utc_click_time"); | ||
68 | + | ||
69 | + Dataset<Row> newDF = dataset.withColumn("lag(utc_click_time)", lag("utc_click_time",1).over(w)); | ||
70 | + newDF = newDF.withColumn("click_time_delta", when(col("lag(utc_click_time)").isNull(), | ||
71 | + lit(0)).otherwise(col("utc_click_time")).minus(when(col("lag(utc_click_time)").isNull(), | ||
72 | + lit(0)).otherwise(col("lag(utc_click_time)")))); | ||
73 | + newDF = newDF.drop("lag(utc_click_time)"); | ||
74 | + return newDF; | ||
75 | + } | ||
76 | +} |
-
Please register or login to post a comment