Showing
3 changed files
with
527 additions
and
0 deletions
Model/Dataset/dataset.zip
0 → 100644
No preview for this file type
Model/LSTM_model.ipynb
0 → 100644
1 | +{ | ||
2 | + "nbformat": 4, | ||
3 | + "nbformat_minor": 0, | ||
4 | + "metadata": { | ||
5 | + "colab": { | ||
6 | + "name": "LSTM_7.ipynb", | ||
7 | + "provenance": [], | ||
8 | + "collapsed_sections": [] | ||
9 | + }, | ||
10 | + "kernelspec": { | ||
11 | + "display_name": "Python 3", | ||
12 | + "name": "python3" | ||
13 | + }, | ||
14 | + "accelerator": "GPU" | ||
15 | + }, | ||
16 | + "cells": [ | ||
17 | + { | ||
18 | + "cell_type": "code", | ||
19 | + "metadata": { | ||
20 | + "id": "aRHde3RC83kB" | ||
21 | + }, | ||
22 | + "source": [ | ||
23 | + "import pandas as pd\n", | ||
24 | + "import numpy as np\n", | ||
25 | + "from sklearn.model_selection import train_test_split\n", | ||
26 | + "from sklearn import metrics\n", | ||
27 | + "from keras.wrappers.scikit_learn import KerasClassifier\n", | ||
28 | + "from tensorflow.keras.models import Sequential\n", | ||
29 | + "from tensorflow.keras.layers import Dense, Activation\n", | ||
30 | + "from tensorflow.keras.callbacks import EarlyStopping\n", | ||
31 | + "import tensorflow.feature_column as feature_column\n", | ||
32 | + "from keras.utils.vis_utils import plot_model\n", | ||
33 | + "from tensorflow.keras import layers\n", | ||
34 | + "from sklearn.externals.six import StringIO \n", | ||
35 | + "from sklearn.metrics import confusion_matrix\n", | ||
36 | + "from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n", | ||
37 | + "from sklearn.metrics import roc_auc_score, roc_curve\n", | ||
38 | + "from sklearn.metrics import plot_confusion_matrix\n", | ||
39 | + "from warnings import simplefilter\n", | ||
40 | + "import time\n", | ||
41 | + "from ast import literal_eval\n", | ||
42 | + "\n", | ||
43 | + "# Sequence of dataset generating\n", | ||
44 | + "# malaria -> flood -> normal -> bruteforce -> malformed\n", | ||
45 | + "f_class0 = '/content/drive/MyDrive/Datasets/MQTTset/malaria.csv' # malariaDoS\n", | ||
46 | + "f_class1 = '/content/drive/MyDrive/Datasets/MQTTset/flood.csv' # flood\n", | ||
47 | + "f_class2 = '/content/drive/MyDrive/Datasets/MQTTset/legitimate1.csv' # normal\n", | ||
48 | + "f_class3 = '/content/drive/MyDrive/Datasets/MQTTset/bruteforce.csv' # bruteforce\n", | ||
49 | + "f_class4 = '/content/drive/MyDrive/Datasets/MQTTset/malformed.csv' # malformed\n", | ||
50 | + "\n", | ||
51 | + "# 클래스 분류번호\n", | ||
52 | + "# 0: malariaDoS\n", | ||
53 | + "# 1: flood\n", | ||
54 | + "# 2: normal\n", | ||
55 | + "# 3: bruteforce\n", | ||
56 | + "# 4: malformed\n", | ||
57 | + "\n", | ||
58 | + "fileList = [f_class0,f_class1,f_class2,f_class3,f_class4]\n", | ||
59 | + "targetList = [0,1,2,3,4] \n", | ||
60 | + "\n", | ||
61 | + "pd.set_option('display.max_columns', 33) # 출력할 열의 최대개수\n", | ||
62 | + "\n", | ||
63 | + "# feature 리스트\n", | ||
64 | + "allfeatures = ['frame.time_delta', 'frame.time_delta_displayed', 'frame.time_epoch', 'frame.time_invalid', 'frame.time_relative', 'eth.src', 'eth.dst', 'ip.src', 'ip.dst', 'tcp.srcport', 'tcp.dstport', 'tcp.flags', 'frame.cap_len', 'frame.len', 'frame.number', 'tcp.stream', 'tcp.analysis.initial_rtt', 'tcp.time_delta', 'tcp.len', 'tcp.window_size_value', 'tcp.checksum', 'mqtt.clientid', 'mqtt.clientid_len', 'mqtt.conack.flags', 'mqtt.conack.flags.reserved', 'mqtt.conack.flags.sp', 'mqtt.conack.val', 'mqtt.conflag.cleansess', 'mqtt.conflag.passwd', 'mqtt.conflag.qos', 'mqtt.conflag.reserved', 'mqtt.conflag.retain', 'mqtt.conflag.uname', 'mqtt.conflag.willflag', 'mqtt.conflags', 'mqtt.dupflag', 'mqtt.hdrflags', 'mqtt.kalive', 'mqtt.len', 'mqtt.msg', 'mqtt.msgid', 'mqtt.msgtype', 'mqtt.passwd', 'mqtt.passwd_len', 'mqtt.proto_len', 'mqtt.protoname', 'mqtt.qos', 'mqtt.retain', 'mqtt.sub.qos', 'mqtt.suback.qos', 'mqtt.topic', 'mqtt.topic_len', 'mqtt.username', 'mqtt.username_len', 'mqtt.ver', 'mqtt.willmsg', 'mqtt.willmsg_len', 'mqtt.willtopic', 'mqtt.willtopic_len', 'ip.proto']\n", | ||
65 | + "usedfeatures = ['tcp.flags', 'tcp.time_delta', 'tcp.len', 'mqtt.conack.flags', 'mqtt.conack.flags.reserved', 'mqtt.conack.flags.sp', 'mqtt.conack.val', 'mqtt.conflag.cleansess', 'mqtt.conflag.passwd', 'mqtt.conflag.qos', 'mqtt.conflag.reserved', 'mqtt.conflag.retain', 'mqtt.conflag.uname', 'mqtt.conflag.willflag', 'mqtt.conflags', 'mqtt.dupflag', 'mqtt.hdrflags', 'mqtt.kalive', 'mqtt.len', 'mqtt.msgid', 'mqtt.msgtype', 'mqtt.proto_len', 'mqtt.protoname', 'mqtt.qos', 'mqtt.retain', 'mqtt.sub.qos', 'mqtt.suback.qos', 'mqtt.ver', 'mqtt.willmsg', 'mqtt.willmsg_len', 'mqtt.willtopic', 'mqtt.willtopic_len']\n", | ||
66 | + "droppedfeatures = ['ip.src', 'ip.dst', 'frame.time_relative', 'frame.time_invalid', 'mqtt.username_len', 'frame.len', 'tcp.dstport', 'tcp.window_size_value', 'mqtt.username', 'mqtt.passwd_len', 'mqtt.topic', 'mqtt.topic_len', 'tcp.checksum', 'frame.cap_len', 'mqtt.passwd', 'frame.time_delta', 'eth.dst', 'mqtt.clientid', 'frame.time_epoch', 'frame.number', 'eth.src', 'mqtt.clientid_len', 'mqtt.msg', 'tcp.stream', 'frame.time_delta_displayed', 'tcp.analysis.initial_rtt', 'tcp.srcport', 'ip.proto']\n", | ||
67 | + "\n", | ||
68 | + "print(\"전체 feature 개수: \", len(allfeatures))\n", | ||
69 | + "print(\"사용된 feature 개수: \", len(usedfeatures))\n", | ||
70 | + "print(\"제외한 feature 개수: \", len(droppedfeatures))\n", | ||
71 | + "\n", | ||
72 | + "\n", | ||
73 | + "def split_dataset(dataframe, test_size): # 훈련/테스트/검증 데이터셋으로 분할\n", | ||
74 | + " train, test = train_test_split(dataframe, test_size=0.3, shuffle=False)\n", | ||
75 | + " train, val = train_test_split(train, test_size=0.3, shuffle=False)\n", | ||
76 | + " print(len(train), '훈련 샘플')\n", | ||
77 | + " print(len(val), '검증 샘플')\n", | ||
78 | + " print(len(test), '테스트 샘플')\n", | ||
79 | + " return train, test, val\n", | ||
80 | + "\n", | ||
81 | + "\n", | ||
82 | + "def modify_features(fileList, dropfeaturelist): # 데이터프레임에서 제외된 feature 열들을 drop, target label 추가\n", | ||
83 | + " _dfList = []\n", | ||
84 | + " for i in range(len(fileList)):\n", | ||
85 | + " df = pd.read_csv(fileList[i], encoding='euc-kr')\n", | ||
86 | + " for feature in dropfeaturelist:\n", | ||
87 | + " df = df.drop([feature],axis=1)\n", | ||
88 | + " df['target'] = i # labeling\n", | ||
89 | + " _dfList.append(df)\n", | ||
90 | + " return _dfList\n", | ||
91 | + "\n", | ||
92 | + "\n", | ||
93 | + "dfList = modify_features(fileList, droppedfeatures)\n", | ||
94 | + "xTrainList,yTrainList = [],[] # train 데이터\n", | ||
95 | + "xTestList,yTestList = [],[] # test 데이터\n", | ||
96 | + "xValList,yValList = [],[] # validation 데이터\n", | ||
97 | + "\n", | ||
98 | + "\n", | ||
99 | + "for df in dfList:\n", | ||
100 | + " df['mqtt.protoname'].fillna('No', inplace=True)\n", | ||
101 | + " df['mqtt.protoname'].replace('MQTT', 1, inplace=True)\n", | ||
102 | + " df['mqtt.protoname'].replace('No', 0, inplace=True)\n", | ||
103 | + " train, test, val = split_dataset(df, 0.3)\n", | ||
104 | + " trainLabel = train.pop('target')\n", | ||
105 | + " testLabel = test.pop('target')\n", | ||
106 | + " valLabel = val.pop('target')\n", | ||
107 | + " xTrainList.append(train)\n", | ||
108 | + " xTestList.append(test)\n", | ||
109 | + " xValList.append(val)\n", | ||
110 | + "\n", | ||
111 | + " yTrainList.append(trainLabel)\n", | ||
112 | + " yTestList.append(testLabel)\n", | ||
113 | + " yValList.append(valLabel)\n", | ||
114 | + " print(train.head(2))\n", | ||
115 | + "\n", | ||
116 | + "# trainList,testList,valList 각각 해당 리스트 원소끼리 merge\n", | ||
117 | + "\n", | ||
118 | + "\n", | ||
119 | + "for i in range(len(dfList)): # 길이 5의 df 리스트\n", | ||
120 | + " xTrain_df = pd.concat(xTrainList, ignore_index=True)\n", | ||
121 | + " yTrain_df = pd.concat(yTrainList, ignore_index=True)\n", | ||
122 | + "\n", | ||
123 | + " xTest_df = pd.concat(xTestList, ignore_index=True)\n", | ||
124 | + " yTest_df = pd.concat(yTestList, ignore_index=True)\n", | ||
125 | + "\n", | ||
126 | + " xVal_df = pd.concat(xValList, ignore_index=True)\n", | ||
127 | + " yVal_df = pd.concat(yValList, ignore_index=True)\n", | ||
128 | + "\n", | ||
129 | + "\n", | ||
130 | + "print(\"훈련데이터셋 형상\\n\", xTrain_df.shape)\n", | ||
131 | + "print(yTrain_df.shape)\n", | ||
132 | + "print(\"테스트데이터셋 형상\\n\", xTest_df.shape)\n", | ||
133 | + "print(yTest_df.shape)\n", | ||
134 | + "print(\"검증데이터셋 형상\\n\", xVal_df.shape)\n", | ||
135 | + "print(yVal_df.shape)\n", | ||
136 | + "\n", | ||
137 | + "\n" | ||
138 | + ], | ||
139 | + "execution_count": null, | ||
140 | + "outputs": [] | ||
141 | + }, | ||
142 | + { | ||
143 | + "cell_type": "code", | ||
144 | + "metadata": { | ||
145 | + "id": "gnzbus9Nsr-F", | ||
146 | + "colab": { | ||
147 | + "base_uri": "https://localhost:8080/" | ||
148 | + }, | ||
149 | + "outputId": "d9343f9c-fe78-4109-bf13-8c91165332c2" | ||
150 | + }, | ||
151 | + "source": [ | ||
152 | + "\n", | ||
153 | + "# | 16진수 | 문자열 |\n", | ||
154 | + "\n", | ||
155 | + " # tcp.flags object # mqtt.protoname object\n", | ||
156 | + "\n", | ||
157 | + "\n", | ||
158 | + " # mqtt.conack.flags \n", | ||
159 | + " \n", | ||
160 | + " # mqtt.conflags\n", | ||
161 | + "\n", | ||
162 | + " # mqtt.hdrflags\n", | ||
163 | + "\n", | ||
164 | + "hexfeatures = ['tcp.flags','mqtt.conack.flags','mqtt.conflags','mqtt.hdrflags']\n", | ||
165 | + "\n", | ||
166 | + "# 문자열 feature인 column들을 embedding\n", | ||
167 | + "\n", | ||
168 | + "\n", | ||
169 | + "\n", | ||
170 | + "trainList = [xTrain_df, xTest_df, xVal_df]\n", | ||
171 | + "labelList = [yTrain_df, yTest_df, yVal_df]\n", | ||
172 | + "\n", | ||
173 | + "\n", | ||
174 | + "for hexa in hexfeatures:\n", | ||
175 | + " xTrain_df[hexa].fillna('0000', inplace=True)\n", | ||
176 | + " for i in range(xTrain_df.shape[0]):\n", | ||
177 | + " if (type(xTrain_df[hexa][i]) == str) and ('x' in xTrain_df[hexa][i]):\n", | ||
178 | + " xTrain_df[hexa][i] = int(xTrain_df[hexa][i],16)\n", | ||
179 | + " #print(type(xTrain_df[hexa][i]))\n", | ||
180 | + " else:\n", | ||
181 | + " xTrain_df[hexa][i] = int(xTrain_df[hexa][i])\n", | ||
182 | + "\n" | ||
183 | + ], | ||
184 | + "execution_count": null, | ||
185 | + "outputs": [ | ||
186 | + { | ||
187 | + "output_type": "stream", | ||
188 | + "text": [ | ||
189 | + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:26: SettingWithCopyWarning: \n", | ||
190 | + "A value is trying to be set on a copy of a slice from a DataFrame\n", | ||
191 | + "\n", | ||
192 | + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", | ||
193 | + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:29: SettingWithCopyWarning: \n", | ||
194 | + "A value is trying to be set on a copy of a slice from a DataFrame\n", | ||
195 | + "\n", | ||
196 | + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n" | ||
197 | + ], | ||
198 | + "name": "stderr" | ||
199 | + } | ||
200 | + ] | ||
201 | + }, | ||
202 | + { | ||
203 | + "cell_type": "code", | ||
204 | + "metadata": { | ||
205 | + "id": "7Fg_qxJwA0m1" | ||
206 | + }, | ||
207 | + "source": [ | ||
208 | + "\r\n", | ||
209 | + "for hexa in hexfeatures:\r\n", | ||
210 | + " xTest_df[hexa].fillna('0000', inplace=True)\r\n", | ||
211 | + " for i in range(xTest_df.shape[0]):\r\n", | ||
212 | + " if (type(xTest_df[hexa][i]) == str) and ('x' in xTest_df[hexa][i]):\r\n", | ||
213 | + " xTest_df[hexa][i] = int(xTest_df[hexa][i],16)\r\n", | ||
214 | + " #print(type(xTest_df[hexa][i]))\r\n", | ||
215 | + " else:\r\n", | ||
216 | + " xTest_df[hexa][i] = int(xTest_df[hexa][i])\r\n", | ||
217 | + "\r\n", | ||
218 | + "for hexa in hexfeatures:\r\n", | ||
219 | + " xVal_df[hexa].fillna('0000', inplace=True)\r\n", | ||
220 | + " for i in range(xVal_df.shape[0]):\r\n", | ||
221 | + " if (type(xVal_df[hexa][i]) == str) and ('x' in xVal_df[hexa][i]):\r\n", | ||
222 | + " xVal_df[hexa][i] = int(xVal_df[hexa][i],16)\r\n", | ||
223 | + " #print(type(xVal_df[hexa][i]))\r\n", | ||
224 | + " else:\r\n", | ||
225 | + " xVal_df[hexa][i] = int(xVal_df[hexa][i])\r\n", | ||
226 | + "\r\n", | ||
227 | + "print(xTrain_df.shape)\r\n", | ||
228 | + "print(xTest_df.shape)\r\n", | ||
229 | + "print(xVal_df.shape)" | ||
230 | + ], | ||
231 | + "execution_count": null, | ||
232 | + "outputs": [] | ||
233 | + }, | ||
234 | + { | ||
235 | + "cell_type": "code", | ||
236 | + "metadata": { | ||
237 | + "id": "imSEh-WcTvz-" | ||
238 | + }, | ||
239 | + "source": [ | ||
240 | + "xTrain_df.fillna(0, inplace=True)\r\n", | ||
241 | + "xTest_df.fillna(0, inplace=True)\r\n", | ||
242 | + "xVal_df.fillna(0, inplace=True)\r\n", | ||
243 | + "print(xTrain_df.head(40))" | ||
244 | + ], | ||
245 | + "execution_count": null, | ||
246 | + "outputs": [] | ||
247 | + }, | ||
248 | + { | ||
249 | + "cell_type": "code", | ||
250 | + "metadata": { | ||
251 | + "id": "lAjG_YjoSSQU" | ||
252 | + }, | ||
253 | + "source": [ | ||
254 | + "# 임시 파일저장\r\n", | ||
255 | + "xTrain_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTrain.csv\", mode='w', index=False, header=True, encoding='utf-8')\r\n", | ||
256 | + "xTest_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTest.csv\", mode='w', index=False, header=True, encoding='utf-8')\r\n", | ||
257 | + "xVal_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xVal.csv\", mode='w', index=False, header=True, encoding='utf-8')" | ||
258 | + ], | ||
259 | + "execution_count": null, | ||
260 | + "outputs": [] | ||
261 | + }, | ||
262 | + { | ||
263 | + "cell_type": "code", | ||
264 | + "metadata": { | ||
265 | + "id": "SpMsliPqVUjK" | ||
266 | + }, | ||
267 | + "source": [ | ||
268 | + "xTrain_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTrain.csv\")\r\n", | ||
269 | + "xTest_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTest.csv\")\r\n", | ||
270 | + "xVal_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xVal.csv\")\r\n", | ||
271 | + "print(xTrain_df.shape)\r\n", | ||
272 | + "print(xTest_df.shape)\r\n", | ||
273 | + "print(xVal_df.shape)" | ||
274 | + ], | ||
275 | + "execution_count": null, | ||
276 | + "outputs": [] | ||
277 | + }, | ||
278 | + { | ||
279 | + "cell_type": "code", | ||
280 | + "metadata": { | ||
281 | + "id": "OqmBTzoYB9zP" | ||
282 | + }, | ||
283 | + "source": [ | ||
284 | + "import pickle\r\n", | ||
285 | + "import os\r\n", | ||
286 | + "import sys" | ||
287 | + ], | ||
288 | + "execution_count": 4, | ||
289 | + "outputs": [] | ||
290 | + }, | ||
291 | + { | ||
292 | + "cell_type": "code", | ||
293 | + "metadata": { | ||
294 | + "id": "NgTPWw4vaJFU" | ||
295 | + }, | ||
296 | + "source": [ | ||
297 | + "from sklearn.preprocessing import StandardScaler\r\n", | ||
298 | + "# 정규화, sequence화\r\n", | ||
299 | + "scaler = StandardScaler()\r\n", | ||
300 | + "xTrain_ = scaler.fit_transform(xTrain_df)\r\n", | ||
301 | + "xTest_ = scaler.fit_transform(xTest_df)\r\n", | ||
302 | + "xVal_ = scaler.fit_transform(xVal_df)\r\n", | ||
303 | + "print(xTrain_.shape)\r\n", | ||
304 | + "print(xTest_.shape)\r\n", | ||
305 | + "print(xVal_.shape)\r\n", | ||
306 | + "\r\n", | ||
307 | + "def ds_to_windows(dataset, timestep) : # dataset을 원하는 window 길이(timestep)로 나누어 저장 \r\n", | ||
308 | + " #winset = np.array([])\r\n", | ||
309 | + " winset = []\r\n", | ||
310 | + " if timestep == 1 : # 사실상 window로 나누는 것이 아니라 3차원으로만 만들어줌\r\n", | ||
311 | + " for i in range((len(dataset))) :\r\n", | ||
312 | + " win = list(dataset[i]) # 길이가 32인 1개 행(1, 32) [[feature1,feature2,...,feature32]]\r\n", | ||
313 | + " winset.append(win)\r\n", | ||
314 | + " else:\r\n", | ||
315 | + " for i in range(dataset.shape[0]-timestep+1):\r\n", | ||
316 | + " win = list(dataset[i:(i+timestep)]) # win 자체가 이미 2차원 (timestep*feature)\r\n", | ||
317 | + " winset.append(win)\r\n", | ||
318 | + " if i % 50 == 0:\r\n", | ||
319 | + " print(i, \"/\", dataset.shape[0]-timestep+1)\r\n", | ||
320 | + " np_winset = np.asarray(winset)\r\n", | ||
321 | + " return np_winset\r\n", | ||
322 | + "\r\n", | ||
323 | + "# tmp = xTrain_[:10]\r\n", | ||
324 | + "\r\n", | ||
325 | + "# result = ds_to_windows(tmp,3)\r\n", | ||
326 | + "# resultnp = np.asarray(result)\r\n", | ||
327 | + "# print(resultnp.shape)\r\n", | ||
328 | + "\r\n", | ||
329 | + "# numpy array 형태\r\n", | ||
330 | + "# print(xTrainArr.shape)\r\n", | ||
331 | + "# print(xTestArr.shape)\r\n", | ||
332 | + "# print(xValArr.shape)\r\n", | ||
333 | + "\r\n", | ||
334 | + "xTrain = ds_to_windows(xTrain_,10)\r\n", | ||
335 | + "print(xTrain.shape)\r\n", | ||
336 | + "with open(\"xTrain.pickle\",\"wb\") as fw:\r\n", | ||
337 | + " pickle.dump(xTrain, fw)\r\n", | ||
338 | + "xTest = ds_to_windows(xTest_,10)\r\n", | ||
339 | + "print(xTest.shape)\r\n", | ||
340 | + "with open(\"xTest.pickle\",\"wb\") as fw:\r\n", | ||
341 | + " pickle.dump(xTest, fw)\r\n", | ||
342 | + "xVal = ds_to_windows(xVal_,10)\r\n", | ||
343 | + "print(xVal.shape)\r\n", | ||
344 | + "with open(\"xVal.pickle\",\"wb\") as fw:\r\n", | ||
345 | + " pickle.dump(xVal, fw)\r\n", | ||
346 | + "\r\n", | ||
347 | + "\r\n" | ||
348 | + ], | ||
349 | + "execution_count": null, | ||
350 | + "outputs": [] | ||
351 | + }, | ||
352 | + { | ||
353 | + "cell_type": "code", | ||
354 | + "metadata": { | ||
355 | + "id": "Xb16z5AGwjwZ" | ||
356 | + }, | ||
357 | + "source": [ | ||
358 | + "# Label도 seq화\r\n", | ||
359 | + "yTrain_ = np.asarray(yTrain_df)\r\n", | ||
360 | + "yTest_ = np.asarray(yTest_df)\r\n", | ||
361 | + "yVal_ = np.asarray(yVal_df)\r\n", | ||
362 | + "\r\n", | ||
363 | + "yTrainTmp = ds_to_windows(yTrain_,10)\r\n", | ||
364 | + "yTrain = []\r\n", | ||
365 | + "\r\n", | ||
366 | + "for i in range(yTrainTmp.shape[0]):\r\n", | ||
367 | + " ohv = np.array([0,0,0,0,0])\r\n", | ||
368 | + " ohv[(np.bincount(yTrainTmp[i]).argmax())] += 1\r\n", | ||
369 | + " yTrain.append(ohv)\r\n", | ||
370 | + "yTrain = np.asarray(yTrain)\r\n", | ||
371 | + "\r\n", | ||
372 | + "with open(\"yTrain.pickle\",\"wb\") as fw:\r\n", | ||
373 | + " pickle.dump(yTrain, fw)\r\n", | ||
374 | + "\r\n", | ||
375 | + "yTestTmp = ds_to_windows(yTest_,10)\r\n", | ||
376 | + "yTest = []\r\n", | ||
377 | + "for i in range(yTestTmp.shape[0]):\r\n", | ||
378 | + " ohv = np.array([0,0,0,0,0])\r\n", | ||
379 | + " ohv[(np.bincount(yTestTmp[i]).argmax())] += 1\r\n", | ||
380 | + " yTest.append(ohv)\r\n", | ||
381 | + "yTest = np.asarray(yTest)\r\n", | ||
382 | + "\r\n", | ||
383 | + "with open(\"yTest.pickle\",\"wb\") as fw:\r\n", | ||
384 | + " pickle.dump(yTestTmp, fw)\r\n", | ||
385 | + "\r\n", | ||
386 | + "yValTmp = ds_to_windows(yVal_,10)\r\n", | ||
387 | + "yVal = []\r\n", | ||
388 | + "for i in range(yValTmp.shape[0]):\r\n", | ||
389 | + " ohv = np.array([0,0,0,0,0])\r\n", | ||
390 | + " ohv[(np.bincount(yValTmp[i]).argmax())] += 1\r\n", | ||
391 | + " yVal.append(ohv)\r\n", | ||
392 | + "yVal = np.asarray(yVal)\r\n", | ||
393 | + "\r\n", | ||
394 | + "with open(\"yVal.pickle\",\"wb\") as fw:\r\n", | ||
395 | + " pickle.dump(yVal, fw)\r\n", | ||
396 | + "\r\n", | ||
397 | + "print(yTrain.shape)\r\n", | ||
398 | + "print(yTest.shape)\r\n", | ||
399 | + "print(yVal.shape)" | ||
400 | + ], | ||
401 | + "execution_count": null, | ||
402 | + "outputs": [] | ||
403 | + }, | ||
404 | + { | ||
405 | + "cell_type": "code", | ||
406 | + "metadata": { | ||
407 | + "id": "yGpaTlxbX-fa" | ||
408 | + }, | ||
409 | + "source": [ | ||
410 | + "print(\"xTrain: \", xTrain.shape)\r\n", | ||
411 | + "print(\"xTest: \", xTest.shape)\r\n", | ||
412 | + "print(\"xVal: \", xVal.shape)\r\n", | ||
413 | + "print(\"yTrain: \", yTrain.shape)\r\n", | ||
414 | + "print(\"yTest: \", yTest.shape)\r\n", | ||
415 | + "print(\"yVal: \", yVal.shape)" | ||
416 | + ], | ||
417 | + "execution_count": null, | ||
418 | + "outputs": [] | ||
419 | + }, | ||
420 | + { | ||
421 | + "cell_type": "code", | ||
422 | + "metadata": { | ||
423 | + "id": "5FcsBC6iGDrt" | ||
424 | + }, | ||
425 | + "source": [ | ||
426 | + "# Model\r\n", | ||
427 | + "from tensorflow.keras import Sequential\r\n", | ||
428 | + "from keras.layers import LSTM, Dense, Activation, Input\r\n", | ||
429 | + "from keras.models import Model\r\n", | ||
430 | + "from keras.optimizers import Adam\r\n", | ||
431 | + "from keras.callbacks import ModelCheckpoint, EarlyStopping\r\n", | ||
432 | + "\r\n", | ||
433 | + "\r\n", | ||
434 | + "model = Sequential()\r\n", | ||
435 | + "model.add(LSTM(128, dropout=0.2, return_sequences=False, input_shape=(10, 32)))\r\n", | ||
436 | + "model.add(Dense(128, activation='relu'))\r\n", | ||
437 | + "model.add(Dense(5, activation='softmax'))\r\n", | ||
438 | + "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\r\n", | ||
439 | + "model.summary()\r\n", | ||
440 | + "\r\n", | ||
441 | + "\r\n", | ||
442 | + "# input = Input(shape=(10, 32))\r\n", | ||
443 | + "# x = LSTM(128, return_sequences=True)(input)\r\n", | ||
444 | + "# x = LSTM(128, return_sequences=True)(x)\r\n", | ||
445 | + "# x = LSTM(128)(x)\r\n", | ||
446 | + "# x = Dense(5, activation='softmax')(x)\r\n", | ||
447 | + "# model = Model(input, x)\r\n", | ||
448 | + "# model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\r\n", | ||
449 | + "# model.summary()\r\n", | ||
450 | + "\r\n", | ||
451 | + "filename = 'checkpoint-epoch-trial-001.h5'\r\n", | ||
452 | + "checkpoint = ModelCheckpoint(filename, # file명을 지정합니다\r\n", | ||
453 | + " monitor='val_loss', # val_loss 값이 개선되었을때 호출됩니다\r\n", | ||
454 | + " verbose=0, # 로그를 출력합니다\r\n", | ||
455 | + " save_best_only=True, # 가장 best 값만 저장합니다\r\n", | ||
456 | + " mode='auto' # auto는 알아서 best를 찾습니다. min/max\r\n", | ||
457 | + " )\r\n", | ||
458 | + "\r\n", | ||
459 | + "# earlystopping = EarlyStopping(monitor='val_loss', # 모니터 기준 설정 (val loss) \r\n", | ||
460 | + "# patience=15, # 10회 Epoch동안 개선되지 않는다면 종료\r\n", | ||
461 | + "# )\r\n", | ||
462 | + "\r\n", | ||
463 | + "hist = model.fit(xTrain, yTrain, \r\n", | ||
464 | + " validation_data=(xVal, yVal),\r\n", | ||
465 | + " epochs=50,\r\n", | ||
466 | + " callbacks=[checkpoint, earlystopping], # checkpoint, earlystopping 콜백\r\n", | ||
467 | + " )\r\n", | ||
468 | + "\r\n", | ||
469 | + "\r\n", | ||
470 | + "\r\n", | ||
471 | + "# ## Load pickle\r\n", | ||
472 | + "# with open(\"data.pickle\",\"rb\") as fr:\r\n", | ||
473 | + "# data = pickle.load(fr)\r\n", | ||
474 | + "# print(data)\r\n", | ||
475 | + "# #['a', 'b', 'c']" | ||
476 | + ], | ||
477 | + "execution_count": null, | ||
478 | + "outputs": [] | ||
479 | + }, | ||
480 | + { | ||
481 | + "cell_type": "code", | ||
482 | + "metadata": { | ||
483 | + "id": "YPVG1OUZ5voT" | ||
484 | + }, | ||
485 | + "source": [ | ||
486 | + "\r\n", | ||
487 | + "# 학습과정 시각화\r\n", | ||
488 | + "import matplotlib.pyplot as plt\r\n", | ||
489 | + "\r\n", | ||
490 | + "fig, loss_ax = plt.subplots()\r\n", | ||
491 | + "\r\n", | ||
492 | + "acc_ax = loss_ax.twinx()\r\n", | ||
493 | + "\r\n", | ||
494 | + "loss_ax.set_ylim([0.0, 0.03])\r\n", | ||
495 | + "acc_ax.set_ylim([0.99, 1.0])\r\n", | ||
496 | + "\r\n", | ||
497 | + "loss_ax.plot(hist.history['loss'], 'y', label='train_loss')\r\n", | ||
498 | + "acc_ax.plot(hist.history['accuracy'], 'b', label='train_accracy')\r\n", | ||
499 | + "\r\n", | ||
500 | + "loss_ax.set_xlabel('epoch')\r\n", | ||
501 | + "loss_ax.set_ylabel('loss')\r\n", | ||
502 | + "acc_ax.set_ylabel('accuray')\r\n", | ||
503 | + "\r\n", | ||
504 | + "loss_ax.legend(loc='upper left')\r\n", | ||
505 | + "acc_ax.legend(loc='lower left')\r\n", | ||
506 | + "\r\n", | ||
507 | + "plt.show()" | ||
508 | + ], | ||
509 | + "execution_count": null, | ||
510 | + "outputs": [] | ||
511 | + }, | ||
512 | + { | ||
513 | + "cell_type": "code", | ||
514 | + "metadata": { | ||
515 | + "id": "U78Q9bvBXlFi" | ||
516 | + }, | ||
517 | + "source": [ | ||
518 | + "# # 성능평가\r\n", | ||
519 | + "loss_and_metrics = model.evaluate(xTest, yTest, batch_size=50)\r\n", | ||
520 | + "\r\n", | ||
521 | + "print('loss_and_metrics : ' + str(loss_and_metrics))\r\n" | ||
522 | + ], | ||
523 | + "execution_count": null, | ||
524 | + "outputs": [] | ||
525 | + } | ||
526 | + ] | ||
527 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Model/lstm_model.h5
0 → 100644
No preview for this file type
-
Please register or login to post a comment