김지훈

연합학습 code 정리중

1 -CAN_ID_BIT = 29
...\ No newline at end of file ...\ No newline at end of file
1 +CAN_DATA_LEN = 8
2 +SYNCAN_DATA_LEN = 4
...\ No newline at end of file ...\ No newline at end of file
......
1 -import os
2 import torch 1 import torch
3 import pandas as pd 2 import pandas as pd
4 import numpy as np 3 import numpy as np
5 from torch.utils.data import Dataset, DataLoader 4 from torch.utils.data import Dataset, DataLoader
5 +from torch.utils.data.sampler import Sampler
6 import const 6 import const
7 7
8 -'''
9 -def int_to_binary(x, bits):
10 - mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
11 - return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
12 -'''
13 8
14 -def unpack_bits(x, num_bits): 9 +# 0, batch * 1, batch * 2 ...
15 - """ 10 +class BatchIntervalSampler(Sampler):
16 - Args:
17 - x (int): bit로 변환할 정수
18 - num_bits (int): 표현할 비트수
19 - """
20 - xshape = list(x.shape)
21 - x = x.reshape([-1, 1])
22 - mask = 2**np.arange(num_bits).reshape([1, num_bits])
23 - return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])
24 11
12 + def __init__(self, data_length, batch_size):
13 + # data length 가 batch size 로 나뉘게 만듦
14 + if data_length % batch_size != 0:
15 + data_length = data_length - (data_length % batch_size)
25 16
26 -# def CsvToNumpy(csv_file): 17 + self.indices =[]
27 -# target_csv = pd.read_csv(csv_file) 18 + # print(data_length)
28 -# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' 19 + batch_group_interval = int(data_length / batch_size)
29 -# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' 20 + for group_idx in range(batch_group_interval):
30 -# print(inputs_save_numpy, labels_save_numpy) 21 + for local_idx in range(batch_size):
22 + self.indices.append(group_idx + local_idx * batch_group_interval)
23 + # print('sampler init', self.indices)
31 24
32 -# i = 0 25 + def __iter__(self):
33 -# inputs_array = [] 26 + return iter(self.indices)
34 -# labels_array = []
35 -# print(len(target_csv))
36 27
37 -# while i + const.CAN_ID_BIT - 1 < len(target_csv): 28 + def __len__(self):
38 - 29 + return len(self.indices)
39 -# is_regular = True
40 -# for j in range(const.CAN_ID_BIT):
41 -# l = target_csv.iloc[i + j]
42 -# b = l[2]
43 -# r = (l[b+2+1] == 'R')
44 -
45 -# if not r:
46 -# is_regular = False
47 -# break
48 -
49 -# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
50 -# for idx in range(const.CAN_ID_BIT):
51 -# can_id = int(target_csv.iloc[i + idx, 1], 16)
52 -# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT)
53 -# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
54 -
55 -# if is_regular:
56 -# labels = 1
57 -# else:
58 -# labels = 0
59 -
60 -# inputs_array.append(inputs)
61 -# labels_array.append(labels)
62 -
63 -# i+=1
64 -# if (i % 5000 == 0):
65 -# print(i)
66 -# # break
67 -
68 -# inputs_array = np.array(inputs_array)
69 -# labels_array = np.array(labels_array)
70 -# np.save(inputs_save_numpy, arr=inputs_array)
71 -# np.save(labels_save_numpy, arr=labels_array)
72 -# print('done')
73 -
74 -
75 -def CsvToText(csv_file):
76 - target_csv = pd.read_csv(csv_file)
77 - text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt'
78 - print(text_file_name)
79 - target_text = open(text_file_name, mode='wt', encoding='utf-8')
80 -
81 - i = 0
82 - datum = [ [], [] ]
83 - print(len(target_csv))
84 -
85 - while i + const.CAN_ID_BIT - 1 < len(target_csv):
86 -
87 - is_regular = True
88 - for j in range(const.CAN_ID_BIT):
89 - l = target_csv.iloc[i + j]
90 - b = l[2]
91 - r = (l[b+2+1] == 'R')
92 -
93 - if not r:
94 - is_regular = False
95 - break
96 -
97 - if is_regular:
98 - target_text.write("%d R\n" % i)
99 - else:
100 - target_text.write("%d T\n" % i)
101 -
102 - i+=1
103 - if (i % 5000 == 0):
104 - print(i)
105 -
106 - target_text.close()
107 - print('done')
108 30
109 31
110 def record_net_data_stats(label_temp, data_idx_map): 32 def record_net_data_stats(label_temp, data_idx_map):
...@@ -120,205 +42,92 @@ def record_net_data_stats(label_temp, data_idx_map): ...@@ -120,205 +42,92 @@ def record_net_data_stats(label_temp, data_idx_map):
120 return net_class_count, net_data_count 42 return net_class_count, net_data_count
121 43
122 44
123 -def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs): 45 +def GetCanDataset(total_edge, fold_num, packet_num, csv_path, txt_path):
124 - csv_list = [] 46 + csv = pd.read_csv(csv_path)
125 - total_datum = [] 47 + txt = open(txt_path, "r")
126 - total_label_temp = [] 48 + lines = txt.read().splitlines()
127 - csv_idx = 0
128 - for csv_file, txt_file in kwargs.items():
129 - csv = pd.read_csv(csv_file)
130 - csv_list.append(csv)
131 -
132 - txt = open(txt_file, "r")
133 - lines = txt.read().splitlines()
134 49
135 - idx = 0 50 + idx = 0
136 - local_datum = [] 51 + datum = []
137 - while idx + const.CAN_ID_BIT - 1 < len(csv): 52 + label_temp = []
138 - line = lines[idx] 53 + # [cur_idx ~ cur_idx + packet_num)
139 - if not line: 54 + while idx + packet_num - 1 < len(csv) // 2:
140 - break 55 + line = lines[idx + packet_num - 1]
56 + if not line:
57 + break
141 58
142 - if line.split(' ')[1] == 'R': 59 + if line.split(' ')[1] == 'R':
143 - local_datum.append((csv_idx, idx, 1)) 60 + datum.append((idx, 1))
144 - total_label_temp.append(1) 61 + label_temp.append(1)
145 - else: 62 + else:
146 - local_datum.append((csv_idx, idx, 0)) 63 + datum.append((idx, 0))
147 - total_label_temp.append(0) 64 + label_temp.append(0)
148 65
149 - idx += 1 66 + idx += 1
150 - if (idx % 1000000 == 0): 67 + if (idx % 1000000 == 0):
151 - print(idx) 68 + print(idx)
152 69
153 - csv_idx += 1 70 + fold_length = int(len(label_temp) / 5)
154 - total_datum += local_datum 71 + train_datum = []
155 - 72 + train_label_temp = []
156 - fold_length = int(len(total_label_temp) / 5)
157 - datum = []
158 - label_temp = []
159 for i in range(5): 73 for i in range(5):
160 if i != fold_num: 74 if i != fold_num:
161 - datum += total_datum[i*fold_length:(i+1)*fold_length] 75 + train_datum += datum[i*fold_length:(i+1)*fold_length]
162 - label_temp += total_label_temp[i*fold_length:(i+1)*fold_length] 76 + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length]
163 else: 77 else:
164 - test_datum = total_datum[i*fold_length:(i+1)*fold_length] 78 + test_datum = datum[i*fold_length:(i+1)*fold_length]
165 79
166 - min_size = 0
167 - output_class_num = 2
168 - N = len(label_temp)
169 - label_temp = np.array(label_temp)
170 - data_idx_map = {}
171 80
172 - while min_size < 512: 81 + N = len(train_label_temp)
173 - idx_batch = [[] for _ in range(total_edge)] 82 + train_label_temp = np.array(train_label_temp)
174 - # for each class in the dataset
175 - for k in range(output_class_num):
176 - idx_k = np.where(label_temp == k)[0]
177 - np.random.shuffle(idx_k)
178 - proportions = np.random.dirichlet(np.repeat(1, total_edge))
179 - ## Balance
180 - proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
181 - proportions = proportions/proportions.sum()
182 - proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
183 - idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
184 - min_size = min([len(idx_j) for idx_j in idx_batch])
185 83
84 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
85 + proportions = np.cumsum(proportions)
86 + idx_batch = [[] for _ in range(total_edge)]
87 + data_idx_map = {}
88 + prev = 0.0
186 for j in range(total_edge): 89 for j in range(total_edge):
187 - np.random.shuffle(idx_batch[j]) 90 + idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))]
91 + prev = proportions[j]
188 data_idx_map[j] = idx_batch[j] 92 data_idx_map[j] = idx_batch[j]
189 93
190 - net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map) 94 + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map)
191 95
192 - return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False) 96 + return CanDataset(csv, train_datum, packet_num), data_idx_map, net_data_count, CanDataset(csv, test_datum, packet_num, False)
193 97
194 98
195 -class CanDatasetKwarg(Dataset): 99 +class CanDataset(Dataset):
196 100
197 - def __init__(self, csv_list, datum, is_train=True): 101 + def __init__(self, csv, datum, packet_num, is_train=True):
198 - self.csv_list = csv_list 102 + self.csv = csv
199 self.datum = datum 103 self.datum = datum
104 + self.packet_num = packet_num
200 if is_train: 105 if is_train:
201 self.idx_map = [] 106 self.idx_map = []
202 else: 107 else:
203 self.idx_map = [idx for idx in range(len(self.datum))] 108 self.idx_map = [idx for idx in range(len(self.datum))]
204 109
205 def __len__(self): 110 def __len__(self):
206 - return len(self.idx_map) 111 + return len(self.idx_map) - self.packet_num + 1
207 112
208 def set_idx_map(self, data_idx_map): 113 def set_idx_map(self, data_idx_map):
209 self.idx_map = data_idx_map 114 self.idx_map = data_idx_map
210 115
211 def __getitem__(self, idx): 116 def __getitem__(self, idx):
212 - csv_idx = self.datum[self.idx_map[idx]][0] 117 + # [cur_idx ~ cur_idx + packet_num)
213 - start_i = self.datum[self.idx_map[idx]][1] 118 + start_i = self.datum[self.idx_map[idx]][0]
214 - is_regular = self.datum[self.idx_map[idx]][2] 119 + is_regular = self.datum[self.idx_map[idx + self.packet_num - 1]][1]
215 -
216 - l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
217 - for i in range(const.CAN_ID_BIT):
218 - id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16)
219 - bits = unpack_bits(np.array(id_), const.CAN_ID_BIT)
220 - l[i] = bits
221 - l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
222 -
223 - return (l, is_regular)
224 -
225 -
226 -def GetCanDatasetUsingTxt(csv_file, txt_path, length):
227 - csv = pd.read_csv(csv_file)
228 - txt = open(txt_path, "r")
229 - lines = txt.read().splitlines()
230 -
231 - idx = 0
232 - datum = [ [], [] ]
233 - while idx + const.CAN_ID_BIT - 1 < len(csv):
234 - if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
235 - break
236 -
237 - line = lines[idx]
238 - if not line:
239 - break
240 -
241 - if line.split(' ')[1] == 'R':
242 - if len(datum[0]) < length//2:
243 - datum[0].append((idx, 1))
244 - else:
245 - if len(datum[1]) < length//2:
246 - datum[1].append((idx, 0))
247 -
248 - idx += 1
249 - if (idx % 5000 == 0):
250 - print(idx, len(datum[0]), len(datum[1]))
251 -
252 - l = int((length // 2) * 0.9)
253 - return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
254 - CanDataset(csv, datum[0][l:] + datum[1][l:])
255 -
256 -
257 -def GetCanDataset(csv_file, length):
258 - csv = pd.read_csv(csv_file)
259 -
260 - i = 0
261 - datum = [ [], [] ]
262 -
263 - while i + const.CAN_ID_BIT - 1 < len(csv):
264 - if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
265 - break
266 -
267 - is_regular = True
268 - for j in range(const.CAN_ID_BIT):
269 - l = csv.iloc[i + j]
270 - b = l[2]
271 - r = (l[b+2+1] == 'R')
272 -
273 - if not r:
274 - is_regular = False
275 - break
276 -
277 - if is_regular:
278 - if len(datum[0]) < length//2:
279 - datum[0].append((i, 1))
280 - else:
281 - if len(datum[1]) < length//2:
282 - datum[1].append((i, 0))
283 - i+=1
284 - if (i % 5000 == 0):
285 - print(i, len(datum[0]), len(datum[1]))
286 -
287 - l = int((length // 2) * 0.9)
288 - return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
289 - CanDataset(csv, datum[0][l:] + datum[1][l:])
290 -
291 -
292 -class CanDataset(Dataset):
293 -
294 - def __init__(self, csv, datum):
295 - self.csv = csv
296 - self.datum = datum
297 -
298 - def __len__(self):
299 - return len(self.datum)
300 -
301 - def __getitem__(self, idx):
302 - start_i = self.datum[idx][0]
303 - is_regular = self.datum[idx][1]
304 120
305 - l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) 121 + packet = np.zeros((const.CAN_DATA_LEN * self.packet_num))
306 - for i in range(const.CAN_ID_BIT): 122 + for next_i in range(self.packet_num):
307 - id = int(self.csv.iloc[start_i + i, 1], 16) 123 + packet = np.zeros((const.CAN_DATA_LEN * self.packet_num))
308 - bits = unpack_bits(np.array(id), const.CAN_ID_BIT) 124 + data_len = self.csv.iloc[start_i + next_i, 1]
309 - l[i] = bits 125 + for j in range(data_len):
310 - l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) 126 + data_value = int(self.csv.iloc[start_i + next_i, 2 + j], 16) / 255.0
127 + packet[j + const.CAN_DATA_LEN * next_i] = data_value
311 128
312 - return (l, is_regular) 129 + return torch.from_numpy(packet).float(), is_regular
313 130
314 131
315 if __name__ == "__main__": 132 if __name__ == "__main__":
316 - kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} 133 + pass
317 - test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs)
318 - testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size,
319 - shuffle=False, num_workers=2)
320 -
321 - for x, y in testloader:
322 - print(x)
323 - print(y)
324 - break
......
1 +#### utils ####
2 +# for mixed dataset
3 +def CsvToTextCNN(csv_file):
4 + target_csv = pd.read_csv(csv_file)
5 + file_name, extension = os.path.splitext(csv_file)
6 + print(file_name, extension)
7 + target_text = open(file_name + '_CNN8.txt', mode='wt', encoding='utf-8')
8 +
9 + idx = 0
10 + print(len(target_csv))
11 +
12 + while idx + const.CNN_FRAME_LEN - 1 < len(target_csv):
13 +
14 + is_regular = True
15 + for j in range(const.CNN_FRAME_LEN):
16 + l = target_csv.iloc[idx + j]
17 + b = l[1]
18 + r = (l[b+2] == 'R')
19 +
20 + if not r:
21 + is_regular = False
22 + break
23 +
24 + if is_regular:
25 + target_text.write("%d R\n" % idx)
26 + else:
27 + target_text.write("%d T\n" % idx)
28 +
29 + idx += 1
30 + if idx % 300000 == 0:
31 + print(idx)
32 +
33 + target_text.close()
34 + print('done')
35 +
36 +
37 +
38 +#### dataset ####
39 +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs):
40 + csv_list = []
41 + total_datum = []
42 + total_label_temp = []
43 + csv_idx = 0
44 + for csv_file, txt_file in kwargs.items():
45 + csv = pd.read_csv(csv_file)
46 + csv_list.append(csv)
47 +
48 + txt = open(txt_file, "r")
49 + lines = txt.read().splitlines()
50 +
51 + idx = 0
52 + local_datum = []
53 + while idx + const.CAN_ID_BIT - 1 < len(csv):
54 + line = lines[idx]
55 + if not line:
56 + break
57 +
58 + if line.split(' ')[1] == 'R':
59 + local_datum.append((csv_idx, idx, 1))
60 + total_label_temp.append(1)
61 + else:
62 + local_datum.append((csv_idx, idx, 0))
63 + total_label_temp.append(0)
64 +
65 + idx += 1
66 + if (idx % 1000000 == 0):
67 + print(idx)
68 +
69 + csv_idx += 1
70 + total_datum += local_datum
71 +
72 + fold_length = int(len(total_label_temp) / 5)
73 + datum = []
74 + label_temp = []
75 + for i in range(5):
76 + if i != fold_num:
77 + datum += total_datum[i*fold_length:(i+1)*fold_length]
78 + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length]
79 + else:
80 + test_datum = total_datum[i*fold_length:(i+1)*fold_length]
81 +
82 + min_size = 0
83 + output_class_num = 2
84 + N = len(label_temp)
85 + label_temp = np.array(label_temp)
86 + data_idx_map = {}
87 +
88 + while min_size < 512:
89 + idx_batch = [[] for _ in range(total_edge)]
90 + # for each class in the dataset
91 + for k in range(output_class_num):
92 + idx_k = np.where(label_temp == k)[0]
93 + np.random.shuffle(idx_k)
94 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
95 + ## Balance
96 + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
97 + proportions = proportions/proportions.sum()
98 + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
99 + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
100 + min_size = min([len(idx_j) for idx_j in idx_batch])
101 +
102 + for j in range(total_edge):
103 + np.random.shuffle(idx_batch[j])
104 + data_idx_map[j] = idx_batch[j]
105 +
106 + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map)
107 +
108 + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False)
109 +
110 +
111 +class CanDatasetKwarg(Dataset):
112 +
113 + def __init__(self, csv_list, datum, is_train=True):
114 + self.csv_list = csv_list
115 + self.datum = datum
116 + if is_train:
117 + self.idx_map = []
118 + else:
119 + self.idx_map = [idx for idx in range(len(self.datum))]
120 +
121 + def __len__(self):
122 + return len(self.idx_map)
123 +
124 + def set_idx_map(self, data_idx_map):
125 + self.idx_map = data_idx_map
126 +
127 + def __getitem__(self, idx):
128 + csv_idx = self.datum[self.idx_map[idx]][0]
129 + start_i = self.datum[self.idx_map[idx]][1]
130 + is_regular = self.datum[self.idx_map[idx]][2]
131 +
132 + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
133 + for i in range(const.CAN_ID_BIT):
134 + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16)
135 + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT)
136 + l[i] = bits
137 + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
138 +
139 + return (l, is_regular)
140 +
141 +
142 +def GetCanDataset(total_edge, fold_num, csv_path, txt_path):
143 + csv = pd.read_csv(csv_path)
144 + txt = open(txt_path, "r")
145 + lines = txt.read().splitlines()
146 + frame_size = const.CAN_FRAME_LEN
147 + idx = 0
148 + datum = []
149 + label_temp = []
150 + while idx + frame_size - 1 < len(csv) // 2:
151 + # csv_row = csv.iloc[idx + frame_size - 1]
152 + # data_len = csv_row[1]
153 + # is_regular = (csv_row[data_len + 2] == 'R')
154 +
155 + # if is_regular:
156 + # datum.append((idx, 1))
157 + # label_temp.append(1)
158 + # else:
159 + # datum.append((idx, 0))
160 + # label_temp.append(0)
161 + line = lines[idx]
162 + if not line:
163 + break
164 +
165 + if line.split(' ')[1] == 'R':
166 + datum.append((idx, 1))
167 + label_temp.append(1)
168 + else:
169 + datum.append((idx, 0))
170 + label_temp.append(0)
171 +
172 + idx += 1
173 + if (idx % 1000000 == 0):
174 + print(idx)
175 +
176 + fold_length = int(len(label_temp) / 5)
177 + train_datum = []
178 + train_label_temp = []
179 + for i in range(5):
180 + if i != fold_num:
181 + train_datum += datum[i*fold_length:(i+1)*fold_length]
182 + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length]
183 + else:
184 + test_datum = datum[i*fold_length:(i+1)*fold_length]
185 +
186 + min_size = 0
187 + output_class_num = 2
188 + N = len(train_label_temp)
189 + train_label_temp = np.array(train_label_temp)
190 + data_idx_map = {}
191 +
192 + # proportions = np.random.dirichlet(np.repeat(1, total_edge))
193 + # proportions = np.cumsum(proportions)
194 + # idx_batch = [[] for _ in range(total_edge)]
195 + # prev = 0.0
196 + # for j in range(total_edge):
197 + # idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))]
198 + # prev = proportions[j]
199 + # np.random.shuffle(idx_batch[j])
200 + # data_idx_map[j] = idx_batch[j]
201 +
202 + while min_size < 512:
203 + idx_batch = [[] for _ in range(total_edge)]
204 + # for each class in the dataset
205 + for k in range(output_class_num):
206 + idx_k = np.where(train_label_temp == k)[0]
207 + np.random.shuffle(idx_k)
208 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
209 + ## Balance
210 + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
211 + proportions = proportions/proportions.sum()
212 + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
213 + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
214 + min_size = min([len(idx_j) for idx_j in idx_batch])
215 +
216 + for j in range(total_edge):
217 + np.random.shuffle(idx_batch[j])
218 + data_idx_map[j] = idx_batch[j]
219 +
220 + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map)
221 +
222 + return CanDataset(csv, train_datum), data_idx_map, net_data_count, CanDataset(csv, test_datum, False)
223 +
224 +
225 +class CanDataset(Dataset):
226 +
227 + def __init__(self, csv, datum, is_train=True):
228 + self.csv = csv
229 + self.datum = datum
230 + self.is_train = is_train
231 + if self.is_train:
232 + self.idx_map = []
233 + else:
234 + self.idx_map = [idx for idx in range(len(self.datum))]
235 +
236 + def __len__(self):
237 + return len(self.idx_map)
238 +
239 + def set_idx_map(self, data_idx_map):
240 + self.idx_map = data_idx_map
241 +
242 + def __getitem__(self, idx):
243 + start_i = self.datum[self.idx_map[idx]][0]
244 + if self.is_train:
245 + is_regular = self.datum[self.idx_map[idx]][1]
246 + l = np.zeros((const.CAN_FRAME_LEN, const.CAN_DATA_LEN))
247 + '''
248 + 각 바이트 값은 모두 normalized 된다.
249 + 0 ~ 255 -> 0.0 ~ 1.0
250 + '''
251 + for i in range(const.CAN_FRAME_LEN):
252 + data_len = self.csv.iloc[start_i + i, 1]
253 + for j in range(data_len):
254 + k = int(self.csv.iloc[start_i + i, 2 + j], 16) / 255.0
255 + l[i][j] = k
256 + l = np.reshape(l, (1, const.CAN_FRAME_LEN, const.CAN_DATA_LEN))
257 + else:
258 + l = np.zeros((const.CAN_DATA_LEN))
259 + data_len = self.csv.iloc[start_i, 1]
260 + is_regular = self.csv.iloc[start_i, data_len + 2] == 'R'
261 + if is_regular:
262 + is_regular = 1
263 + else:
264 + is_regular = 0
265 + for j in range(data_len):
266 + k = int(self.csv.iloc[start_i, 2 + j], 16) / 255.0
267 + l[j] = k
268 + l = np.reshape(l, (1, const.CAN_DATA_LEN))
269 +
270 + return (l, is_regular)
271 +
272 +
273 +def GetCanDatasetCNN(total_edge, fold_num, csv_path, txt_path):
274 + csv = pd.read_csv(csv_path)
275 + txt = open(txt_path, "r")
276 + lines = txt.read().splitlines()
277 +
278 + idx = 0
279 + datum = []
280 + label_temp = []
281 + while idx < len(csv) // 2:
282 + line = lines[idx]
283 + if not line:
284 + break
285 +
286 + if line.split(' ')[1] == 'R':
287 + datum.append((idx, 1))
288 + label_temp.append(1)
289 + else:
290 + datum.append((idx, 0))
291 + label_temp.append(0)
292 +
293 + idx += 1
294 + if (idx % 1000000 == 0):
295 + print(idx)
296 +
297 + fold_length = int(len(label_temp) / 5)
298 + train_datum = []
299 + train_label_temp = []
300 + for i in range(5):
301 + if i != fold_num:
302 + train_datum += datum[i*fold_length:(i+1)*fold_length]
303 + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length]
304 + else:
305 + test_datum = datum[i*fold_length:(i+1)*fold_length]
306 +
307 +
308 + N = len(train_label_temp)
309 + train_label_temp = np.array(train_label_temp)
310 +
311 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
312 + proportions = np.cumsum(proportions)
313 + idx_batch = [[] for _ in range(total_edge)]
314 + data_idx_map = {}
315 + prev = 0.0
316 + for j in range(total_edge):
317 + idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))]
318 + prev = proportions[j]
319 + data_idx_map[j] = idx_batch[j]
320 +
321 + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map)
322 +
323 + return CanDatasetCNN(csv, train_datum), data_idx_map, net_data_count, CanDatasetCNN(csv, test_datum, False)
324 +
325 +
326 +class CanDatasetCNN(Dataset):
327 +
328 + def __init__(self, csv, datum, is_train=True):
329 + self.csv = csv
330 + self.datum = datum
331 + if is_train:
332 + self.idx_map = []
333 + else:
334 + self.idx_map = [idx for idx in range(len(self.datum))]
335 +
336 + def __len__(self):
337 + return len(self.idx_map)
338 +
339 + def set_idx_map(self, data_idx_map):
340 + self.idx_map = data_idx_map
341 +
342 + def __getitem__(self, idx):
343 + start_i = self.datum[self.idx_map[idx]][0]
344 + is_regular = self.datum[self.idx_map[idx]][1]
345 +
346 + packet = np.zeros((const.CNN_FRAME_LEN, const.CNN_FRAME_LEN))
347 + for i in range(const.CNN_FRAME_LEN):
348 + data_len = self.csv.iloc[start_i + i, 1]
349 + for j in range(data_len):
350 + k = int(self.csv.iloc[start_i + i, 2 + j], 16) / 255.0
351 + packet[i][j] = k
352 + packet = np.reshape(packet, (1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN))
353 + return (packet, is_regular)
354 +
355 +
356 +def unpack_bits(x, num_bits):
357 + """
358 + Args:
359 + x (int): bit로 변환할 정수
360 + num_bits (int): 표현할 비트수
361 + """
362 + xshape = list(x.shape)
363 + x = x.reshape([-1, 1])
364 + mask = 2**np.arange(num_bits).reshape([1, num_bits])
365 + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])
...\ No newline at end of file ...\ No newline at end of file
1 -import utils
2 import copy 1 import copy
2 +import argparse
3 +import time
4 +import math
5 +import numpy as np
6 +import os
3 from collections import OrderedDict 7 from collections import OrderedDict
8 +import torch
9 +import torch.optim as optim
10 +import torch.nn as nn
4 11
5 import model 12 import model
13 +import utils
6 import dataset 14 import dataset
7 15
16 +# for google colab reload
8 import importlib 17 import importlib
9 -importlib.reload(utils)
10 importlib.reload(model) 18 importlib.reload(model)
19 +importlib.reload(utils)
11 importlib.reload(dataset) 20 importlib.reload(dataset)
12 21
13 -from utils import * 22 +
23 +## paramter
24 +
25 +# shared
26 +criterion = nn.CrossEntropyLoss()
27 +C = 0.1
28 +#
29 +
30 +# prox
31 +mu = 0.001
32 +#
33 +
34 +# time weight
35 +twa_exp = 1.1
36 +#
37 +
38 +# dynamic weight
39 +H = 0.5
40 +P = 0.1
41 +G = 0.1
42 +R = 0.1
43 +alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0
44 +#
45 +
46 +## end
14 47
15 48
16 def add_args(parser): 49 def add_args(parser):
17 - # parser.add_argument('--model', type=str, default='moderate-cnn', 50 + parser.add_argument('--packet_num', type=int, default=1,
18 - # help='neural network used in training') 51 + help='packet number used in training, 1 ~ 3')
19 - parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', 52 + parser.add_argument('--dataset', type=str, default='can',
20 - help='dataset used for training') 53 + help='dataset used for training, can or syncan')
21 parser.add_argument('--fold_num', type=int, default=0, 54 parser.add_argument('--fold_num', type=int, default=0,
22 help='5-fold, 0 ~ 4') 55 help='5-fold, 0 ~ 4')
23 - parser.add_argument('--batch_size', type=int, default=256, metavar='N', 56 + parser.add_argument('--batch_size', type=int, default=128,
24 help='input batch size for training') 57 help='input batch size for training')
25 - parser.add_argument('--lr', type=float, default=0.002, metavar='LR', 58 + parser.add_argument('--lr', type=float, default=0.001,
26 help='learning rate') 59 help='learning rate')
27 - parser.add_argument('--n_nets', type=int, default=100, metavar='NN', 60 + parser.add_argument('--n_nets', type=int, default=100,
28 help='number of workers in a distributed cluster') 61 help='number of workers in a distributed cluster')
29 - parser.add_argument('--comm_type', type=str, default='fedtwa', 62 + parser.add_argument('--comm_type', type=str, default='fedprox',
30 - help='which type of communication strategy is going to be used: layerwise/blockwise') 63 + help='type of communication, [fedavg, fedprox, fedtwa, feddw, edge]')
31 - parser.add_argument('--comm_round', type=int, default=10, 64 + parser.add_argument('--comm_round', type=int, default=50,
32 help='how many round of communications we shoud use') 65 help='how many round of communications we shoud use')
66 + parser.add_argument('--weight_save_path', type=str, default='./weights',
67 + help='model weight save path')
33 args = parser.parse_args(args=[]) 68 args = parser.parse_args(args=[])
34 return args 69 return args
35 70
36 71
72 +def test_model(fed_model, args, testloader, device):
73 + fed_model.to(device)
74 + fed_model.eval()
75 +
76 + cnt = 0
77 + step_acc = 0.0
78 + with torch.no_grad():
79 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
80 + for i, (inputs, labels) in enumerate(testloader):
81 + inputs, labels = inputs.to(device), labels.to(device)
82 +
83 + outputs, packet_state = fed_model(inputs, packet_state)
84 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
85 +
86 + _, preds = torch.max(outputs, 1)
87 +
88 + cnt += inputs.shape[0]
89 + corr_sum = torch.sum(preds == labels.data)
90 + step_acc += corr_sum.double()
91 + if i % 200 == 0:
92 + print('test [%4d/%4d] acc: %.3f' % (i, len(testloader), (step_acc / cnt).item()))
93 + # break
94 + fed_accuracy = (step_acc / cnt).item()
95 + print('test acc', fed_accuracy)
96 + fed_model.to('cpu')
97 + fed_model.train()
98 + torch.save(fed_model.state_dict(), os.path.join(args.weight_save_path, '%s_%d_%.4f.pth' % (args.comm_type, cr, fed_accuracy)))
99 +
100 +
37 def start_fedavg(fed_model, args, 101 def start_fedavg(fed_model, args,
38 train_data_set, 102 train_data_set,
39 data_idx_map, 103 data_idx_map,
...@@ -42,8 +106,7 @@ def start_fedavg(fed_model, args, ...@@ -42,8 +106,7 @@ def start_fedavg(fed_model, args,
42 edges, 106 edges,
43 device): 107 device):
44 print("start fed avg") 108 print("start fed avg")
45 - criterion = nn.CrossEntropyLoss() 109 +
46 - C = 0.1
47 num_edge = int(max(C * args.n_nets, 1)) 110 num_edge = int(max(C * args.n_nets, 1))
48 total_data_count = 0 111 total_data_count = 0
49 for _, data_count in net_data_count.items(): 112 for _, data_count in net_data_count.items():
...@@ -59,20 +122,25 @@ def start_fedavg(fed_model, args, ...@@ -59,20 +122,25 @@ def start_fedavg(fed_model, args,
59 122
60 for edge_progress, edge_index in enumerate(selected_edge): 123 for edge_progress, edge_index in enumerate(selected_edge):
61 train_data_set.set_idx_map(data_idx_map[edge_index]) 124 train_data_set.set_idx_map(data_idx_map[edge_index])
62 - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, 125 + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size)
63 - shuffle=True, num_workers=2) 126 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler,
127 + shuffle=False, num_workers=2, drop_last=True)
64 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) 128 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
65 129
66 edges[edge_index] = copy.deepcopy(fed_model) 130 edges[edge_index] = copy.deepcopy(fed_model)
67 edges[edge_index].to(device) 131 edges[edge_index].to(device)
68 edges[edge_index].train() 132 edges[edge_index].train()
69 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) 133 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
134 +
70 # train 135 # train
136 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
71 for data_idx, (inputs, labels) in enumerate(train_loader): 137 for data_idx, (inputs, labels) in enumerate(train_loader):
72 - inputs, labels = inputs.float().to(device), labels.long().to(device) 138 + inputs, labels = inputs.to(device), labels.to(device)
139 +
140 + edge_pred, packet_state = edges[edge_index](inputs, packet_state)
141 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
73 142
74 edge_opt.zero_grad() 143 edge_opt.zero_grad()
75 - edge_pred = edges[edge_index](inputs)
76 144
77 edge_loss = criterion(edge_pred, labels) 145 edge_loss = criterion(edge_pred, labels)
78 edge_loss.backward() 146 edge_loss.backward()
...@@ -90,39 +158,13 @@ def start_fedavg(fed_model, args, ...@@ -90,39 +158,13 @@ def start_fedavg(fed_model, args,
90 local_state = edge.state_dict() 158 local_state = edge.state_dict()
91 for key in fed_model.state_dict().keys(): 159 for key in fed_model.state_dict().keys():
92 if k == 0: 160 if k == 0:
93 - update_state[key] = local_state[key] * net_data_count[k] / total_data_count 161 + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count)
94 else: 162 else:
95 - update_state[key] += local_state[key] * net_data_count[k] / total_data_count 163 + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count)
96 - 164 +
97 fed_model.load_state_dict(update_state) 165 fed_model.load_state_dict(update_state)
98 if cr % 10 == 0: 166 if cr % 10 == 0:
99 - fed_model.to(device) 167 + test_model(fed_model, args, testloader, device)
100 - fed_model.eval()
101 -
102 - total_loss = 0.0
103 - cnt = 0
104 - step_acc = 0.0
105 - with torch.no_grad():
106 - for i, data in enumerate(testloader):
107 - inputs, labels = data
108 - inputs, labels = inputs.float().to(device), labels.long().to(device)
109 -
110 - outputs = fed_model(inputs)
111 - _, preds = torch.max(outputs, 1)
112 -
113 - loss = criterion(outputs, labels)
114 - cnt += inputs.shape[0]
115 -
116 - corr_sum = torch.sum(preds == labels.data)
117 - step_acc += corr_sum.double()
118 - running_loss = loss.item() * inputs.shape[0]
119 - total_loss += running_loss
120 - if i % 200 == 0:
121 - print('test [%4d] loss: %.3f' % (i, loss.item()))
122 - # break
123 - print((step_acc / cnt).data)
124 - print(total_loss / cnt)
125 - fed_model.to('cpu')
126 168
127 169
128 def start_fedprox(fed_model, args, 170 def start_fedprox(fed_model, args,
...@@ -131,9 +173,7 @@ def start_fedprox(fed_model, args, ...@@ -131,9 +173,7 @@ def start_fedprox(fed_model, args,
131 testloader, 173 testloader,
132 device): 174 device):
133 print("start fed prox") 175 print("start fed prox")
134 - criterion = nn.CrossEntropyLoss() 176 +
135 - mu = 0.001
136 - C = 0.1
137 num_edge = int(max(C * args.n_nets, 1)) 177 num_edge = int(max(C * args.n_nets, 1))
138 fed_model.to(device) 178 fed_model.to(device)
139 179
...@@ -149,25 +189,26 @@ def start_fedprox(fed_model, args, ...@@ -149,25 +189,26 @@ def start_fedprox(fed_model, args,
149 selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) 189 selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
150 print("selected edge", selected_edge) 190 print("selected edge", selected_edge)
151 191
152 - total_data_length = 0
153 - edge_data_len = []
154 for edge_progress, edge_index in enumerate(selected_edge): 192 for edge_progress, edge_index in enumerate(selected_edge):
155 train_data_set.set_idx_map(data_idx_map[edge_index]) 193 train_data_set.set_idx_map(data_idx_map[edge_index])
156 - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, 194 + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size)
157 - shuffle=True, num_workers=2) 195 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler,
196 + shuffle=False, num_workers=2, drop_last=True)
158 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) 197 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
159 - total_data_length += len(train_data_set)
160 - edge_data_len.append(len(train_data_set))
161 198
162 edge_model = copy.deepcopy(fed_model) 199 edge_model = copy.deepcopy(fed_model)
163 edge_model.to(device) 200 edge_model.to(device)
201 + edge_model.train()
164 edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr) 202 edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr)
165 # train 203 # train
204 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
166 for data_idx, (inputs, labels) in enumerate(train_loader): 205 for data_idx, (inputs, labels) in enumerate(train_loader):
167 - inputs, labels = inputs.float().to(device), labels.long().to(device) 206 + inputs, labels = inputs.to(device), labels.to(device)
207 +
208 + edge_pred, packet_state = edge_model(inputs, packet_state)
209 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
168 210
169 edge_opt.zero_grad() 211 edge_opt.zero_grad()
170 - edge_pred = edge_model(inputs)
171 212
172 edge_loss = criterion(edge_pred, labels) 213 edge_loss = criterion(edge_pred, labels)
173 # prox term 214 # prox term
...@@ -196,30 +237,8 @@ def start_fedprox(fed_model, args, ...@@ -196,30 +237,8 @@ def start_fedprox(fed_model, args,
196 fed_model.to(device) 237 fed_model.to(device)
197 238
198 if cr % 10 == 0: 239 if cr % 10 == 0:
199 - fed_model.eval() 240 + test_model(fed_model, args, testloader, device)
200 - total_loss = 0.0 241 + fed_model.to(device)
201 - cnt = 0
202 - step_acc = 0.0
203 - with torch.no_grad():
204 - for i, data in enumerate(testloader):
205 - inputs, labels = data
206 - inputs, labels = inputs.float().to(device), labels.long().to(device)
207 -
208 - outputs = fed_model(inputs)
209 - _, preds = torch.max(outputs, 1)
210 -
211 - loss = criterion(outputs, labels)
212 - cnt += inputs.shape[0]
213 -
214 - corr_sum = torch.sum(preds == labels.data)
215 - step_acc += corr_sum.double()
216 - running_loss = loss.item() * inputs.shape[0]
217 - total_loss += running_loss
218 - if i % 200 == 0:
219 - print('test [%4d] loss: %.3f' % (i, loss.item()))
220 - # break
221 - print((step_acc / cnt).data)
222 - print(total_loss / cnt)
223 242
224 243
225 def start_fedtwa(fed_model, args, 244 def start_fedtwa(fed_model, args,
...@@ -231,10 +250,8 @@ def start_fedtwa(fed_model, args, ...@@ -231,10 +250,8 @@ def start_fedtwa(fed_model, args,
231 device): 250 device):
232 # TEFL, without asynchronous model update 251 # TEFL, without asynchronous model update
233 print("start fed temporally weighted aggregation") 252 print("start fed temporally weighted aggregation")
234 - criterion = nn.CrossEntropyLoss() 253 +
235 time_stamp = [0 for worker in range(args.n_nets)] 254 time_stamp = [0 for worker in range(args.n_nets)]
236 - twa_exp = math.e / 2.0
237 - C = 0.1
238 num_edge = int(max(C * args.n_nets, 1)) 255 num_edge = int(max(C * args.n_nets, 1))
239 total_data_count = 0 256 total_data_count = 0
240 for _, data_count in net_data_count.items(): 257 for _, data_count in net_data_count.items():
...@@ -251,20 +268,25 @@ def start_fedtwa(fed_model, args, ...@@ -251,20 +268,25 @@ def start_fedtwa(fed_model, args,
251 for edge_progress, edge_index in enumerate(selected_edge): 268 for edge_progress, edge_index in enumerate(selected_edge):
252 time_stamp[edge_index] = cr 269 time_stamp[edge_index] = cr
253 train_data_set.set_idx_map(data_idx_map[edge_index]) 270 train_data_set.set_idx_map(data_idx_map[edge_index])
254 - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, 271 + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size)
255 - shuffle=True, num_workers=2) 272 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler,
273 + shuffle=False, num_workers=2, drop_last=True)
256 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) 274 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
257 275
258 edges[edge_index] = copy.deepcopy(fed_model) 276 edges[edge_index] = copy.deepcopy(fed_model)
259 edges[edge_index].to(device) 277 edges[edge_index].to(device)
260 edges[edge_index].train() 278 edges[edge_index].train()
261 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) 279 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
280 +
262 # train 281 # train
282 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
263 for data_idx, (inputs, labels) in enumerate(train_loader): 283 for data_idx, (inputs, labels) in enumerate(train_loader):
264 inputs, labels = inputs.float().to(device), labels.long().to(device) 284 inputs, labels = inputs.float().to(device), labels.long().to(device)
265 285
286 + edge_pred, packet_state = edges[edge_index](inputs, packet_state)
287 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
288 +
266 edge_opt.zero_grad() 289 edge_opt.zero_grad()
267 - edge_pred = edges[edge_index](inputs)
268 290
269 edge_loss = criterion(edge_pred, labels) 291 edge_loss = criterion(edge_pred, labels)
270 edge_loss.backward() 292 edge_loss.backward()
...@@ -277,44 +299,19 @@ def start_fedtwa(fed_model, args, ...@@ -277,44 +299,19 @@ def start_fedtwa(fed_model, args,
277 edges[edge_index].to('cpu') 299 edges[edge_index].to('cpu')
278 300
279 # cal weight using time stamp 301 # cal weight using time stamp
302 + # in paper, cr - time_stamp[k] used, but error is high
280 update_state = OrderedDict() 303 update_state = OrderedDict()
281 for k, edge in enumerate(edges): 304 for k, edge in enumerate(edges):
282 local_state = edge.state_dict() 305 local_state = edge.state_dict()
283 for key in fed_model.state_dict().keys(): 306 for key in fed_model.state_dict().keys():
284 if k == 0: 307 if k == 0:
285 - update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) 308 + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr -2 - time_stamp[k]))
286 else: 309 else:
287 - update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) 310 + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr -2 - time_stamp[k]))
288 311
289 fed_model.load_state_dict(update_state) 312 fed_model.load_state_dict(update_state)
290 if cr % 10 == 0: 313 if cr % 10 == 0:
291 - fed_model.to(device) 314 + test_model(fed_model, args, testloader, device)
292 - fed_model.eval()
293 -
294 - total_loss = 0.0
295 - cnt = 0
296 - step_acc = 0.0
297 - with torch.no_grad():
298 - for i, data in enumerate(testloader):
299 - inputs, labels = data
300 - inputs, labels = inputs.float().to(device), labels.long().to(device)
301 -
302 - outputs = fed_model(inputs)
303 - _, preds = torch.max(outputs, 1)
304 -
305 - loss = criterion(outputs, labels)
306 - cnt += inputs.shape[0]
307 -
308 - corr_sum = torch.sum(preds == labels.data)
309 - step_acc += corr_sum.double()
310 - running_loss = loss.item() * inputs.shape[0]
311 - total_loss += running_loss
312 - if i % 200 == 0:
313 - print('test [%4d] loss: %.3f' % (i, loss.item()))
314 - # break
315 - print((step_acc / cnt).data)
316 - print(total_loss / cnt)
317 - fed_model.to('cpu')
318 315
319 316
320 def start_feddw(fed_model, args, 317 def start_feddw(fed_model, args,
...@@ -326,13 +323,8 @@ def start_feddw(fed_model, args, ...@@ -326,13 +323,8 @@ def start_feddw(fed_model, args,
326 edges, 323 edges,
327 device): 324 device):
328 print("start fed Node-aware Dynamic Weighting") 325 print("start fed Node-aware Dynamic Weighting")
326 +
329 worker_selected_frequency = [0 for worker in range(args.n_nets)] 327 worker_selected_frequency = [0 for worker in range(args.n_nets)]
330 - criterion = nn.CrossEntropyLoss()
331 - H = 0.5
332 - P = 0.5
333 - G = 0.1
334 - R = 0.1
335 - alpha, beta, gamma = 30.0/100.0, 50.0/100.0, 20.0/100.0
336 num_edge = int(max(G * args.n_nets, 1)) 328 num_edge = int(max(G * args.n_nets, 1))
337 329
338 # cal data weight for selecting participants 330 # cal data weight for selecting participants
...@@ -382,20 +374,25 @@ def start_feddw(fed_model, args, ...@@ -382,20 +374,25 @@ def start_feddw(fed_model, args,
382 for edge_progress, edge_index in enumerate(selected_edge): 374 for edge_progress, edge_index in enumerate(selected_edge):
383 worker_selected_frequency[edge_index] += 1 375 worker_selected_frequency[edge_index] += 1
384 train_data_set.set_idx_map(data_idx_map[edge_index]) 376 train_data_set.set_idx_map(data_idx_map[edge_index])
385 - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, 377 + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size)
386 - shuffle=True, num_workers=2) 378 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler,
379 + shuffle=False, num_workers=2, drop_last=True)
387 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) 380 print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
388 381
389 edges[edge_index] = copy.deepcopy(fed_model) 382 edges[edge_index] = copy.deepcopy(fed_model)
390 edges[edge_index].to(device) 383 edges[edge_index].to(device)
391 edges[edge_index].train() 384 edges[edge_index].train()
392 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) 385 edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
386 +
393 # train 387 # train
388 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
394 for data_idx, (inputs, labels) in enumerate(train_loader): 389 for data_idx, (inputs, labels) in enumerate(train_loader):
395 inputs, labels = inputs.float().to(device), labels.long().to(device) 390 inputs, labels = inputs.float().to(device), labels.long().to(device)
396 391
392 + edge_pred, packet_state = edges[edge_index](inputs, packet_state)
393 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
394 +
397 edge_opt.zero_grad() 395 edge_opt.zero_grad()
398 - edge_pred = edges[edge_index](inputs)
399 396
400 edge_loss = criterion(edge_pred, labels) 397 edge_loss = criterion(edge_pred, labels)
401 edge_loss.backward() 398 edge_loss.backward()
...@@ -408,17 +405,20 @@ def start_feddw(fed_model, args, ...@@ -408,17 +405,20 @@ def start_feddw(fed_model, args,
408 405
409 # get edge accuracy using subset of testset 406 # get edge accuracy using subset of testset
410 edges[edge_index].eval() 407 edges[edge_index].eval()
411 - print("[%2d/%2d] edge: %d, cal accuracy" % (edge_progress, len(selected_edge), edge_index)) 408 + print("[%2d/%2d] edge: %d, cal local accuracy" % (edge_progress, len(selected_edge), edge_index))
412 cnt = 0 409 cnt = 0
413 step_acc = 0.0 410 step_acc = 0.0
414 with torch.no_grad(): 411 with torch.no_grad():
412 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
415 for inputs, labels in local_test_loader: 413 for inputs, labels in local_test_loader:
416 inputs, labels = inputs.float().to(device), labels.long().to(device) 414 inputs, labels = inputs.float().to(device), labels.long().to(device)
417 415
418 - outputs = edges[edge_index](inputs) 416 + edge_pred, packet_state = edges[edge_index](inputs, packet_state)
419 - _, preds = torch.max(outputs, 1) 417 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
418 +
419 + _, preds = torch.max(edge_pred, 1)
420 420
421 - loss = criterion(outputs, labels) 421 + loss = criterion(edge_pred, labels)
422 cnt += inputs.shape[0] 422 cnt += inputs.shape[0]
423 423
424 corr_sum = torch.sum(preds == labels.data) 424 corr_sum = torch.sum(preds == labels.data)
...@@ -426,7 +426,7 @@ def start_feddw(fed_model, args, ...@@ -426,7 +426,7 @@ def start_feddw(fed_model, args,
426 # break 426 # break
427 427
428 worker_local_accuracy[edge_index] = (step_acc / cnt).item() 428 worker_local_accuracy[edge_index] = (step_acc / cnt).item()
429 - print(worker_local_accuracy[edge_index]) 429 + print('edge local accuracy', worker_local_accuracy[edge_index])
430 edges[edge_index].to('cpu') 430 edges[edge_index].to('cpu')
431 431
432 # cal weight dynamically 432 # cal weight dynamically
...@@ -449,59 +449,123 @@ def start_feddw(fed_model, args, ...@@ -449,59 +449,123 @@ def start_feddw(fed_model, args,
449 449
450 fed_model.load_state_dict(update_state) 450 fed_model.load_state_dict(update_state)
451 if cr % 10 == 0: 451 if cr % 10 == 0:
452 - fed_model.to(device) 452 + test_model(fed_model, args, testloader, device)
453 - fed_model.eval()
454 -
455 - total_loss = 0.0
456 - cnt = 0
457 - step_acc = 0.0
458 - with torch.no_grad():
459 - for i, data in enumerate(testloader):
460 - inputs, labels = data
461 - inputs, labels = inputs.float().to(device), labels.long().to(device)
462 453
463 - outputs = fed_model(inputs)
464 - _, preds = torch.max(outputs, 1)
465 454
466 - loss = criterion(outputs, labels) 455 +def start_only_edge(args,
467 - cnt += inputs.shape[0] 456 + train_data_set,
457 + data_idx_map,
458 + testloader,
459 + edges,
460 + device):
461 + print("start only edge")
462 + total_epoch = int(args.comm_round * C)
468 463
469 - corr_sum = torch.sum(preds == labels.data) 464 + for cr in range(1, total_epoch + 1):
470 - step_acc += corr_sum.double() 465 + print("Edge round : %d" % (cr))
471 - running_loss = loss.item() * inputs.shape[0] 466 + edge_accuracy_list = []
472 - total_loss += running_loss 467 +
473 - if i % 200 == 0: 468 + for edge_index, edge_model in enumerate(edges):
474 - print('test [%4d] loss: %.3f' % (i, loss.item())) 469 + train_data_set.set_idx_map(data_idx_map[edge_index])
470 + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size)
471 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler,
472 + shuffle=False, num_workers=2, drop_last=True)
473 + print("edge[%2d/%2d] data len: %d" % (edge_index, len(edges), len(train_data_set)))
474 +
475 + edge_model.to(device)
476 + edge_model.train()
477 + edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr)
478 +
479 + # train
480 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
481 + for data_idx, (inputs, labels) in enumerate(train_loader):
482 + inputs, labels = inputs.float().to(device), labels.long().to(device)
483 +
484 + edge_pred, packet_state = edge_model(inputs, packet_state)
485 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
486 +
487 + edge_opt.zero_grad()
488 +
489 + edge_loss = criterion(edge_pred, labels)
490 + edge_loss.backward()
491 +
492 + edge_opt.step()
493 + edge_loss = edge_loss.item()
494 + if data_idx % 100 == 0:
495 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
475 # break 496 # break
476 - print((step_acc / cnt).data) 497 +
477 - print(total_loss / cnt) 498 + # test
478 - fed_model.to('cpu') 499 + # if cr < 4:
500 + # continue
501 + edge_model.eval()
502 + total_loss = 0.0
503 + cnt = 0
504 + step_acc = 0.0
505 + with torch.no_grad():
506 + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device)
507 + for i, (inputs, labels) in enumerate(testloader):
508 + inputs, labels = inputs.float().to(device), labels.long().to(device)
509 +
510 + outputs, packet_state = edge_model(inputs, packet_state)
511 + packet_state = torch.autograd.Variable(packet_state, requires_grad=False)
512 +
513 + _, preds = torch.max(outputs, 1)
514 +
515 + loss = criterion(outputs, labels)
516 + cnt += inputs.shape[0]
517 +
518 + corr_sum = torch.sum(preds == labels.data)
519 + step_acc += corr_sum.double()
520 + running_loss = loss.item() * inputs.shape[0]
521 + total_loss += running_loss
522 + if i % 200 == 0:
523 + print('test [%4d] loss: %.3f' % (i, loss.item()))
524 + # break
525 + edge_accuracy = (step_acc / cnt).item()
526 + edge_accuracy_list.append(edge_accuracy)
527 + print("edge[%2d/%2d] acc: %.4f" % (edge_index, len(edges), edge_accuracy))
528 + edge_model.to('cpu')
529 +
530 + # if cr < 4:
531 + # continue
532 + edge_accuracy_avg = sum(edge_accuracy_list) / len(edge_accuracy_list)
533 + torch.save(edges[0].state_dict(), os.path.join(weight_path, 'edge_%d_%.4f.pth' % (cr, edge_accuracy_avg)))
534 +
479 535
480 536
481 def start_train(): 537 def start_train():
482 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 538 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
483 - print(device) 539 + print('device:', device)
540 +
484 args = add_args(argparse.ArgumentParser()) 541 args = add_args(argparse.ArgumentParser())
485 542
543 + # make weight folder
544 + os.makedirs(args.weight_save_path, exist_ok=True)
545 +
546 + # for reproductivity
486 seed = 0 547 seed = 0
487 np.random.seed(seed) 548 np.random.seed(seed)
488 torch.manual_seed(seed) 549 torch.manual_seed(seed)
489 550
490 print("Loading data...") 551 print("Loading data...")
491 - # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt', 552 + if args.dataset == 'can':
492 - # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt', 553 + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
493 - # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt', 554 + elif args.dataset == 'syncan':
494 - # "./dataset/gear_dataset.csv" : './gear_dataset.txt' 555 + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt")
495 - # } 556 +
496 - kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} 557 + sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size)
497 - train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs) 558 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler,
498 - testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, 559 + shuffle=False, num_workers=2, drop_last=True)
499 - shuffle=False, num_workers=2) 560 +
500 - 561 + if args.dataset == 'can':
501 - fed_model = model.Net() 562 + fed_model = model.OneNet(args.packet_num)
502 - args.comm_type = 'feddw' 563 + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
564 + elif args.dataset == 'syncan':
565 + fed_model = model.OneNet(args.packet_num)
566 + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
567 +
503 if args.comm_type == "fedavg": 568 if args.comm_type == "fedavg":
504 - edges, _, _ = init_models(args.n_nets, args)
505 start_fedavg(fed_model, args, 569 start_fedavg(fed_model, args,
506 train_data_set, 570 train_data_set,
507 data_idx_map, 571 data_idx_map,
...@@ -516,7 +580,6 @@ def start_train(): ...@@ -516,7 +580,6 @@ def start_train():
516 testloader, 580 testloader,
517 device) 581 device)
518 elif args.comm_type == "fedtwa": 582 elif args.comm_type == "fedtwa":
519 - edges, _, _ = init_models(args.n_nets, args)
520 start_fedtwa(fed_model, args, 583 start_fedtwa(fed_model, args,
521 train_data_set, 584 train_data_set,
522 data_idx_map, 585 data_idx_map,
...@@ -526,14 +589,14 @@ def start_train(): ...@@ -526,14 +589,14 @@ def start_train():
526 device) 589 device)
527 elif args.comm_type == "feddw": 590 elif args.comm_type == "feddw":
528 local_test_set = copy.deepcopy(test_data_set) 591 local_test_set = copy.deepcopy(test_data_set)
529 - # mnist train 60,000 / test 10,000 / 1,000 592 + # in paper, mnist train 60,000 / test 10,000 / 1,000 - 10%
530 - # CAN train ~ 13,000,000 / test 2,000,000 / for speed 40,000 593 + # CAN train ~ 1,400,000 / test 300,000 / for speed 15,000 - 5%
531 - local_test_idx = np.random.choice(len(local_test_set), len(local_test_set) // 50, replace=False) 594 + local_test_idx = [idx for idx in range(0, len(local_test_set) // 20)]
532 local_test_set.set_idx_map(local_test_idx) 595 local_test_set.set_idx_map(local_test_idx)
533 - local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size, 596 + sampler = dataset.BatchIntervalSampler(len(local_test_set), args.batch_size)
534 - shuffle=False, num_workers=2) 597 + local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size, sampler=sampler,
598 + shuffle=False, num_workers=2, drop_last=True)
535 599
536 - edges, _, _ = init_models(args.n_nets, args)
537 start_feddw(fed_model, args, 600 start_feddw(fed_model, args,
538 train_data_set, 601 train_data_set,
539 data_idx_map, 602 data_idx_map,
...@@ -542,6 +605,13 @@ def start_train(): ...@@ -542,6 +605,13 @@ def start_train():
542 local_test_loader, 605 local_test_loader,
543 edges, 606 edges,
544 device) 607 device)
608 + elif args.comm_type == "edge":
609 + start_only_edge(args,
610 + train_data_set,
611 + data_idx_map,
612 + testloader,
613 + edges,
614 + device)
545 615
546 if __name__ == "__main__": 616 if __name__ == "__main__":
547 start_train() 617 start_train()
...\ No newline at end of file ...\ No newline at end of file
......
1 import torch.nn as nn 1 import torch.nn as nn
2 -import torch.nn.functional as F
3 import torch 2 import torch
4 import const 3 import const
5 4
6 -class Net(nn.Module):
7 - def __init__(self):
8 - super(Net, self).__init__()
9 5
10 - self.f1 = nn.Sequential( 6 +STATE_DIM = 8 * 32
11 - nn.Conv2d(1, 2, 3), 7 +class OneNet(nn.Module):
12 - nn.ReLU(True), 8 + def __init__(self, packet_num):
9 + super(OneNet, self).__init__()
10 + IN_DIM = 8 * packet_num # byte
11 + FEATURE_DIM = 32
12 +
13 + # transform the given packet into a tensor which is in a good feature space
14 + self.feature_layer = nn.Sequential(
15 + nn.Linear(IN_DIM, 32),
16 + nn.ReLU(),
17 + nn.Linear(32, FEATURE_DIM),
18 + nn.ReLU()
13 ) 19 )
14 - self.f2 = nn.Sequential( 20 +
15 - nn.Conv2d(2, 4, 3), 21 + # generates the current state 's'
16 - nn.ReLU(True), 22 + self.f = nn.Sequential(
17 - ) 23 + nn.Linear(STATE_DIM + FEATURE_DIM, STATE_DIM),
18 - self.f3 = nn.Sequential( 24 + nn.ReLU(),
19 - nn.Conv2d(4, 8, 3), 25 + nn.Linear(STATE_DIM, STATE_DIM),
20 - nn.ReLU(True), 26 + nn.ReLU()
21 ) 27 )
22 - self.f4 = nn.Sequential( 28 +
23 - nn.Linear(8 * 23 * 23, 2), 29 + # check whether the given packet is malicious
30 + self.g = nn.Sequential(
31 + nn.Linear(STATE_DIM + FEATURE_DIM, 64),
32 + nn.ReLU(),
33 + nn.Linear(64, 64),
34 + nn.ReLU(),
35 + nn.Linear(64, 2),
24 ) 36 )
25 37
26 - def forward(self, x):
27 - x = self.f1(x)
28 - x = self.f2(x)
29 - x = self.f3(x)
30 - x = torch.flatten(x, 1)
31 - x = self.f4(x)
32 - return x
...\ No newline at end of file ...\ No newline at end of file
38 + def forward(self, x, s):
39 + x = self.feature_layer(x)
40 + x = torch.cat((x, s), 1)
41 + s2 = self.f(x)
42 + x2 = self.g(x)
43 +
44 + return x2, s2
...\ No newline at end of file ...\ No newline at end of file
......
1 +import tensorrt as trt
2 +
3 +onnx_file_name = 'bert.onnx'
4 +tensorrt_file_name = 'bert.plan'
5 +fp16_mode = True
6 +# int8_mode = True
7 +TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
8 +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
9 +
10 +builder = trt.Builder(TRT_LOGGER)
11 +network = builder.create_network(EXPLICIT_BATCH)
12 +parser = trt.OnnxParser(network, TRT_LOGGER)
13 +
14 +builder.max_workspace_size = (1 << 30)
15 +builder.fp16_mode = fp16_mode
16 +# builder.int8_mode = int8_mode
17 +
18 +with open(onnx_file_name, 'rb') as model:
19 + if not parser.parse(model.read()):
20 + for error in range(parser.num_errors):
21 + print (parser.get_error(error))
22 +
23 +# for int8 mode
24 +# print(network.num_layers, network.num_inputs , network.num_outputs)
25 +# for layer_index in range(network.num_layers):
26 +# layer = network[layer_index]
27 +# print(layer.name)
28 +# tensor = layer.get_output(0)
29 +# print(tensor.name)
30 +# tensor.dynamic_range = (0, 255)
31 +
32 + # input_tensor = layer.get_input(0)
33 + # print(input_tensor)
34 + # input_tensor.dynamic_range = (0, 255)
35 +
36 +engine = builder.build_cuda_engine(network)
37 +buf = engine.serialize()
38 +with open(tensorrt_file_name, 'wb') as f:
39 + f.write(buf)
40 +
41 +print('done, trt model')
...\ No newline at end of file ...\ No newline at end of file
1 +import tensorrt as trt
2 +import pycuda.driver as cuda
3 +import numpy as np
4 +import torch
5 +import pycuda.autoinit
6 +import dataset
7 +import model
8 +import time
9 +# print(dir(trt))
10 +
11 +tensorrt_file_name = 'bert.plan'
12 +TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
13 +trt_runtime = trt.Runtime(TRT_LOGGER)
14 +
15 +with open(tensorrt_file_name, 'rb') as f:
16 + engine_data = f.read()
17 +engine = trt_runtime.deserialize_cuda_engine(engine_data)
18 +context = engine.create_execution_context()
19 +
20 +# class HostDeviceMem(object):
21 +# def __init__(self, host_mem, device_mem):
22 +# self.host = host_mem
23 +# self.device = device_mem
24 +
25 +# def __str__(self):
26 +# return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
27 +
28 +# def __repr__(self):
29 +# return self.__str__()
30 +
31 +# inputs, outputs, bindings, stream = [], [], [], []
32 +# for binding in engine:
33 +# size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
34 +# dtype = trt.nptype(engine.get_binding_dtype(binding))
35 +# host_mem = cuda.pagelocked_empty(size, dtype)
36 +# device_mem = cuda.mem_alloc(host_mem.nbytes)
37 +# bindings.append(int(device_mem))
38 +# if engine.binding_is_input(binding):
39 +# inputs.append( HostDeviceMem(host_mem, device_mem) )
40 +# else:
41 +# outputs.append(HostDeviceMem(host_mem, device_mem))
42 +
43 +# input_ids = np.ones([1, 1, 29, 29])
44 +
45 +# numpy_array_input = [input_ids]
46 +# hosts = [input.host for input in inputs]
47 +# trt_types = [trt.int32]
48 +
49 +# for numpy_array, host, trt_types in zip(numpy_array_input, hosts, trt_types):
50 +# numpy_array = np.asarray(numpy_array).ravel()
51 +# np.copyto(host, numpy_array)
52 +
53 +# def do_inference(context, bindings, inputs, outputs, stream):
54 +# [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
55 +# context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
56 +# [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
57 +# stream.synchronize()
58 +# return [out.host for out in outputs]
59 +
60 +# trt_outputs = do_inference(
61 +# context=context,
62 +# bindings=bindings,
63 +# inputs=inputs,
64 +# outputs=outputs,
65 +# stream=stream)
66 +
67 +def infer(context, input_img, output_size, batch_size):
68 + # Load engine
69 + # engine = context.get_engine()
70 + # assert(engine.get_nb_bindings() == 2)
71 + # Convert input data to float32
72 + input_img = input_img.astype(np.float32)
73 + # Create host buffer to receive data
74 + output = np.empty(output_size, dtype = np.float32)
75 + # Allocate device memory
76 + d_input = cuda.mem_alloc(batch_size * input_img.size * input_img.dtype.itemsize)
77 + d_output = cuda.mem_alloc(batch_size * output.size * output.dtype.itemsize)
78 +
79 + bindings = [int(d_input), int(d_output)]
80 + stream = cuda.Stream()
81 + # Transfer input data to device
82 + cuda.memcpy_htod_async(d_input, input_img, stream)
83 + # Execute model
84 + context.execute_async(batch_size, bindings, stream.handle, None)
85 + # Transfer predictions back
86 + cuda.memcpy_dtoh_async(output, d_output, stream)
87 + # Synchronize threads
88 + stream.synchronize()
89 + # Return predictions
90 + return output
91 +
92 +
93 +# kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
94 +# train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(100, 0, **kwargs)
95 +# testloader = torch.utils.data.DataLoader(test_data_set, batch_size=256,
96 +# shuffle=False, num_workers=2)
97 +
98 +check_time = time.time()
99 +cnt = 0
100 +temp = np.ones([256, 1, 29, 29])
101 +for idx in range(100):
102 +# for i, (inputs, labels) in enumerate(testloader):
103 + trt_outputs = infer(context, temp, (256, 2), 256)
104 +
105 + print(trt_outputs.shape)
106 + # print(trt_outputs)
107 + # print(np.argmax(trt_outputs, axis=0))
108 + # cnt += 1
109 + # if cnt == 100:
110 + # break
111 +print(time.time() - check_time)
112 +
113 +
114 +tensorrt_file_name = 'bert_int.plan'
115 +TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
116 +trt_runtime = trt.Runtime(TRT_LOGGER)
117 +
118 +with open(tensorrt_file_name, 'rb') as f:
119 + engine_data = f.read()
120 +engine = trt_runtime.deserialize_cuda_engine(engine_data)
121 +context = engine.create_execution_context()
122 +check_time = time.time()
123 +cnt = 0
124 +temp = np.ones([256, 1, 29, 29])
125 +for idx in range(100):
126 +# for i, (inputs, labels) in enumerate(testloader):
127 + trt_outputs = infer(context, temp, (256, 2), 256)
128 +
129 + print(trt_outputs.shape)
130 + # print(trt_outputs)
131 + # print(np.argmax(trt_outputs, axis=0))
132 + # cnt += 1
133 + # if cnt == 100:
134 + # break
135 +print(time.time() - check_time)
136 +
137 +
138 +test_model = model.Net().cuda()
139 +check_time = time.time()
140 +cnt = 0
141 +temp = torch.randn(256, 1, 29, 29).cuda()
142 +for idx in range(100):
143 +# for i, (inputs, labels) in enumerate(testloader):
144 + # inputs = inputs.float().cuda()
145 + normal_outputs = test_model(temp)
146 + # print(normal_outputs)
147 + print(normal_outputs.shape)
148 + cnt += 1
149 + if cnt == 100:
150 + break
151 +print(time.time() - check_time)
152 +
153 +
154 +
155 +import tensorrt as trt
156 +import numpy as np
157 +import pycuda.autoinit
158 +import pycuda.driver as cuda
159 +import time
160 +
161 +model_path = "bert.onnx"
162 +input_size = 32
163 +
164 +TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
165 +
166 +# def build_engine(model_path):
167 +# with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
168 +# builder.max_workspace_size = 1<<20
169 +# builder.max_batch_size = 1
170 +# with open(model_path, "rb") as f:
171 +# parser.parse(f.read())
172 +# engine = builder.build_cuda_engine(network)
173 +# return engine
174 +
175 +def alloc_buf(engine):
176 + # host cpu mem
177 + h_in_size = trt.volume(engine.get_binding_shape(0))
178 + h_out_size = trt.volume(engine.get_binding_shape(1))
179 + h_in_dtype = trt.nptype(engine.get_binding_dtype(0))
180 + h_out_dtype = trt.nptype(engine.get_binding_dtype(1))
181 + in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype)
182 + out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype)
183 + # allocate gpu mem
184 + in_gpu = cuda.mem_alloc(in_cpu.nbytes)
185 + out_gpu = cuda.mem_alloc(out_cpu.nbytes)
186 + stream = cuda.Stream()
187 + return in_cpu, out_cpu, in_gpu, out_gpu, stream
188 +
189 +
190 +def inference(engine, context, inputs, out_cpu, in_gpu, out_gpu, stream):
191 + # async version
192 + # with engine.create_execution_context() as context: # cost time to initialize
193 + # cuda.memcpy_htod_async(in_gpu, inputs, stream)
194 + # context.execute_async(1, [int(in_gpu), int(out_gpu)], stream.handle, None)
195 + # cuda.memcpy_dtoh_async(out_cpu, out_gpu, stream)
196 + # stream.synchronize()
197 +
198 + # sync version
199 + cuda.memcpy_htod(in_gpu, inputs)
200 + context.execute(1, [int(in_gpu), int(out_gpu)])
201 + cuda.memcpy_dtoh(out_cpu, out_gpu)
202 + return out_cpu
203 +
204 +if __name__ == "__main__":
205 + inputs = np.random.random((1, 1, 29, 29)).astype(np.float32)
206 +
207 + tensorrt_file_name = '/content/drive/My Drive/capstone1/CAN/bert.plan'
208 + TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
209 + trt_runtime = trt.Runtime(TRT_LOGGER)
210 +
211 + with open(tensorrt_file_name, 'rb') as f:
212 + engine_data = f.read()
213 + engine = trt_runtime.deserialize_cuda_engine(engine_data)
214 + # engine = build_engine(model_path)
215 + context = engine.create_execution_context()
216 + for _ in range(10):
217 + t1 = time.time()
218 + in_cpu, out_cpu, in_gpu, out_gpu, stream = alloc_buf(engine)
219 + res = inference(engine, context, inputs.reshape(-1), out_cpu, in_gpu, out_gpu, stream)
220 + print(res)
221 + print("cost time: ", time.time()-t1)
...\ No newline at end of file ...\ No newline at end of file
1 +import model
2 +import torch
3 +
4 +import importlib
5 +importlib.reload(model)
6 +
7 +batch_size = 256
8 +model = model.Net().cuda().eval()
9 +inputs = torch.randn(batch_size, 1, 29, 29, requires_grad=True).cuda()
10 +torch_out = model(inputs)
11 +
12 +torch.onnx.export(
13 + model,
14 + inputs,
15 + 'bert.onnx',
16 + input_names=['inputs'],
17 + output_names=['outputs'],
18 + export_params=True)
19 +
20 +print('done, onnx model')
...\ No newline at end of file ...\ No newline at end of file
1 +import pandas as pd
2 +import numpy as np
3 +import csv
4 +import os
5 +import const
6 +from matplotlib import pyplot as plt
7 +
8 +
9 +def run_benchmark_cnn():
10 + import sys
11 + sys.path.append("/content/drive/My Drive/capstone1/CAN/torch2trt")
12 + from torch2trt import torch2trt
13 + import model
14 + import time
15 + import torch
16 + import dataset
17 + import torch.nn as nn
18 +
19 + test_model = model.CnnNet()
20 + test_model.eval().cuda()
21 +
22 + batch_size = 1
23 + # _, _, _, test_data_set = dataset.GetCanDataset(100, 0, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
24 +
25 + # sampler = dataset.BatchIntervalSampler(len(test_data_set), batch_size)
26 + # testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler,
27 + # shuffle=False, num_workers=2, drop_last=True)
28 +
29 + # create model and input data
30 + # for inputs, labels in testloader:
31 + # trt_x = inputs.float().cuda()
32 + # trt_state = torch.zeros((batch_size, 8 * 32)).float().cuda()
33 + # trt_model = model.OneNet()
34 + # trt_model.load_state_dict(torch.load(weight_path))
35 + # trt_model.float().eval().cuda()
36 +
37 + # trt_f16_x = inputs.half().cuda()
38 + # trt_f16_state = torch.zeros((batch_size, 8 * 32)).half().cuda()
39 + # trt_f16_model = model.OneNet()
40 + # trt_f16_model.load_state_dict(torch.load(weight_path))
41 + # trt_f16_model.half().eval().cuda()
42 +
43 + # trt_int8_strict_x = inputs.float().cuda()
44 + # trt_int8_strict_state = torch.zeros((batch_size, 8 * 32)).float().cuda() # match model weight
45 + # trt_int8_strict_model = model.OneNet()
46 + # trt_int8_strict_model.load_state_dict(torch.load(weight_path))
47 + # trt_int8_strict_model.eval().cuda() # no attribute 'char'
48 +
49 + # break
50 +
51 + inputs = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN))
52 +
53 +
54 + trt_x = inputs.half().cuda() # ??? densenet error?
55 + trt_model = model.CnnNet()
56 + # trt_model.load_state_dict(torch.load(weight_path))
57 + trt_model.eval().cuda()
58 +
59 + trt_f16_x = inputs.half().cuda()
60 + trt_f16_model = model.CnnNet().half()
61 + # trt_f16_model.load_state_dict(torch.load(weight_path))
62 + trt_f16_model.half().eval().cuda()
63 +
64 + trt_int8_strict_x = inputs.half().cuda() # match model weight
65 + trt_int8_strict_model = model.CnnNet()
66 + # trt_int8_strict_model.load_state_dict(torch.load(weight_path))
67 + trt_int8_strict_model.eval().cuda() # no attribute 'char'
68 +
69 + # convert to TensorRT feeding sample data as input
70 + model_trt = torch2trt(trt_model, [trt_x], max_batch_size=batch_size)
71 + model_trt_f16 = torch2trt(trt_f16_model, [trt_f16_x], fp16_mode=True, max_batch_size=batch_size)
72 + model_trt_int8_strict = torch2trt(trt_int8_strict_model, [trt_int8_strict_x], fp16_mode=False, int8_mode=True, strict_type_constraints=True, max_batch_size=batch_size)
73 +
74 + # testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler,
75 + # shuffle=False, num_workers=2, drop_last=True)
76 +
77 + with torch.no_grad():
78 + ### test inference time
79 + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).half().cuda()
80 + dummy_cnt = 10000
81 + print('ignore data loading time, inference random data')
82 +
83 + check_time = time.time()
84 + for i in range(dummy_cnt):
85 + _ = test_model(dummy_x)
86 + print('torch model: %.6f' % ((time.time() - check_time) / dummy_cnt))
87 +
88 + check_time = time.time()
89 + for i in range(dummy_cnt):
90 + _ = model_trt(dummy_x)
91 + print('trt model: %.6f' % ((time.time() - check_time) / dummy_cnt))
92 +
93 + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).half().cuda()
94 + check_time = time.time()
95 + for i in range(dummy_cnt):
96 + _ = model_trt_f16(dummy_x)
97 + print('trt float 16 model: %.6f' % ((time.time() - check_time) / dummy_cnt))
98 +
99 + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).char().cuda()
100 + check_time = time.time()
101 + for i in range(dummy_cnt):
102 + _ = model_trt_int8_strict(dummy_x)
103 + print('trt int8 strict model: %.6f' % ((time.time() - check_time) / dummy_cnt))
104 + return
105 + ## end
106 +
107 + criterion = nn.CrossEntropyLoss()
108 + state_temp = torch.zeros((batch_size, 8 * 32)).cuda()
109 + step_acc = 0.0
110 + step_loss = 0.0
111 + cnt = 0
112 + loss_cnt = 0
113 + for i, (inputs, labels) in enumerate(testloader):
114 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
115 + normal_outputs, state_temp = test_model(inputs, state_temp)
116 +
117 + _, preds = torch.max(normal_outputs, 1)
118 + edge_loss = criterion(normal_outputs, labels)
119 + step_loss += edge_loss.item()
120 + loss_cnt += 1
121 +
122 + corr_sum = torch.sum(preds == labels.data)
123 + step_acc += corr_sum.double()
124 + cnt += batch_size
125 + print('torch', step_acc.item() / cnt, step_loss / loss_cnt)
126 +
127 + state_temp = torch.zeros((batch_size, 8 * 32)).cuda()
128 + step_acc = 0.0
129 + cnt = 0
130 + step_loss = 0.0
131 + loss_cnt = 0
132 + for i, (inputs, labels) in enumerate(testloader):
133 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
134 + normal_outputs, state_temp = model_trt(inputs, state_temp)
135 +
136 + _, preds = torch.max(normal_outputs, 1)
137 + edge_loss = criterion(normal_outputs, labels)
138 + step_loss += edge_loss.item()
139 + loss_cnt += 1
140 +
141 + corr_sum = torch.sum(preds == labels.data)
142 + step_acc += corr_sum.double()
143 + cnt += batch_size
144 + print('trt', step_acc.item() / cnt, step_loss / loss_cnt)
145 +
146 + state_temp = torch.zeros((batch_size, 8 * 32)).half().cuda()
147 + step_acc = 0.0
148 + cnt = 0
149 + step_loss = 0.0
150 + loss_cnt = 0
151 + for i, (inputs, labels) in enumerate(testloader):
152 + inputs, labels = inputs.half().cuda(), labels.long().cuda()
153 + normal_outputs, state_temp = model_trt_f16(inputs, state_temp)
154 +
155 + _, preds = torch.max(normal_outputs, 1)
156 + edge_loss = criterion(normal_outputs, labels)
157 + step_loss += edge_loss.item()
158 + loss_cnt += 1
159 +
160 + corr_sum = torch.sum(preds == labels.data)
161 + step_acc += corr_sum.double()
162 + cnt += batch_size
163 + print('float16', step_acc.item() / cnt, step_loss / loss_cnt)
164 +
165 + state_temp = torch.zeros((batch_size, 8 * 32)).float().cuda()
166 + step_acc = 0.0
167 + cnt = 0
168 + step_loss = 0.0
169 + loss_cnt = 0
170 + for i, (inputs, labels) in enumerate(testloader):
171 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
172 + normal_outputs, state_temp = model_trt_int8_strict(inputs, state_temp)
173 +
174 + _, preds = torch.max(normal_outputs, 1)
175 + edge_loss = criterion(normal_outputs, labels)
176 + step_loss += edge_loss.item()
177 + loss_cnt += 1
178 +
179 + corr_sum = torch.sum(preds == labels.data)
180 + step_acc += corr_sum.double()
181 + cnt += batch_size
182 + print('int8 strict', step_acc.item() / cnt, step_loss / loss_cnt)
183 +
184 +
185 +def run_benchmark(weight_path):
186 + import sys
187 + sys.path.append("/content/drive/My Drive/capstone1/CAN/torch2trt")
188 + from torch2trt import torch2trt
189 + import model
190 + import time
191 + import torch
192 + import dataset
193 + import torch.nn as nn
194 +
195 + test_model = model.OneNet()
196 + test_model.load_state_dict(torch.load(weight_path))
197 + test_model.eval().cuda()
198 +
199 + batch_size = 1
200 + _, _, _, test_data_set = dataset.GetCanDataset(100, 0, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
201 +
202 + sampler = dataset.BatchIntervalSampler(len(test_data_set), batch_size)
203 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler,
204 + shuffle=False, num_workers=2, drop_last=True)
205 +
206 + # create model and input data
207 + for inputs, labels in testloader:
208 + # inputs = torch.cat([inputs, inputs, inputs], 1)
209 +
210 + trt_x = inputs.float().cuda()
211 + trt_state = torch.zeros((batch_size, 8 * 32)).float().cuda()
212 + trt_model = model.OneNet()
213 + trt_model.load_state_dict(torch.load(weight_path))
214 + trt_model.float().eval().cuda()
215 +
216 + trt_f16_x = inputs.half().cuda()
217 + trt_f16_state = torch.zeros((batch_size, 8 * 32)).half().cuda()
218 + trt_f16_model = model.OneNet().half()
219 + trt_f16_model.load_state_dict(torch.load(weight_path))
220 + trt_f16_model.half().eval().cuda()
221 +
222 + trt_int8_strict_x = inputs.float().cuda()
223 + trt_int8_strict_state = torch.zeros((batch_size, 8 * 32)).float().cuda() # match model weight
224 + trt_int8_strict_model = model.OneNet()
225 + trt_int8_strict_model.load_state_dict(torch.load(weight_path))
226 + trt_int8_strict_model.eval().cuda() # no attribute 'char'
227 +
228 + break
229 +
230 + # convert to TensorRT feeding sample data as input
231 + model_trt = torch2trt(trt_model, [trt_x, trt_state], max_batch_size=batch_size)
232 + model_trt_f16 = torch2trt(trt_f16_model, [trt_f16_x, trt_f16_state], fp16_mode=True, max_batch_size=batch_size)
233 + model_trt_int8_strict = torch2trt(trt_int8_strict_model, [trt_int8_strict_x, trt_int8_strict_state], fp16_mode=False, int8_mode=True, strict_type_constraints=True, max_batch_size=batch_size)
234 +
235 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler,
236 + shuffle=False, num_workers=2, drop_last=True)
237 +
238 + with torch.no_grad():
239 + ### test inference time
240 + dummy_x = torch.ones((batch_size, 8)).cuda()
241 + dummy_state = torch.zeros(batch_size, model.STATE_DIM).cuda()
242 + dummy_cnt = 10000
243 + print('ignore data loading time, inference random data')
244 +
245 + check_time = time.time()
246 + for i in range(dummy_cnt):
247 + _, _ = test_model(dummy_x, dummy_state)
248 + print('torch model: %.6f' % ((time.time() - check_time) / dummy_cnt))
249 +
250 + check_time = time.time()
251 + for i in range(dummy_cnt):
252 + _, _ = model_trt(dummy_x, dummy_state)
253 + print('trt model: %.6f' % ((time.time() - check_time) / dummy_cnt))
254 +
255 + dummy_x = torch.ones((batch_size, 8)).half().cuda()
256 + dummy_state = torch.zeros(batch_size, model.STATE_DIM).half().cuda()
257 + check_time = time.time()
258 + for i in range(dummy_cnt):
259 + _, _ = model_trt_f16(dummy_x, dummy_state)
260 + print('trt float 16 model: %.6f' % ((time.time() - check_time) / dummy_cnt))
261 +
262 + dummy_x = torch.ones((batch_size, 8)).char().cuda()
263 + dummy_state = torch.zeros(batch_size, model.STATE_DIM).char().cuda()
264 + check_time = time.time()
265 + for i in range(dummy_cnt):
266 + _, _ = model_trt_int8_strict(dummy_x, dummy_state)
267 + print('trt int8 strict model: %.6f' % ((time.time() - check_time) / dummy_cnt))
268 + return
269 + ## end
270 +
271 + criterion = nn.CrossEntropyLoss()
272 + state_temp = torch.zeros((batch_size, 8 * 32)).cuda()
273 + step_acc = 0.0
274 + step_loss = 0.0
275 + cnt = 0
276 + loss_cnt = 0
277 + for i, (inputs, labels) in enumerate(testloader):
278 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
279 + normal_outputs, state_temp = test_model(inputs, state_temp)
280 +
281 + _, preds = torch.max(normal_outputs, 1)
282 + edge_loss = criterion(normal_outputs, labels)
283 + step_loss += edge_loss.item()
284 + loss_cnt += 1
285 +
286 + corr_sum = torch.sum(preds == labels.data)
287 + step_acc += corr_sum.double()
288 + cnt += batch_size
289 + print('torch', step_acc.item() / cnt, step_loss / loss_cnt)
290 +
291 + state_temp = torch.zeros((batch_size, 8 * 32)).cuda()
292 + step_acc = 0.0
293 + cnt = 0
294 + step_loss = 0.0
295 + loss_cnt = 0
296 + for i, (inputs, labels) in enumerate(testloader):
297 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
298 + normal_outputs, state_temp = model_trt(inputs, state_temp)
299 +
300 + _, preds = torch.max(normal_outputs, 1)
301 + edge_loss = criterion(normal_outputs, labels)
302 + step_loss += edge_loss.item()
303 + loss_cnt += 1
304 +
305 + corr_sum = torch.sum(preds == labels.data)
306 + step_acc += corr_sum.double()
307 + cnt += batch_size
308 + print('trt', step_acc.item() / cnt, step_loss / loss_cnt)
309 +
310 + state_temp = torch.zeros((batch_size, 8 * 32)).half().cuda()
311 + step_acc = 0.0
312 + cnt = 0
313 + step_loss = 0.0
314 + loss_cnt = 0
315 + for i, (inputs, labels) in enumerate(testloader):
316 + inputs, labels = inputs.half().cuda(), labels.long().cuda()
317 + normal_outputs, state_temp = model_trt_f16(inputs, state_temp)
318 +
319 + _, preds = torch.max(normal_outputs, 1)
320 + edge_loss = criterion(normal_outputs, labels)
321 + step_loss += edge_loss.item()
322 + loss_cnt += 1
323 +
324 + corr_sum = torch.sum(preds == labels.data)
325 + step_acc += corr_sum.double()
326 + cnt += batch_size
327 + print('float16', step_acc.item() / cnt, step_loss / loss_cnt)
328 +
329 + state_temp = torch.zeros((batch_size, 8 * 32)).float().cuda()
330 + step_acc = 0.0
331 + cnt = 0
332 + step_loss = 0.0
333 + loss_cnt = 0
334 + for i, (inputs, labels) in enumerate(testloader):
335 + inputs, labels = inputs.float().cuda(), labels.long().cuda()
336 + normal_outputs, state_temp = model_trt_int8_strict(inputs, state_temp)
337 +
338 + _, preds = torch.max(normal_outputs, 1)
339 + edge_loss = criterion(normal_outputs, labels)
340 + step_loss += edge_loss.item()
341 + loss_cnt += 1
342 +
343 + corr_sum = torch.sum(preds == labels.data)
344 + step_acc += corr_sum.double()
345 + cnt += batch_size
346 + print('int8 strict', step_acc.item() / cnt, step_loss / loss_cnt)
347 +
348 +
349 +def drawGraph(x_value, x_label, y_axis, y_label):
350 + pass
351 +
352 +
353 +def CsvToTextOne(csv_file):
354 + target_csv = pd.read_csv(csv_file)
355 + file_name, extension = os.path.splitext(csv_file)
356 + print(file_name, extension)
357 + target_text = open(file_name + '_1.txt', mode='wt', encoding='utf-8')
358 +
359 + idx = 0
360 + print(len(target_csv))
361 +
362 + while idx < len(target_csv):
363 + csv_row = target_csv.iloc[idx]
364 + data_len = csv_row[1]
365 + is_regular = (csv_row[data_len + 2] == 'R')
366 +
367 + if is_regular:
368 + target_text.write("%d R\n" % idx)
369 + else:
370 + target_text.write("%d T\n" % idx)
371 +
372 + idx += 1
373 + if (idx % 1000000 == 0):
374 + print(idx)
375 +
376 + target_text.close()
377 + print('done')
378 +
379 +
380 +def Mix_Four_CANDataset():
381 + Dos_csv = pd.read_csv('./dataset/DoS_dataset.csv')
382 + Other_csv = [pd.read_csv('./dataset/Fuzzy_dataset.csv'),
383 + pd.read_csv('./dataset/RPM_dataset.csv'),
384 + pd.read_csv('./dataset/gear_dataset.csv')]
385 + Other_csv_idx = [0, 0, 0]
386 +
387 + save_csv = open('./dataset/Mixed_dataset.csv', 'w')
388 + save_csv_file = csv.writer(save_csv)
389 +
390 + # DoS 유해 트래픽 주기를 바꿈
391 + # DoS 다음 세번의 Dos 자리를 다른 유해 트래픽으로 바꿈
392 + # DoS / (Fuzzy, RPM, gear) 중 3번 순서 랜덤, 뽑히는 개수 랜덤 / Dos ...
393 + dos_idx = 0
394 + dos_preriod = 3
395 + while dos_idx < len(Dos_csv):
396 + dos_row = Dos_csv.iloc[dos_idx]
397 + number_of_data = dos_row[2]
398 + is_regular = (dos_row[number_of_data + 3] == 'R')
399 + dos_row.dropna(inplace=True)
400 +
401 + if is_regular:
402 + save_csv_file.writerow(dos_row[1:])
403 + else:
404 + if dos_preriod == 3:
405 + save_csv_file.writerow(dos_row[1:])
406 + np.random.seed(dos_idx)
407 + selected_edge = np.random.choice([0, 1, 2], 3, replace=True)
408 + else:
409 + selected_idx = selected_edge[dos_preriod]
410 + local_csv = Other_csv[selected_idx]
411 + local_idx = Other_csv_idx[selected_idx]
412 +
413 + while True:
414 + local_row = local_csv.iloc[local_idx]
415 + local_number_of_data = local_row[2]
416 + is_injected = (local_row[local_number_of_data + 3] == 'T')
417 + local_idx += 1
418 + if is_injected:
419 + local_row.dropna(inplace=True)
420 + save_csv_file.writerow(local_row[1:])
421 + break
422 + Other_csv_idx[selected_idx] = local_idx
423 +
424 + dos_preriod -= 1
425 + if dos_preriod == -1:
426 + dos_preriod = 3
427 +
428 + dos_idx += 1
429 + if dos_idx % 100000 == 0:
430 + print(dos_idx)
431 + # break
432 + save_csv.close()
433 +
434 +def Mix_Six_SynCANDataset():
435 + normal_csv = pd.read_csv('./dataset/test_normal.csv')
436 + normal_idx = 0
437 + target_len = len(normal_csv)
438 +
439 + save_csv = open('./dataset/test_mixed.csv', 'w')
440 + save_csv_file = csv.writer(save_csv)
441 +
442 + other_csv = [pd.read_csv('./dataset/test_continuous.csv'),
443 + pd.read_csv('./dataset/test_flooding.csv'),
444 + pd.read_csv('./dataset/test_plateau.csv'),
445 + pd.read_csv('./dataset/test_playback.csv'),
446 + pd.read_csv('./dataset/test_suppress.csv')]
447 + other_csv_idx = [0, 0, 0, 0, 0]
448 +
449 + while normal_idx < target_len:
450 + np.random.seed(normal_idx)
451 + selected_csv = np.random.choice([0, 1, 2, 3, 4], 5, replace=True)
452 + all_done = True
453 + for csv_idx in selected_csv:
454 + now_csv = other_csv[csv_idx]
455 + now_idx = other_csv_idx[csv_idx]
456 +
457 + start_normal_idx = now_idx
458 + while now_idx < len(now_csv):
459 + csv_row_ahead = now_csv.iloc[now_idx + 1]
460 + label_ahead = csv_row_ahead[0]
461 +
462 + csv_row_behind = now_csv.iloc[now_idx]
463 + label_behind = csv_row_behind[0]
464 +
465 + if label_ahead == 1 and label_behind == 0:
466 + print(now_idx, 'start error')
467 + add_normal_len = (now_idx - start_normal_idx) // 9
468 + start_abnormal_idx = now_idx + 1
469 + elif label_ahead == 0 and label_behind == 1:
470 + print(now_idx, 'end error')
471 + add_abnormal_len = (now_idx - start_abnormal_idx) // 6
472 +
473 + for _ in range(6):
474 + # done
475 + if normal_idx + add_normal_len >= target_len:
476 + save_csv.close()
477 + return
478 +
479 + # write normal
480 + for idx in range(normal_idx, normal_idx + add_normal_len):
481 + row = normal_csv.iloc[idx]
482 + row = row.fillna(0)
483 + save_csv_file.writerow(row[0:1].append(row[2:]))
484 + normal_idx += add_normal_len
485 + # write abnormal
486 + for idx in range(start_abnormal_idx, start_abnormal_idx + add_abnormal_len):
487 + row = now_csv.iloc[idx]
488 + row = row.fillna(0)
489 + save_csv_file.writerow(row[0:1].append(row[2:]))
490 + start_abnormal_idx += add_abnormal_len
491 +
492 + other_csv_idx[csv_idx] = now_idx + 1
493 + # check other csv not end
494 + all_done = False
495 + break
496 +
497 + now_idx += 1
498 +
499 + if all_done:
500 + break
501 +
502 + save_csv.close()