Toggle navigation
Toggle navigation
This project
Loading...
Sign in
신은섭(Shin Eun Seop)
/
Detecting_fraud_clicks
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
2
Merge Requests
0
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
신은섭(Shin Eun Seop)
2018-06-05 08:16:38 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
4d4027061ffee6dd06d2049cf6fb9b6c735cc880
4d402706
1 parent
bb1e9781
add decision tree ml model
Java-Cesco/Detecting_fraud_clicks/
#10
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
36 deletions
.gitignore
src/main/java/Aggregation.java → src/main/java/detact/Aggregation.java
src/main/java/MLModel.java → src/main/java/detact/ML/DecisionTree.java
src/main/java/Utill.java → src/main/java/detact/Utill.java
.gitignore
View file @
4d40270
...
...
@@ -74,4 +74,9 @@ fabric.properties
*.rar
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
\ No newline at end of file
hs_err_pid*
# datafile
train.zip
train.csv
\ No newline at end of file
...
...
src/main/java/Aggregation.java
→
src/main/java/
detact/
Aggregation.java
View file @
4d40270
package
detact
;
import
org.apache.spark.sql.Dataset
;
import
org.apache.spark.sql.Row
;
import
org.apache.spark.sql.SparkSession
;
...
...
@@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window;
import
org.apache.spark.sql.expressions.WindowSpec
;
import
static
org
.
apache
.
spark
.
sql
.
functions
.*;
import
static
org
.
apache
.
spark
.
sql
.
functions
.
lit
;
import
static
org
.
apache
.
spark
.
sql
.
functions
.
when
;
public
class
Aggregation
{
public
static
String
AGGREGATED_PATH
=
"agg_data"
;
public
static
String
ORIGINAL_DATA_PATH
=
"train_sample.csv"
;
public
static
void
main
(
String
[]
args
)
throws
Exception
{
public
static
void
main
(
String
[]
args
)
{
//Create Session
SparkSession
spark
=
SparkSession
...
...
@@ -19,10 +22,10 @@ public class Aggregation {
.
master
(
"local"
)
.
getOrCreate
();
// Aggregation
//
detact.
Aggregation
Aggregation
agg
=
new
Aggregation
();
Dataset
<
Row
>
dataset
=
Utill
.
loadCSVDataSet
(
"./train_sample.csv"
,
spark
);
Dataset
<
Row
>
dataset
=
Utill
.
loadCSVDataSet
(
Aggregation
.
ORIGINAL_DATA_PATH
,
spark
);
dataset
=
agg
.
changeTimestempToLong
(
dataset
);
dataset
=
agg
.
averageValidClickCount
(
dataset
);
dataset
=
agg
.
clickTimeDelta
(
dataset
);
...
...
@@ -32,7 +35,7 @@ public class Aggregation {
dataset
.
where
(
"ip == '5348' and app == '19'"
).
show
(
10
);
// Save to scv
Utill
.
saveCSVDataSet
(
dataset
,
"./agg_data"
);
Utill
.
saveCSVDataSet
(
dataset
,
Aggregation
.
AGGREGATED_PATH
);
}
private
Dataset
<
Row
>
changeTimestempToLong
(
Dataset
<
Row
>
dataset
){
...
...
@@ -75,7 +78,7 @@ public class Aggregation {
.
rangeBetween
(
Window
.
currentRow
(),
Window
.
currentRow
()+
600
);
Dataset
<
Row
>
newDF
=
dataset
.
withColumn
(
"count_click_in_ten_mins"
,
(
count
(
"utc_click_time"
).
over
(
w
)).
minus
(
1
));
//TODO 본인것 포함할 것인지 정해야함.
(
count
(
"utc_click_time"
).
over
(
w
)).
minus
(
1
));
return
newDF
;
}
...
...
src/main/java/
MLModel
.java
→
src/main/java/
detact/ML/DecisionTree
.java
View file @
4d40270
import
org.apache.spark.SparkConf
;
import
org.apache.spark.api.java.JavaRDD
;
import
org.apache.spark.api.java.JavaSparkContext
;
import
org.apache.spark.api.java.function.Function
;
package
detact
.
ML
;
import
detact.Aggregation
;
import
detact.Utill
;
import
org.apache.spark.ml.Pipeline
;
import
org.apache.spark.ml.PipelineModel
;
import
org.apache.spark.ml.PipelineStage
;
...
...
@@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel;
import
org.apache.spark.ml.regression.DecisionTreeRegressionModel
;
import
org.apache.spark.ml.regression.DecisionTreeRegressor
;
import
org.apache.spark.sql.Dataset
;
import
org.apache.spark.sql.Encoders
;
import
org.apache.spark.sql.Row
;
import
org.apache.spark.sql.SQLContext
;
import
scala.Serializable
;
import
java.util.*
;
import
org.apache.spark.sql.SparkSession
;
//
m
l
//
DecisionTree Mode
l
public
class
MapExampl
e
{
public
class
DecisionTre
e
{
public
static
void
main
(
String
[]
args
)
throws
Exception
{
//Create Session
SparkSession
spark
=
SparkSession
.
builder
()
.
appName
(
"Detecting Fraud Clicks"
)
.
master
(
"local"
)
.
getOrCreate
();
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
Aggregation
agg
=
new
Aggregation
();
agg
.
Dataset
<
Row
>
resultds
=
sqlContext
.
createDataFrame
(
result
);
// load aggregated dataset
Dataset
<
Row
>
resultds
=
Utill
.
loadCSVDataSet
(
Aggregation
.
AGGREGATED_PATH
,
spark
);
System
.
out
.
println
(
"schema start"
);
resultds
.
printSchema
();
System
.
out
.
println
(
"schema end"
);
// show Dataset schema
// System.out.println("schema start");
// resultds.printSchema();
// String[] cols = resultds.columns();
// for (String col : cols) {
// System.out.println(col);
// }
// System.out.println("schema end");
VectorAssembler
assembler
=
new
VectorAssembler
()
.
setInputCols
(
new
String
[]{
"ip"
,
"app"
,
"device"
,
"os"
,
"channel"
,
"clickInTenMins"
})
.
setInputCols
(
new
String
[]{
"ip"
,
"app"
,
"device"
,
"os"
,
"channel"
,
"utc_click_time"
,
"avg_valid_click_count"
,
"click_time_delta"
,
"count_click_in_ten_mins"
})
.
setOutputCol
(
"features"
);
Dataset
<
Row
>
output
=
assembler
.
transform
(
resultds
);
...
...
@@ -56,9 +68,11 @@ public class MapExample {
Dataset
<
Row
>
trainingData
=
splits
[
0
];
Dataset
<
Row
>
testData
=
splits
[
1
];
// Train a
Decis
ionTree model.
// Train a
detact.DecisionTree
ionTree model.
DecisionTreeRegressor
dt
=
new
DecisionTreeRegressor
()
.
setFeaturesCol
(
"indexedFeatures"
).
setLabelCol
(
"attributed"
);
.
setFeaturesCol
(
"indexedFeatures"
)
.
setLabelCol
(
"is_attributed"
)
.
setMaxDepth
(
10
);
// Chain indexer and tree in a Pipeline.
Pipeline
pipeline
=
new
Pipeline
()
...
...
@@ -71,19 +85,20 @@ public class MapExample {
Dataset
<
Row
>
predictions
=
model
.
transform
(
testData
);
// Select example rows to display.
predictions
.
select
(
"attributed"
,
"features"
).
show
(
5
);
predictions
.
select
(
"
is_
attributed"
,
"features"
).
show
(
5
);
// Select (prediction, true label) and compute test error.
RegressionEvaluator
evaluator
=
new
RegressionEvaluator
()
.
setLabelCol
(
"attributed"
)
.
setLabelCol
(
"
is_
attributed"
)
.
setPredictionCol
(
"prediction"
)
.
setMetricName
(
"rmse"
);
double
rmse
=
evaluator
.
evaluate
(
predictions
);
System
.
out
.
println
(
"Root Mean Squared Error (RMSE) on test result = "
+
rmse
);
DecisionTreeRegressionModel
treeModel
=
(
DecisionTreeRegressionModel
)
(
model
.
stages
()[
1
]);
System
.
out
.
println
(
"Learned regression tree model:\n"
+
treeModel
.
toDebugString
());
}
}
...
...
src/main/java/Utill.java
→
src/main/java/
detact/
Utill.java
View file @
4d40270
package
detact
;
import
org.apache.spark.sql.Dataset
;
import
org.apache.spark.sql.Row
;
import
org.apache.spark.sql.SparkSession
;
...
...
Please
register
or
login
to post a comment