Jihoon

구현코드, 모델(.h5 file) 및 데이터셋 업로드

No preview for this file type
1 +{
2 + "nbformat": 4,
3 + "nbformat_minor": 0,
4 + "metadata": {
5 + "colab": {
6 + "name": "LSTM_7.ipynb",
7 + "provenance": [],
8 + "collapsed_sections": []
9 + },
10 + "kernelspec": {
11 + "display_name": "Python 3",
12 + "name": "python3"
13 + },
14 + "accelerator": "GPU"
15 + },
16 + "cells": [
17 + {
18 + "cell_type": "code",
19 + "metadata": {
20 + "id": "aRHde3RC83kB"
21 + },
22 + "source": [
23 + "import pandas as pd\n",
24 + "import numpy as np\n",
25 + "from sklearn.model_selection import train_test_split\n",
26 + "from sklearn import metrics\n",
27 + "from keras.wrappers.scikit_learn import KerasClassifier\n",
28 + "from tensorflow.keras.models import Sequential\n",
29 + "from tensorflow.keras.layers import Dense, Activation\n",
30 + "from tensorflow.keras.callbacks import EarlyStopping\n",
31 + "import tensorflow.feature_column as feature_column\n",
32 + "from keras.utils.vis_utils import plot_model\n",
33 + "from tensorflow.keras import layers\n",
34 + "from sklearn.externals.six import StringIO \n",
35 + "from sklearn.metrics import confusion_matrix\n",
36 + "from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
37 + "from sklearn.metrics import roc_auc_score, roc_curve\n",
38 + "from sklearn.metrics import plot_confusion_matrix\n",
39 + "from warnings import simplefilter\n",
40 + "import time\n",
41 + "from ast import literal_eval\n",
42 + "\n",
43 + "# Sequence of dataset generating\n",
44 + "# malaria -> flood -> normal -> bruteforce -> malformed\n",
45 + "f_class0 = '/content/drive/MyDrive/Datasets/MQTTset/malaria.csv' # malariaDoS\n",
46 + "f_class1 = '/content/drive/MyDrive/Datasets/MQTTset/flood.csv' # flood\n",
47 + "f_class2 = '/content/drive/MyDrive/Datasets/MQTTset/legitimate1.csv' # normal\n",
48 + "f_class3 = '/content/drive/MyDrive/Datasets/MQTTset/bruteforce.csv' # bruteforce\n",
49 + "f_class4 = '/content/drive/MyDrive/Datasets/MQTTset/malformed.csv' # malformed\n",
50 + "\n",
51 + "# 클래스 분류번호\n",
52 + "# 0: malariaDoS\n",
53 + "# 1: flood\n",
54 + "# 2: normal\n",
55 + "# 3: bruteforce\n",
56 + "# 4: malformed\n",
57 + "\n",
58 + "fileList = [f_class0,f_class1,f_class2,f_class3,f_class4]\n",
59 + "targetList = [0,1,2,3,4] \n",
60 + "\n",
61 + "pd.set_option('display.max_columns', 33) # 출력할 열의 최대개수\n",
62 + "\n",
63 + "# feature 리스트\n",
64 + "allfeatures = ['frame.time_delta', 'frame.time_delta_displayed', 'frame.time_epoch', 'frame.time_invalid', 'frame.time_relative', 'eth.src', 'eth.dst', 'ip.src', 'ip.dst', 'tcp.srcport', 'tcp.dstport', 'tcp.flags', 'frame.cap_len', 'frame.len', 'frame.number', 'tcp.stream', 'tcp.analysis.initial_rtt', 'tcp.time_delta', 'tcp.len', 'tcp.window_size_value', 'tcp.checksum', 'mqtt.clientid', 'mqtt.clientid_len', 'mqtt.conack.flags', 'mqtt.conack.flags.reserved', 'mqtt.conack.flags.sp', 'mqtt.conack.val', 'mqtt.conflag.cleansess', 'mqtt.conflag.passwd', 'mqtt.conflag.qos', 'mqtt.conflag.reserved', 'mqtt.conflag.retain', 'mqtt.conflag.uname', 'mqtt.conflag.willflag', 'mqtt.conflags', 'mqtt.dupflag', 'mqtt.hdrflags', 'mqtt.kalive', 'mqtt.len', 'mqtt.msg', 'mqtt.msgid', 'mqtt.msgtype', 'mqtt.passwd', 'mqtt.passwd_len', 'mqtt.proto_len', 'mqtt.protoname', 'mqtt.qos', 'mqtt.retain', 'mqtt.sub.qos', 'mqtt.suback.qos', 'mqtt.topic', 'mqtt.topic_len', 'mqtt.username', 'mqtt.username_len', 'mqtt.ver', 'mqtt.willmsg', 'mqtt.willmsg_len', 'mqtt.willtopic', 'mqtt.willtopic_len', 'ip.proto']\n",
65 + "usedfeatures = ['tcp.flags', 'tcp.time_delta', 'tcp.len', 'mqtt.conack.flags', 'mqtt.conack.flags.reserved', 'mqtt.conack.flags.sp', 'mqtt.conack.val', 'mqtt.conflag.cleansess', 'mqtt.conflag.passwd', 'mqtt.conflag.qos', 'mqtt.conflag.reserved', 'mqtt.conflag.retain', 'mqtt.conflag.uname', 'mqtt.conflag.willflag', 'mqtt.conflags', 'mqtt.dupflag', 'mqtt.hdrflags', 'mqtt.kalive', 'mqtt.len', 'mqtt.msgid', 'mqtt.msgtype', 'mqtt.proto_len', 'mqtt.protoname', 'mqtt.qos', 'mqtt.retain', 'mqtt.sub.qos', 'mqtt.suback.qos', 'mqtt.ver', 'mqtt.willmsg', 'mqtt.willmsg_len', 'mqtt.willtopic', 'mqtt.willtopic_len']\n",
66 + "droppedfeatures = ['ip.src', 'ip.dst', 'frame.time_relative', 'frame.time_invalid', 'mqtt.username_len', 'frame.len', 'tcp.dstport', 'tcp.window_size_value', 'mqtt.username', 'mqtt.passwd_len', 'mqtt.topic', 'mqtt.topic_len', 'tcp.checksum', 'frame.cap_len', 'mqtt.passwd', 'frame.time_delta', 'eth.dst', 'mqtt.clientid', 'frame.time_epoch', 'frame.number', 'eth.src', 'mqtt.clientid_len', 'mqtt.msg', 'tcp.stream', 'frame.time_delta_displayed', 'tcp.analysis.initial_rtt', 'tcp.srcport', 'ip.proto']\n",
67 + "\n",
68 + "print(\"전체 feature 개수: \", len(allfeatures))\n",
69 + "print(\"사용된 feature 개수: \", len(usedfeatures))\n",
70 + "print(\"제외한 feature 개수: \", len(droppedfeatures))\n",
71 + "\n",
72 + "\n",
73 + "def split_dataset(dataframe, test_size): # 훈련/테스트/검증 데이터셋으로 분할\n",
74 + " train, test = train_test_split(dataframe, test_size=0.3, shuffle=False)\n",
75 + " train, val = train_test_split(train, test_size=0.3, shuffle=False)\n",
76 + " print(len(train), '훈련 샘플')\n",
77 + " print(len(val), '검증 샘플')\n",
78 + " print(len(test), '테스트 샘플')\n",
79 + " return train, test, val\n",
80 + "\n",
81 + "\n",
82 + "def modify_features(fileList, dropfeaturelist): # 데이터프레임에서 제외된 feature 열들을 drop, target label 추가\n",
83 + " _dfList = []\n",
84 + " for i in range(len(fileList)):\n",
85 + " df = pd.read_csv(fileList[i], encoding='euc-kr')\n",
86 + " for feature in dropfeaturelist:\n",
87 + " df = df.drop([feature],axis=1)\n",
88 + " df['target'] = i # labeling\n",
89 + " _dfList.append(df)\n",
90 + " return _dfList\n",
91 + "\n",
92 + "\n",
93 + "dfList = modify_features(fileList, droppedfeatures)\n",
94 + "xTrainList,yTrainList = [],[] # train 데이터\n",
95 + "xTestList,yTestList = [],[] # test 데이터\n",
96 + "xValList,yValList = [],[] # validation 데이터\n",
97 + "\n",
98 + "\n",
99 + "for df in dfList:\n",
100 + " df['mqtt.protoname'].fillna('No', inplace=True)\n",
101 + " df['mqtt.protoname'].replace('MQTT', 1, inplace=True)\n",
102 + " df['mqtt.protoname'].replace('No', 0, inplace=True)\n",
103 + " train, test, val = split_dataset(df, 0.3)\n",
104 + " trainLabel = train.pop('target')\n",
105 + " testLabel = test.pop('target')\n",
106 + " valLabel = val.pop('target')\n",
107 + " xTrainList.append(train)\n",
108 + " xTestList.append(test)\n",
109 + " xValList.append(val)\n",
110 + "\n",
111 + " yTrainList.append(trainLabel)\n",
112 + " yTestList.append(testLabel)\n",
113 + " yValList.append(valLabel)\n",
114 + " print(train.head(2))\n",
115 + "\n",
116 + "# trainList,testList,valList 각각 해당 리스트 원소끼리 merge\n",
117 + "\n",
118 + "\n",
119 + "for i in range(len(dfList)): # 길이 5의 df 리스트\n",
120 + " xTrain_df = pd.concat(xTrainList, ignore_index=True)\n",
121 + " yTrain_df = pd.concat(yTrainList, ignore_index=True)\n",
122 + "\n",
123 + " xTest_df = pd.concat(xTestList, ignore_index=True)\n",
124 + " yTest_df = pd.concat(yTestList, ignore_index=True)\n",
125 + "\n",
126 + " xVal_df = pd.concat(xValList, ignore_index=True)\n",
127 + " yVal_df = pd.concat(yValList, ignore_index=True)\n",
128 + "\n",
129 + "\n",
130 + "print(\"훈련데이터셋 형상\\n\", xTrain_df.shape)\n",
131 + "print(yTrain_df.shape)\n",
132 + "print(\"테스트데이터셋 형상\\n\", xTest_df.shape)\n",
133 + "print(yTest_df.shape)\n",
134 + "print(\"검증데이터셋 형상\\n\", xVal_df.shape)\n",
135 + "print(yVal_df.shape)\n",
136 + "\n",
137 + "\n"
138 + ],
139 + "execution_count": null,
140 + "outputs": []
141 + },
142 + {
143 + "cell_type": "code",
144 + "metadata": {
145 + "id": "gnzbus9Nsr-F",
146 + "colab": {
147 + "base_uri": "https://localhost:8080/"
148 + },
149 + "outputId": "d9343f9c-fe78-4109-bf13-8c91165332c2"
150 + },
151 + "source": [
152 + "\n",
153 + "# | 16진수 | 문자열 |\n",
154 + "\n",
155 + " # tcp.flags object # mqtt.protoname object\n",
156 + "\n",
157 + "\n",
158 + " # mqtt.conack.flags \n",
159 + " \n",
160 + " # mqtt.conflags\n",
161 + "\n",
162 + " # mqtt.hdrflags\n",
163 + "\n",
164 + "hexfeatures = ['tcp.flags','mqtt.conack.flags','mqtt.conflags','mqtt.hdrflags']\n",
165 + "\n",
166 + "# 문자열 feature인 column들을 embedding\n",
167 + "\n",
168 + "\n",
169 + "\n",
170 + "trainList = [xTrain_df, xTest_df, xVal_df]\n",
171 + "labelList = [yTrain_df, yTest_df, yVal_df]\n",
172 + "\n",
173 + "\n",
174 + "for hexa in hexfeatures:\n",
175 + " xTrain_df[hexa].fillna('0000', inplace=True)\n",
176 + " for i in range(xTrain_df.shape[0]):\n",
177 + " if (type(xTrain_df[hexa][i]) == str) and ('x' in xTrain_df[hexa][i]):\n",
178 + " xTrain_df[hexa][i] = int(xTrain_df[hexa][i],16)\n",
179 + " #print(type(xTrain_df[hexa][i]))\n",
180 + " else:\n",
181 + " xTrain_df[hexa][i] = int(xTrain_df[hexa][i])\n",
182 + "\n"
183 + ],
184 + "execution_count": null,
185 + "outputs": [
186 + {
187 + "output_type": "stream",
188 + "text": [
189 + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:26: SettingWithCopyWarning: \n",
190 + "A value is trying to be set on a copy of a slice from a DataFrame\n",
191 + "\n",
192 + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
193 + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:29: SettingWithCopyWarning: \n",
194 + "A value is trying to be set on a copy of a slice from a DataFrame\n",
195 + "\n",
196 + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n"
197 + ],
198 + "name": "stderr"
199 + }
200 + ]
201 + },
202 + {
203 + "cell_type": "code",
204 + "metadata": {
205 + "id": "7Fg_qxJwA0m1"
206 + },
207 + "source": [
208 + "\r\n",
209 + "for hexa in hexfeatures:\r\n",
210 + " xTest_df[hexa].fillna('0000', inplace=True)\r\n",
211 + " for i in range(xTest_df.shape[0]):\r\n",
212 + " if (type(xTest_df[hexa][i]) == str) and ('x' in xTest_df[hexa][i]):\r\n",
213 + " xTest_df[hexa][i] = int(xTest_df[hexa][i],16)\r\n",
214 + " #print(type(xTest_df[hexa][i]))\r\n",
215 + " else:\r\n",
216 + " xTest_df[hexa][i] = int(xTest_df[hexa][i])\r\n",
217 + "\r\n",
218 + "for hexa in hexfeatures:\r\n",
219 + " xVal_df[hexa].fillna('0000', inplace=True)\r\n",
220 + " for i in range(xVal_df.shape[0]):\r\n",
221 + " if (type(xVal_df[hexa][i]) == str) and ('x' in xVal_df[hexa][i]):\r\n",
222 + " xVal_df[hexa][i] = int(xVal_df[hexa][i],16)\r\n",
223 + " #print(type(xVal_df[hexa][i]))\r\n",
224 + " else:\r\n",
225 + " xVal_df[hexa][i] = int(xVal_df[hexa][i])\r\n",
226 + "\r\n",
227 + "print(xTrain_df.shape)\r\n",
228 + "print(xTest_df.shape)\r\n",
229 + "print(xVal_df.shape)"
230 + ],
231 + "execution_count": null,
232 + "outputs": []
233 + },
234 + {
235 + "cell_type": "code",
236 + "metadata": {
237 + "id": "imSEh-WcTvz-"
238 + },
239 + "source": [
240 + "xTrain_df.fillna(0, inplace=True)\r\n",
241 + "xTest_df.fillna(0, inplace=True)\r\n",
242 + "xVal_df.fillna(0, inplace=True)\r\n",
243 + "print(xTrain_df.head(40))"
244 + ],
245 + "execution_count": null,
246 + "outputs": []
247 + },
248 + {
249 + "cell_type": "code",
250 + "metadata": {
251 + "id": "lAjG_YjoSSQU"
252 + },
253 + "source": [
254 + "# 임시 파일저장\r\n",
255 + "xTrain_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTrain.csv\", mode='w', index=False, header=True, encoding='utf-8')\r\n",
256 + "xTest_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTest.csv\", mode='w', index=False, header=True, encoding='utf-8')\r\n",
257 + "xVal_df.to_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xVal.csv\", mode='w', index=False, header=True, encoding='utf-8')"
258 + ],
259 + "execution_count": null,
260 + "outputs": []
261 + },
262 + {
263 + "cell_type": "code",
264 + "metadata": {
265 + "id": "SpMsliPqVUjK"
266 + },
267 + "source": [
268 + "xTrain_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTrain.csv\")\r\n",
269 + "xTest_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xTest.csv\")\r\n",
270 + "xVal_df = pd.read_csv(\"/content/drive/MyDrive/Datasets/MQTTset/xVal.csv\")\r\n",
271 + "print(xTrain_df.shape)\r\n",
272 + "print(xTest_df.shape)\r\n",
273 + "print(xVal_df.shape)"
274 + ],
275 + "execution_count": null,
276 + "outputs": []
277 + },
278 + {
279 + "cell_type": "code",
280 + "metadata": {
281 + "id": "OqmBTzoYB9zP"
282 + },
283 + "source": [
284 + "import pickle\r\n",
285 + "import os\r\n",
286 + "import sys"
287 + ],
288 + "execution_count": 4,
289 + "outputs": []
290 + },
291 + {
292 + "cell_type": "code",
293 + "metadata": {
294 + "id": "NgTPWw4vaJFU"
295 + },
296 + "source": [
297 + "from sklearn.preprocessing import StandardScaler\r\n",
298 + "# 정규화, sequence화\r\n",
299 + "scaler = StandardScaler()\r\n",
300 + "xTrain_ = scaler.fit_transform(xTrain_df)\r\n",
301 + "xTest_ = scaler.fit_transform(xTest_df)\r\n",
302 + "xVal_ = scaler.fit_transform(xVal_df)\r\n",
303 + "print(xTrain_.shape)\r\n",
304 + "print(xTest_.shape)\r\n",
305 + "print(xVal_.shape)\r\n",
306 + "\r\n",
307 + "def ds_to_windows(dataset, timestep) : # dataset을 원하는 window 길이(timestep)로 나누어 저장 \r\n",
308 + " #winset = np.array([])\r\n",
309 + " winset = []\r\n",
310 + " if timestep == 1 : # 사실상 window로 나누는 것이 아니라 3차원으로만 만들어줌\r\n",
311 + " for i in range((len(dataset))) :\r\n",
312 + " win = list(dataset[i]) # 길이가 32인 1개 행(1, 32) [[feature1,feature2,...,feature32]]\r\n",
313 + " winset.append(win)\r\n",
314 + " else:\r\n",
315 + " for i in range(dataset.shape[0]-timestep+1):\r\n",
316 + " win = list(dataset[i:(i+timestep)]) # win 자체가 이미 2차원 (timestep*feature)\r\n",
317 + " winset.append(win)\r\n",
318 + " if i % 50 == 0:\r\n",
319 + " print(i, \"/\", dataset.shape[0]-timestep+1)\r\n",
320 + " np_winset = np.asarray(winset)\r\n",
321 + " return np_winset\r\n",
322 + "\r\n",
323 + "# tmp = xTrain_[:10]\r\n",
324 + "\r\n",
325 + "# result = ds_to_windows(tmp,3)\r\n",
326 + "# resultnp = np.asarray(result)\r\n",
327 + "# print(resultnp.shape)\r\n",
328 + "\r\n",
329 + "# numpy array 형태\r\n",
330 + "# print(xTrainArr.shape)\r\n",
331 + "# print(xTestArr.shape)\r\n",
332 + "# print(xValArr.shape)\r\n",
333 + "\r\n",
334 + "xTrain = ds_to_windows(xTrain_,10)\r\n",
335 + "print(xTrain.shape)\r\n",
336 + "with open(\"xTrain.pickle\",\"wb\") as fw:\r\n",
337 + " pickle.dump(xTrain, fw)\r\n",
338 + "xTest = ds_to_windows(xTest_,10)\r\n",
339 + "print(xTest.shape)\r\n",
340 + "with open(\"xTest.pickle\",\"wb\") as fw:\r\n",
341 + " pickle.dump(xTest, fw)\r\n",
342 + "xVal = ds_to_windows(xVal_,10)\r\n",
343 + "print(xVal.shape)\r\n",
344 + "with open(\"xVal.pickle\",\"wb\") as fw:\r\n",
345 + " pickle.dump(xVal, fw)\r\n",
346 + "\r\n",
347 + "\r\n"
348 + ],
349 + "execution_count": null,
350 + "outputs": []
351 + },
352 + {
353 + "cell_type": "code",
354 + "metadata": {
355 + "id": "Xb16z5AGwjwZ"
356 + },
357 + "source": [
358 + "# Label도 seq화\r\n",
359 + "yTrain_ = np.asarray(yTrain_df)\r\n",
360 + "yTest_ = np.asarray(yTest_df)\r\n",
361 + "yVal_ = np.asarray(yVal_df)\r\n",
362 + "\r\n",
363 + "yTrainTmp = ds_to_windows(yTrain_,10)\r\n",
364 + "yTrain = []\r\n",
365 + "\r\n",
366 + "for i in range(yTrainTmp.shape[0]):\r\n",
367 + " ohv = np.array([0,0,0,0,0])\r\n",
368 + " ohv[(np.bincount(yTrainTmp[i]).argmax())] += 1\r\n",
369 + " yTrain.append(ohv)\r\n",
370 + "yTrain = np.asarray(yTrain)\r\n",
371 + "\r\n",
372 + "with open(\"yTrain.pickle\",\"wb\") as fw:\r\n",
373 + " pickle.dump(yTrain, fw)\r\n",
374 + "\r\n",
375 + "yTestTmp = ds_to_windows(yTest_,10)\r\n",
376 + "yTest = []\r\n",
377 + "for i in range(yTestTmp.shape[0]):\r\n",
378 + " ohv = np.array([0,0,0,0,0])\r\n",
379 + " ohv[(np.bincount(yTestTmp[i]).argmax())] += 1\r\n",
380 + " yTest.append(ohv)\r\n",
381 + "yTest = np.asarray(yTest)\r\n",
382 + "\r\n",
383 + "with open(\"yTest.pickle\",\"wb\") as fw:\r\n",
384 + " pickle.dump(yTestTmp, fw)\r\n",
385 + "\r\n",
386 + "yValTmp = ds_to_windows(yVal_,10)\r\n",
387 + "yVal = []\r\n",
388 + "for i in range(yValTmp.shape[0]):\r\n",
389 + " ohv = np.array([0,0,0,0,0])\r\n",
390 + " ohv[(np.bincount(yValTmp[i]).argmax())] += 1\r\n",
391 + " yVal.append(ohv)\r\n",
392 + "yVal = np.asarray(yVal)\r\n",
393 + "\r\n",
394 + "with open(\"yVal.pickle\",\"wb\") as fw:\r\n",
395 + " pickle.dump(yVal, fw)\r\n",
396 + "\r\n",
397 + "print(yTrain.shape)\r\n",
398 + "print(yTest.shape)\r\n",
399 + "print(yVal.shape)"
400 + ],
401 + "execution_count": null,
402 + "outputs": []
403 + },
404 + {
405 + "cell_type": "code",
406 + "metadata": {
407 + "id": "yGpaTlxbX-fa"
408 + },
409 + "source": [
410 + "print(\"xTrain: \", xTrain.shape)\r\n",
411 + "print(\"xTest: \", xTest.shape)\r\n",
412 + "print(\"xVal: \", xVal.shape)\r\n",
413 + "print(\"yTrain: \", yTrain.shape)\r\n",
414 + "print(\"yTest: \", yTest.shape)\r\n",
415 + "print(\"yVal: \", yVal.shape)"
416 + ],
417 + "execution_count": null,
418 + "outputs": []
419 + },
420 + {
421 + "cell_type": "code",
422 + "metadata": {
423 + "id": "5FcsBC6iGDrt"
424 + },
425 + "source": [
426 + "# Model\r\n",
427 + "from tensorflow.keras import Sequential\r\n",
428 + "from keras.layers import LSTM, Dense, Activation, Input\r\n",
429 + "from keras.models import Model\r\n",
430 + "from keras.optimizers import Adam\r\n",
431 + "from keras.callbacks import ModelCheckpoint, EarlyStopping\r\n",
432 + "\r\n",
433 + "\r\n",
434 + "model = Sequential()\r\n",
435 + "model.add(LSTM(128, dropout=0.2, return_sequences=False, input_shape=(10, 32)))\r\n",
436 + "model.add(Dense(128, activation='relu'))\r\n",
437 + "model.add(Dense(5, activation='softmax'))\r\n",
438 + "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\r\n",
439 + "model.summary()\r\n",
440 + "\r\n",
441 + "\r\n",
442 + "# input = Input(shape=(10, 32))\r\n",
443 + "# x = LSTM(128, return_sequences=True)(input)\r\n",
444 + "# x = LSTM(128, return_sequences=True)(x)\r\n",
445 + "# x = LSTM(128)(x)\r\n",
446 + "# x = Dense(5, activation='softmax')(x)\r\n",
447 + "# model = Model(input, x)\r\n",
448 + "# model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\r\n",
449 + "# model.summary()\r\n",
450 + "\r\n",
451 + "filename = 'checkpoint-epoch-trial-001.h5'\r\n",
452 + "checkpoint = ModelCheckpoint(filename, # file명을 지정합니다\r\n",
453 + " monitor='val_loss', # val_loss 값이 개선되었을때 호출됩니다\r\n",
454 + " verbose=0, # 로그를 출력합니다\r\n",
455 + " save_best_only=True, # 가장 best 값만 저장합니다\r\n",
456 + " mode='auto' # auto는 알아서 best를 찾습니다. min/max\r\n",
457 + " )\r\n",
458 + "\r\n",
459 + "# earlystopping = EarlyStopping(monitor='val_loss', # 모니터 기준 설정 (val loss) \r\n",
460 + "# patience=15, # 10회 Epoch동안 개선되지 않는다면 종료\r\n",
461 + "# )\r\n",
462 + "\r\n",
463 + "hist = model.fit(xTrain, yTrain, \r\n",
464 + " validation_data=(xVal, yVal),\r\n",
465 + " epochs=50,\r\n",
466 + " callbacks=[checkpoint, earlystopping], # checkpoint, earlystopping 콜백\r\n",
467 + " )\r\n",
468 + "\r\n",
469 + "\r\n",
470 + "\r\n",
471 + "# ## Load pickle\r\n",
472 + "# with open(\"data.pickle\",\"rb\") as fr:\r\n",
473 + "# data = pickle.load(fr)\r\n",
474 + "# print(data)\r\n",
475 + "# #['a', 'b', 'c']"
476 + ],
477 + "execution_count": null,
478 + "outputs": []
479 + },
480 + {
481 + "cell_type": "code",
482 + "metadata": {
483 + "id": "YPVG1OUZ5voT"
484 + },
485 + "source": [
486 + "\r\n",
487 + "# 학습과정 시각화\r\n",
488 + "import matplotlib.pyplot as plt\r\n",
489 + "\r\n",
490 + "fig, loss_ax = plt.subplots()\r\n",
491 + "\r\n",
492 + "acc_ax = loss_ax.twinx()\r\n",
493 + "\r\n",
494 + "loss_ax.set_ylim([0.0, 0.03])\r\n",
495 + "acc_ax.set_ylim([0.99, 1.0])\r\n",
496 + "\r\n",
497 + "loss_ax.plot(hist.history['loss'], 'y', label='train_loss')\r\n",
498 + "acc_ax.plot(hist.history['accuracy'], 'b', label='train_accracy')\r\n",
499 + "\r\n",
500 + "loss_ax.set_xlabel('epoch')\r\n",
501 + "loss_ax.set_ylabel('loss')\r\n",
502 + "acc_ax.set_ylabel('accuray')\r\n",
503 + "\r\n",
504 + "loss_ax.legend(loc='upper left')\r\n",
505 + "acc_ax.legend(loc='lower left')\r\n",
506 + "\r\n",
507 + "plt.show()"
508 + ],
509 + "execution_count": null,
510 + "outputs": []
511 + },
512 + {
513 + "cell_type": "code",
514 + "metadata": {
515 + "id": "U78Q9bvBXlFi"
516 + },
517 + "source": [
518 + "# # 성능평가\r\n",
519 + "loss_and_metrics = model.evaluate(xTest, yTest, batch_size=50)\r\n",
520 + "\r\n",
521 + "print('loss_and_metrics : ' + str(loss_and_metrics))\r\n"
522 + ],
523 + "execution_count": null,
524 + "outputs": []
525 + }
526 + ]
527 +}
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type