Showing
9 changed files
with
1506 additions
and
465 deletions
1 | -import os | ||
2 | import torch | 1 | import torch |
3 | import pandas as pd | 2 | import pandas as pd |
4 | import numpy as np | 3 | import numpy as np |
5 | from torch.utils.data import Dataset, DataLoader | 4 | from torch.utils.data import Dataset, DataLoader |
5 | +from torch.utils.data.sampler import Sampler | ||
6 | import const | 6 | import const |
7 | 7 | ||
8 | -''' | ||
9 | -def int_to_binary(x, bits): | ||
10 | - mask = 2 ** torch.arange(bits).to(x.device, x.dtype) | ||
11 | - return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() | ||
12 | -''' | ||
13 | 8 | ||
14 | -def unpack_bits(x, num_bits): | 9 | +# 0, batch * 1, batch * 2 ... |
15 | - """ | 10 | +class BatchIntervalSampler(Sampler): |
16 | - Args: | ||
17 | - x (int): bit로 변환할 정수 | ||
18 | - num_bits (int): 표현할 비트수 | ||
19 | - """ | ||
20 | - xshape = list(x.shape) | ||
21 | - x = x.reshape([-1, 1]) | ||
22 | - mask = 2**np.arange(num_bits).reshape([1, num_bits]) | ||
23 | - return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits]) | ||
24 | 11 | ||
12 | + def __init__(self, data_length, batch_size): | ||
13 | + # data length 가 batch size 로 나뉘게 만듦 | ||
14 | + if data_length % batch_size != 0: | ||
15 | + data_length = data_length - (data_length % batch_size) | ||
25 | 16 | ||
26 | -# def CsvToNumpy(csv_file): | 17 | + self.indices =[] |
27 | -# target_csv = pd.read_csv(csv_file) | 18 | + # print(data_length) |
28 | -# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | 19 | + batch_group_interval = int(data_length / batch_size) |
29 | -# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | 20 | + for group_idx in range(batch_group_interval): |
30 | -# print(inputs_save_numpy, labels_save_numpy) | 21 | + for local_idx in range(batch_size): |
22 | + self.indices.append(group_idx + local_idx * batch_group_interval) | ||
23 | + # print('sampler init', self.indices) | ||
31 | 24 | ||
32 | -# i = 0 | 25 | + def __iter__(self): |
33 | -# inputs_array = [] | 26 | + return iter(self.indices) |
34 | -# labels_array = [] | ||
35 | -# print(len(target_csv)) | ||
36 | 27 | ||
37 | -# while i + const.CAN_ID_BIT - 1 < len(target_csv): | 28 | + def __len__(self): |
38 | - | 29 | + return len(self.indices) |
39 | -# is_regular = True | ||
40 | -# for j in range(const.CAN_ID_BIT): | ||
41 | -# l = target_csv.iloc[i + j] | ||
42 | -# b = l[2] | ||
43 | -# r = (l[b+2+1] == 'R') | ||
44 | - | ||
45 | -# if not r: | ||
46 | -# is_regular = False | ||
47 | -# break | ||
48 | - | ||
49 | -# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
50 | -# for idx in range(const.CAN_ID_BIT): | ||
51 | -# can_id = int(target_csv.iloc[i + idx, 1], 16) | ||
52 | -# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT) | ||
53 | -# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
54 | - | ||
55 | -# if is_regular: | ||
56 | -# labels = 1 | ||
57 | -# else: | ||
58 | -# labels = 0 | ||
59 | - | ||
60 | -# inputs_array.append(inputs) | ||
61 | -# labels_array.append(labels) | ||
62 | - | ||
63 | -# i+=1 | ||
64 | -# if (i % 5000 == 0): | ||
65 | -# print(i) | ||
66 | -# # break | ||
67 | - | ||
68 | -# inputs_array = np.array(inputs_array) | ||
69 | -# labels_array = np.array(labels_array) | ||
70 | -# np.save(inputs_save_numpy, arr=inputs_array) | ||
71 | -# np.save(labels_save_numpy, arr=labels_array) | ||
72 | -# print('done') | ||
73 | - | ||
74 | - | ||
75 | -def CsvToText(csv_file): | ||
76 | - target_csv = pd.read_csv(csv_file) | ||
77 | - text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt' | ||
78 | - print(text_file_name) | ||
79 | - target_text = open(text_file_name, mode='wt', encoding='utf-8') | ||
80 | - | ||
81 | - i = 0 | ||
82 | - datum = [ [], [] ] | ||
83 | - print(len(target_csv)) | ||
84 | - | ||
85 | - while i + const.CAN_ID_BIT - 1 < len(target_csv): | ||
86 | - | ||
87 | - is_regular = True | ||
88 | - for j in range(const.CAN_ID_BIT): | ||
89 | - l = target_csv.iloc[i + j] | ||
90 | - b = l[2] | ||
91 | - r = (l[b+2+1] == 'R') | ||
92 | - | ||
93 | - if not r: | ||
94 | - is_regular = False | ||
95 | - break | ||
96 | - | ||
97 | - if is_regular: | ||
98 | - target_text.write("%d R\n" % i) | ||
99 | - else: | ||
100 | - target_text.write("%d T\n" % i) | ||
101 | - | ||
102 | - i+=1 | ||
103 | - if (i % 5000 == 0): | ||
104 | - print(i) | ||
105 | - | ||
106 | - target_text.close() | ||
107 | - print('done') | ||
108 | 30 | ||
109 | 31 | ||
110 | def record_net_data_stats(label_temp, data_idx_map): | 32 | def record_net_data_stats(label_temp, data_idx_map): |
... | @@ -120,205 +42,92 @@ def record_net_data_stats(label_temp, data_idx_map): | ... | @@ -120,205 +42,92 @@ def record_net_data_stats(label_temp, data_idx_map): |
120 | return net_class_count, net_data_count | 42 | return net_class_count, net_data_count |
121 | 43 | ||
122 | 44 | ||
123 | -def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs): | 45 | +def GetCanDataset(total_edge, fold_num, packet_num, csv_path, txt_path): |
124 | - csv_list = [] | 46 | + csv = pd.read_csv(csv_path) |
125 | - total_datum = [] | 47 | + txt = open(txt_path, "r") |
126 | - total_label_temp = [] | 48 | + lines = txt.read().splitlines() |
127 | - csv_idx = 0 | ||
128 | - for csv_file, txt_file in kwargs.items(): | ||
129 | - csv = pd.read_csv(csv_file) | ||
130 | - csv_list.append(csv) | ||
131 | - | ||
132 | - txt = open(txt_file, "r") | ||
133 | - lines = txt.read().splitlines() | ||
134 | 49 | ||
135 | - idx = 0 | 50 | + idx = 0 |
136 | - local_datum = [] | 51 | + datum = [] |
137 | - while idx + const.CAN_ID_BIT - 1 < len(csv): | 52 | + label_temp = [] |
138 | - line = lines[idx] | 53 | + # [cur_idx ~ cur_idx + packet_num) |
139 | - if not line: | 54 | + while idx + packet_num - 1 < len(csv) // 2: |
140 | - break | 55 | + line = lines[idx + packet_num - 1] |
56 | + if not line: | ||
57 | + break | ||
141 | 58 | ||
142 | - if line.split(' ')[1] == 'R': | 59 | + if line.split(' ')[1] == 'R': |
143 | - local_datum.append((csv_idx, idx, 1)) | 60 | + datum.append((idx, 1)) |
144 | - total_label_temp.append(1) | 61 | + label_temp.append(1) |
145 | - else: | 62 | + else: |
146 | - local_datum.append((csv_idx, idx, 0)) | 63 | + datum.append((idx, 0)) |
147 | - total_label_temp.append(0) | 64 | + label_temp.append(0) |
148 | 65 | ||
149 | - idx += 1 | 66 | + idx += 1 |
150 | - if (idx % 1000000 == 0): | 67 | + if (idx % 1000000 == 0): |
151 | - print(idx) | 68 | + print(idx) |
152 | 69 | ||
153 | - csv_idx += 1 | 70 | + fold_length = int(len(label_temp) / 5) |
154 | - total_datum += local_datum | 71 | + train_datum = [] |
155 | - | 72 | + train_label_temp = [] |
156 | - fold_length = int(len(total_label_temp) / 5) | ||
157 | - datum = [] | ||
158 | - label_temp = [] | ||
159 | for i in range(5): | 73 | for i in range(5): |
160 | if i != fold_num: | 74 | if i != fold_num: |
161 | - datum += total_datum[i*fold_length:(i+1)*fold_length] | 75 | + train_datum += datum[i*fold_length:(i+1)*fold_length] |
162 | - label_temp += total_label_temp[i*fold_length:(i+1)*fold_length] | 76 | + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length] |
163 | else: | 77 | else: |
164 | - test_datum = total_datum[i*fold_length:(i+1)*fold_length] | 78 | + test_datum = datum[i*fold_length:(i+1)*fold_length] |
165 | 79 | ||
166 | - min_size = 0 | ||
167 | - output_class_num = 2 | ||
168 | - N = len(label_temp) | ||
169 | - label_temp = np.array(label_temp) | ||
170 | - data_idx_map = {} | ||
171 | 80 | ||
172 | - while min_size < 512: | 81 | + N = len(train_label_temp) |
173 | - idx_batch = [[] for _ in range(total_edge)] | 82 | + train_label_temp = np.array(train_label_temp) |
174 | - # for each class in the dataset | ||
175 | - for k in range(output_class_num): | ||
176 | - idx_k = np.where(label_temp == k)[0] | ||
177 | - np.random.shuffle(idx_k) | ||
178 | - proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
179 | - ## Balance | ||
180 | - proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)]) | ||
181 | - proportions = proportions/proportions.sum() | ||
182 | - proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] | ||
183 | - idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))] | ||
184 | - min_size = min([len(idx_j) for idx_j in idx_batch]) | ||
185 | 83 | ||
84 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
85 | + proportions = np.cumsum(proportions) | ||
86 | + idx_batch = [[] for _ in range(total_edge)] | ||
87 | + data_idx_map = {} | ||
88 | + prev = 0.0 | ||
186 | for j in range(total_edge): | 89 | for j in range(total_edge): |
187 | - np.random.shuffle(idx_batch[j]) | 90 | + idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))] |
91 | + prev = proportions[j] | ||
188 | data_idx_map[j] = idx_batch[j] | 92 | data_idx_map[j] = idx_batch[j] |
189 | 93 | ||
190 | - net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map) | 94 | + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map) |
191 | 95 | ||
192 | - return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False) | 96 | + return CanDataset(csv, train_datum, packet_num), data_idx_map, net_data_count, CanDataset(csv, test_datum, packet_num, False) |
193 | 97 | ||
194 | 98 | ||
195 | -class CanDatasetKwarg(Dataset): | 99 | +class CanDataset(Dataset): |
196 | 100 | ||
197 | - def __init__(self, csv_list, datum, is_train=True): | 101 | + def __init__(self, csv, datum, packet_num, is_train=True): |
198 | - self.csv_list = csv_list | 102 | + self.csv = csv |
199 | self.datum = datum | 103 | self.datum = datum |
104 | + self.packet_num = packet_num | ||
200 | if is_train: | 105 | if is_train: |
201 | self.idx_map = [] | 106 | self.idx_map = [] |
202 | else: | 107 | else: |
203 | self.idx_map = [idx for idx in range(len(self.datum))] | 108 | self.idx_map = [idx for idx in range(len(self.datum))] |
204 | 109 | ||
205 | def __len__(self): | 110 | def __len__(self): |
206 | - return len(self.idx_map) | 111 | + return len(self.idx_map) - self.packet_num + 1 |
207 | 112 | ||
208 | def set_idx_map(self, data_idx_map): | 113 | def set_idx_map(self, data_idx_map): |
209 | self.idx_map = data_idx_map | 114 | self.idx_map = data_idx_map |
210 | 115 | ||
211 | def __getitem__(self, idx): | 116 | def __getitem__(self, idx): |
212 | - csv_idx = self.datum[self.idx_map[idx]][0] | 117 | + # [cur_idx ~ cur_idx + packet_num) |
213 | - start_i = self.datum[self.idx_map[idx]][1] | 118 | + start_i = self.datum[self.idx_map[idx]][0] |
214 | - is_regular = self.datum[self.idx_map[idx]][2] | 119 | + is_regular = self.datum[self.idx_map[idx + self.packet_num - 1]][1] |
215 | - | ||
216 | - l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
217 | - for i in range(const.CAN_ID_BIT): | ||
218 | - id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16) | ||
219 | - bits = unpack_bits(np.array(id_), const.CAN_ID_BIT) | ||
220 | - l[i] = bits | ||
221 | - l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
222 | - | ||
223 | - return (l, is_regular) | ||
224 | - | ||
225 | - | ||
226 | -def GetCanDatasetUsingTxt(csv_file, txt_path, length): | ||
227 | - csv = pd.read_csv(csv_file) | ||
228 | - txt = open(txt_path, "r") | ||
229 | - lines = txt.read().splitlines() | ||
230 | - | ||
231 | - idx = 0 | ||
232 | - datum = [ [], [] ] | ||
233 | - while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
234 | - if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
235 | - break | ||
236 | - | ||
237 | - line = lines[idx] | ||
238 | - if not line: | ||
239 | - break | ||
240 | - | ||
241 | - if line.split(' ')[1] == 'R': | ||
242 | - if len(datum[0]) < length//2: | ||
243 | - datum[0].append((idx, 1)) | ||
244 | - else: | ||
245 | - if len(datum[1]) < length//2: | ||
246 | - datum[1].append((idx, 0)) | ||
247 | - | ||
248 | - idx += 1 | ||
249 | - if (idx % 5000 == 0): | ||
250 | - print(idx, len(datum[0]), len(datum[1])) | ||
251 | - | ||
252 | - l = int((length // 2) * 0.9) | ||
253 | - return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
254 | - CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
255 | - | ||
256 | - | ||
257 | -def GetCanDataset(csv_file, length): | ||
258 | - csv = pd.read_csv(csv_file) | ||
259 | - | ||
260 | - i = 0 | ||
261 | - datum = [ [], [] ] | ||
262 | - | ||
263 | - while i + const.CAN_ID_BIT - 1 < len(csv): | ||
264 | - if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
265 | - break | ||
266 | - | ||
267 | - is_regular = True | ||
268 | - for j in range(const.CAN_ID_BIT): | ||
269 | - l = csv.iloc[i + j] | ||
270 | - b = l[2] | ||
271 | - r = (l[b+2+1] == 'R') | ||
272 | - | ||
273 | - if not r: | ||
274 | - is_regular = False | ||
275 | - break | ||
276 | - | ||
277 | - if is_regular: | ||
278 | - if len(datum[0]) < length//2: | ||
279 | - datum[0].append((i, 1)) | ||
280 | - else: | ||
281 | - if len(datum[1]) < length//2: | ||
282 | - datum[1].append((i, 0)) | ||
283 | - i+=1 | ||
284 | - if (i % 5000 == 0): | ||
285 | - print(i, len(datum[0]), len(datum[1])) | ||
286 | - | ||
287 | - l = int((length // 2) * 0.9) | ||
288 | - return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
289 | - CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
290 | - | ||
291 | - | ||
292 | -class CanDataset(Dataset): | ||
293 | - | ||
294 | - def __init__(self, csv, datum): | ||
295 | - self.csv = csv | ||
296 | - self.datum = datum | ||
297 | - | ||
298 | - def __len__(self): | ||
299 | - return len(self.datum) | ||
300 | - | ||
301 | - def __getitem__(self, idx): | ||
302 | - start_i = self.datum[idx][0] | ||
303 | - is_regular = self.datum[idx][1] | ||
304 | 120 | ||
305 | - l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | 121 | + packet = np.zeros((const.CAN_DATA_LEN * self.packet_num)) |
306 | - for i in range(const.CAN_ID_BIT): | 122 | + for next_i in range(self.packet_num): |
307 | - id = int(self.csv.iloc[start_i + i, 1], 16) | 123 | + packet = np.zeros((const.CAN_DATA_LEN * self.packet_num)) |
308 | - bits = unpack_bits(np.array(id), const.CAN_ID_BIT) | 124 | + data_len = self.csv.iloc[start_i + next_i, 1] |
309 | - l[i] = bits | 125 | + for j in range(data_len): |
310 | - l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | 126 | + data_value = int(self.csv.iloc[start_i + next_i, 2 + j], 16) / 255.0 |
127 | + packet[j + const.CAN_DATA_LEN * next_i] = data_value | ||
311 | 128 | ||
312 | - return (l, is_regular) | 129 | + return torch.from_numpy(packet).float(), is_regular |
313 | 130 | ||
314 | 131 | ||
315 | if __name__ == "__main__": | 132 | if __name__ == "__main__": |
316 | - kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | 133 | + pass |
317 | - test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs) | ||
318 | - testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, | ||
319 | - shuffle=False, num_workers=2) | ||
320 | - | ||
321 | - for x, y in testloader: | ||
322 | - print(x) | ||
323 | - print(y) | ||
324 | - break | ... | ... |
코드/연합학습/deprecated.py
0 → 100644
1 | +#### utils #### | ||
2 | +# for mixed dataset | ||
3 | +def CsvToTextCNN(csv_file): | ||
4 | + target_csv = pd.read_csv(csv_file) | ||
5 | + file_name, extension = os.path.splitext(csv_file) | ||
6 | + print(file_name, extension) | ||
7 | + target_text = open(file_name + '_CNN8.txt', mode='wt', encoding='utf-8') | ||
8 | + | ||
9 | + idx = 0 | ||
10 | + print(len(target_csv)) | ||
11 | + | ||
12 | + while idx + const.CNN_FRAME_LEN - 1 < len(target_csv): | ||
13 | + | ||
14 | + is_regular = True | ||
15 | + for j in range(const.CNN_FRAME_LEN): | ||
16 | + l = target_csv.iloc[idx + j] | ||
17 | + b = l[1] | ||
18 | + r = (l[b+2] == 'R') | ||
19 | + | ||
20 | + if not r: | ||
21 | + is_regular = False | ||
22 | + break | ||
23 | + | ||
24 | + if is_regular: | ||
25 | + target_text.write("%d R\n" % idx) | ||
26 | + else: | ||
27 | + target_text.write("%d T\n" % idx) | ||
28 | + | ||
29 | + idx += 1 | ||
30 | + if idx % 300000 == 0: | ||
31 | + print(idx) | ||
32 | + | ||
33 | + target_text.close() | ||
34 | + print('done') | ||
35 | + | ||
36 | + | ||
37 | + | ||
38 | +#### dataset #### | ||
39 | +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs): | ||
40 | + csv_list = [] | ||
41 | + total_datum = [] | ||
42 | + total_label_temp = [] | ||
43 | + csv_idx = 0 | ||
44 | + for csv_file, txt_file in kwargs.items(): | ||
45 | + csv = pd.read_csv(csv_file) | ||
46 | + csv_list.append(csv) | ||
47 | + | ||
48 | + txt = open(txt_file, "r") | ||
49 | + lines = txt.read().splitlines() | ||
50 | + | ||
51 | + idx = 0 | ||
52 | + local_datum = [] | ||
53 | + while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
54 | + line = lines[idx] | ||
55 | + if not line: | ||
56 | + break | ||
57 | + | ||
58 | + if line.split(' ')[1] == 'R': | ||
59 | + local_datum.append((csv_idx, idx, 1)) | ||
60 | + total_label_temp.append(1) | ||
61 | + else: | ||
62 | + local_datum.append((csv_idx, idx, 0)) | ||
63 | + total_label_temp.append(0) | ||
64 | + | ||
65 | + idx += 1 | ||
66 | + if (idx % 1000000 == 0): | ||
67 | + print(idx) | ||
68 | + | ||
69 | + csv_idx += 1 | ||
70 | + total_datum += local_datum | ||
71 | + | ||
72 | + fold_length = int(len(total_label_temp) / 5) | ||
73 | + datum = [] | ||
74 | + label_temp = [] | ||
75 | + for i in range(5): | ||
76 | + if i != fold_num: | ||
77 | + datum += total_datum[i*fold_length:(i+1)*fold_length] | ||
78 | + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length] | ||
79 | + else: | ||
80 | + test_datum = total_datum[i*fold_length:(i+1)*fold_length] | ||
81 | + | ||
82 | + min_size = 0 | ||
83 | + output_class_num = 2 | ||
84 | + N = len(label_temp) | ||
85 | + label_temp = np.array(label_temp) | ||
86 | + data_idx_map = {} | ||
87 | + | ||
88 | + while min_size < 512: | ||
89 | + idx_batch = [[] for _ in range(total_edge)] | ||
90 | + # for each class in the dataset | ||
91 | + for k in range(output_class_num): | ||
92 | + idx_k = np.where(label_temp == k)[0] | ||
93 | + np.random.shuffle(idx_k) | ||
94 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
95 | + ## Balance | ||
96 | + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)]) | ||
97 | + proportions = proportions/proportions.sum() | ||
98 | + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] | ||
99 | + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))] | ||
100 | + min_size = min([len(idx_j) for idx_j in idx_batch]) | ||
101 | + | ||
102 | + for j in range(total_edge): | ||
103 | + np.random.shuffle(idx_batch[j]) | ||
104 | + data_idx_map[j] = idx_batch[j] | ||
105 | + | ||
106 | + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map) | ||
107 | + | ||
108 | + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False) | ||
109 | + | ||
110 | + | ||
111 | +class CanDatasetKwarg(Dataset): | ||
112 | + | ||
113 | + def __init__(self, csv_list, datum, is_train=True): | ||
114 | + self.csv_list = csv_list | ||
115 | + self.datum = datum | ||
116 | + if is_train: | ||
117 | + self.idx_map = [] | ||
118 | + else: | ||
119 | + self.idx_map = [idx for idx in range(len(self.datum))] | ||
120 | + | ||
121 | + def __len__(self): | ||
122 | + return len(self.idx_map) | ||
123 | + | ||
124 | + def set_idx_map(self, data_idx_map): | ||
125 | + self.idx_map = data_idx_map | ||
126 | + | ||
127 | + def __getitem__(self, idx): | ||
128 | + csv_idx = self.datum[self.idx_map[idx]][0] | ||
129 | + start_i = self.datum[self.idx_map[idx]][1] | ||
130 | + is_regular = self.datum[self.idx_map[idx]][2] | ||
131 | + | ||
132 | + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
133 | + for i in range(const.CAN_ID_BIT): | ||
134 | + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16) | ||
135 | + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT) | ||
136 | + l[i] = bits | ||
137 | + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
138 | + | ||
139 | + return (l, is_regular) | ||
140 | + | ||
141 | + | ||
142 | +def GetCanDataset(total_edge, fold_num, csv_path, txt_path): | ||
143 | + csv = pd.read_csv(csv_path) | ||
144 | + txt = open(txt_path, "r") | ||
145 | + lines = txt.read().splitlines() | ||
146 | + frame_size = const.CAN_FRAME_LEN | ||
147 | + idx = 0 | ||
148 | + datum = [] | ||
149 | + label_temp = [] | ||
150 | + while idx + frame_size - 1 < len(csv) // 2: | ||
151 | + # csv_row = csv.iloc[idx + frame_size - 1] | ||
152 | + # data_len = csv_row[1] | ||
153 | + # is_regular = (csv_row[data_len + 2] == 'R') | ||
154 | + | ||
155 | + # if is_regular: | ||
156 | + # datum.append((idx, 1)) | ||
157 | + # label_temp.append(1) | ||
158 | + # else: | ||
159 | + # datum.append((idx, 0)) | ||
160 | + # label_temp.append(0) | ||
161 | + line = lines[idx] | ||
162 | + if not line: | ||
163 | + break | ||
164 | + | ||
165 | + if line.split(' ')[1] == 'R': | ||
166 | + datum.append((idx, 1)) | ||
167 | + label_temp.append(1) | ||
168 | + else: | ||
169 | + datum.append((idx, 0)) | ||
170 | + label_temp.append(0) | ||
171 | + | ||
172 | + idx += 1 | ||
173 | + if (idx % 1000000 == 0): | ||
174 | + print(idx) | ||
175 | + | ||
176 | + fold_length = int(len(label_temp) / 5) | ||
177 | + train_datum = [] | ||
178 | + train_label_temp = [] | ||
179 | + for i in range(5): | ||
180 | + if i != fold_num: | ||
181 | + train_datum += datum[i*fold_length:(i+1)*fold_length] | ||
182 | + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length] | ||
183 | + else: | ||
184 | + test_datum = datum[i*fold_length:(i+1)*fold_length] | ||
185 | + | ||
186 | + min_size = 0 | ||
187 | + output_class_num = 2 | ||
188 | + N = len(train_label_temp) | ||
189 | + train_label_temp = np.array(train_label_temp) | ||
190 | + data_idx_map = {} | ||
191 | + | ||
192 | + # proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
193 | + # proportions = np.cumsum(proportions) | ||
194 | + # idx_batch = [[] for _ in range(total_edge)] | ||
195 | + # prev = 0.0 | ||
196 | + # for j in range(total_edge): | ||
197 | + # idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))] | ||
198 | + # prev = proportions[j] | ||
199 | + # np.random.shuffle(idx_batch[j]) | ||
200 | + # data_idx_map[j] = idx_batch[j] | ||
201 | + | ||
202 | + while min_size < 512: | ||
203 | + idx_batch = [[] for _ in range(total_edge)] | ||
204 | + # for each class in the dataset | ||
205 | + for k in range(output_class_num): | ||
206 | + idx_k = np.where(train_label_temp == k)[0] | ||
207 | + np.random.shuffle(idx_k) | ||
208 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
209 | + ## Balance | ||
210 | + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)]) | ||
211 | + proportions = proportions/proportions.sum() | ||
212 | + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] | ||
213 | + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))] | ||
214 | + min_size = min([len(idx_j) for idx_j in idx_batch]) | ||
215 | + | ||
216 | + for j in range(total_edge): | ||
217 | + np.random.shuffle(idx_batch[j]) | ||
218 | + data_idx_map[j] = idx_batch[j] | ||
219 | + | ||
220 | + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map) | ||
221 | + | ||
222 | + return CanDataset(csv, train_datum), data_idx_map, net_data_count, CanDataset(csv, test_datum, False) | ||
223 | + | ||
224 | + | ||
225 | +class CanDataset(Dataset): | ||
226 | + | ||
227 | + def __init__(self, csv, datum, is_train=True): | ||
228 | + self.csv = csv | ||
229 | + self.datum = datum | ||
230 | + self.is_train = is_train | ||
231 | + if self.is_train: | ||
232 | + self.idx_map = [] | ||
233 | + else: | ||
234 | + self.idx_map = [idx for idx in range(len(self.datum))] | ||
235 | + | ||
236 | + def __len__(self): | ||
237 | + return len(self.idx_map) | ||
238 | + | ||
239 | + def set_idx_map(self, data_idx_map): | ||
240 | + self.idx_map = data_idx_map | ||
241 | + | ||
242 | + def __getitem__(self, idx): | ||
243 | + start_i = self.datum[self.idx_map[idx]][0] | ||
244 | + if self.is_train: | ||
245 | + is_regular = self.datum[self.idx_map[idx]][1] | ||
246 | + l = np.zeros((const.CAN_FRAME_LEN, const.CAN_DATA_LEN)) | ||
247 | + ''' | ||
248 | + 각 바이트 값은 모두 normalized 된다. | ||
249 | + 0 ~ 255 -> 0.0 ~ 1.0 | ||
250 | + ''' | ||
251 | + for i in range(const.CAN_FRAME_LEN): | ||
252 | + data_len = self.csv.iloc[start_i + i, 1] | ||
253 | + for j in range(data_len): | ||
254 | + k = int(self.csv.iloc[start_i + i, 2 + j], 16) / 255.0 | ||
255 | + l[i][j] = k | ||
256 | + l = np.reshape(l, (1, const.CAN_FRAME_LEN, const.CAN_DATA_LEN)) | ||
257 | + else: | ||
258 | + l = np.zeros((const.CAN_DATA_LEN)) | ||
259 | + data_len = self.csv.iloc[start_i, 1] | ||
260 | + is_regular = self.csv.iloc[start_i, data_len + 2] == 'R' | ||
261 | + if is_regular: | ||
262 | + is_regular = 1 | ||
263 | + else: | ||
264 | + is_regular = 0 | ||
265 | + for j in range(data_len): | ||
266 | + k = int(self.csv.iloc[start_i, 2 + j], 16) / 255.0 | ||
267 | + l[j] = k | ||
268 | + l = np.reshape(l, (1, const.CAN_DATA_LEN)) | ||
269 | + | ||
270 | + return (l, is_regular) | ||
271 | + | ||
272 | + | ||
273 | +def GetCanDatasetCNN(total_edge, fold_num, csv_path, txt_path): | ||
274 | + csv = pd.read_csv(csv_path) | ||
275 | + txt = open(txt_path, "r") | ||
276 | + lines = txt.read().splitlines() | ||
277 | + | ||
278 | + idx = 0 | ||
279 | + datum = [] | ||
280 | + label_temp = [] | ||
281 | + while idx < len(csv) // 2: | ||
282 | + line = lines[idx] | ||
283 | + if not line: | ||
284 | + break | ||
285 | + | ||
286 | + if line.split(' ')[1] == 'R': | ||
287 | + datum.append((idx, 1)) | ||
288 | + label_temp.append(1) | ||
289 | + else: | ||
290 | + datum.append((idx, 0)) | ||
291 | + label_temp.append(0) | ||
292 | + | ||
293 | + idx += 1 | ||
294 | + if (idx % 1000000 == 0): | ||
295 | + print(idx) | ||
296 | + | ||
297 | + fold_length = int(len(label_temp) / 5) | ||
298 | + train_datum = [] | ||
299 | + train_label_temp = [] | ||
300 | + for i in range(5): | ||
301 | + if i != fold_num: | ||
302 | + train_datum += datum[i*fold_length:(i+1)*fold_length] | ||
303 | + train_label_temp += label_temp[i*fold_length:(i+1)*fold_length] | ||
304 | + else: | ||
305 | + test_datum = datum[i*fold_length:(i+1)*fold_length] | ||
306 | + | ||
307 | + | ||
308 | + N = len(train_label_temp) | ||
309 | + train_label_temp = np.array(train_label_temp) | ||
310 | + | ||
311 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
312 | + proportions = np.cumsum(proportions) | ||
313 | + idx_batch = [[] for _ in range(total_edge)] | ||
314 | + data_idx_map = {} | ||
315 | + prev = 0.0 | ||
316 | + for j in range(total_edge): | ||
317 | + idx_batch[j] = [idx for idx in range(int(prev * N), int(proportions[j] * N))] | ||
318 | + prev = proportions[j] | ||
319 | + data_idx_map[j] = idx_batch[j] | ||
320 | + | ||
321 | + _, net_data_count = record_net_data_stats(train_label_temp, data_idx_map) | ||
322 | + | ||
323 | + return CanDatasetCNN(csv, train_datum), data_idx_map, net_data_count, CanDatasetCNN(csv, test_datum, False) | ||
324 | + | ||
325 | + | ||
326 | +class CanDatasetCNN(Dataset): | ||
327 | + | ||
328 | + def __init__(self, csv, datum, is_train=True): | ||
329 | + self.csv = csv | ||
330 | + self.datum = datum | ||
331 | + if is_train: | ||
332 | + self.idx_map = [] | ||
333 | + else: | ||
334 | + self.idx_map = [idx for idx in range(len(self.datum))] | ||
335 | + | ||
336 | + def __len__(self): | ||
337 | + return len(self.idx_map) | ||
338 | + | ||
339 | + def set_idx_map(self, data_idx_map): | ||
340 | + self.idx_map = data_idx_map | ||
341 | + | ||
342 | + def __getitem__(self, idx): | ||
343 | + start_i = self.datum[self.idx_map[idx]][0] | ||
344 | + is_regular = self.datum[self.idx_map[idx]][1] | ||
345 | + | ||
346 | + packet = np.zeros((const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)) | ||
347 | + for i in range(const.CNN_FRAME_LEN): | ||
348 | + data_len = self.csv.iloc[start_i + i, 1] | ||
349 | + for j in range(data_len): | ||
350 | + k = int(self.csv.iloc[start_i + i, 2 + j], 16) / 255.0 | ||
351 | + packet[i][j] = k | ||
352 | + packet = np.reshape(packet, (1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)) | ||
353 | + return (packet, is_regular) | ||
354 | + | ||
355 | + | ||
356 | +def unpack_bits(x, num_bits): | ||
357 | + """ | ||
358 | + Args: | ||
359 | + x (int): bit로 변환할 정수 | ||
360 | + num_bits (int): 표현할 비트수 | ||
361 | + """ | ||
362 | + xshape = list(x.shape) | ||
363 | + x = x.reshape([-1, 1]) | ||
364 | + mask = 2**np.arange(num_bits).reshape([1, num_bits]) | ||
365 | + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits]) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | -import utils | ||
2 | import copy | 1 | import copy |
2 | +import argparse | ||
3 | +import time | ||
4 | +import math | ||
5 | +import numpy as np | ||
6 | +import os | ||
3 | from collections import OrderedDict | 7 | from collections import OrderedDict |
8 | +import torch | ||
9 | +import torch.optim as optim | ||
10 | +import torch.nn as nn | ||
4 | 11 | ||
5 | import model | 12 | import model |
13 | +import utils | ||
6 | import dataset | 14 | import dataset |
7 | 15 | ||
16 | +# for google colab reload | ||
8 | import importlib | 17 | import importlib |
9 | -importlib.reload(utils) | ||
10 | importlib.reload(model) | 18 | importlib.reload(model) |
19 | +importlib.reload(utils) | ||
11 | importlib.reload(dataset) | 20 | importlib.reload(dataset) |
12 | 21 | ||
13 | -from utils import * | 22 | + |
23 | +## paramter | ||
24 | + | ||
25 | +# shared | ||
26 | +criterion = nn.CrossEntropyLoss() | ||
27 | +C = 0.1 | ||
28 | +# | ||
29 | + | ||
30 | +# prox | ||
31 | +mu = 0.001 | ||
32 | +# | ||
33 | + | ||
34 | +# time weight | ||
35 | +twa_exp = 1.1 | ||
36 | +# | ||
37 | + | ||
38 | +# dynamic weight | ||
39 | +H = 0.5 | ||
40 | +P = 0.1 | ||
41 | +G = 0.1 | ||
42 | +R = 0.1 | ||
43 | +alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 | ||
44 | +# | ||
45 | + | ||
46 | +## end | ||
14 | 47 | ||
15 | 48 | ||
16 | def add_args(parser): | 49 | def add_args(parser): |
17 | - # parser.add_argument('--model', type=str, default='moderate-cnn', | 50 | + parser.add_argument('--packet_num', type=int, default=1, |
18 | - # help='neural network used in training') | 51 | + help='packet number used in training, 1 ~ 3') |
19 | - parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', | 52 | + parser.add_argument('--dataset', type=str, default='can', |
20 | - help='dataset used for training') | 53 | + help='dataset used for training, can or syncan') |
21 | parser.add_argument('--fold_num', type=int, default=0, | 54 | parser.add_argument('--fold_num', type=int, default=0, |
22 | help='5-fold, 0 ~ 4') | 55 | help='5-fold, 0 ~ 4') |
23 | - parser.add_argument('--batch_size', type=int, default=256, metavar='N', | 56 | + parser.add_argument('--batch_size', type=int, default=128, |
24 | help='input batch size for training') | 57 | help='input batch size for training') |
25 | - parser.add_argument('--lr', type=float, default=0.002, metavar='LR', | 58 | + parser.add_argument('--lr', type=float, default=0.001, |
26 | help='learning rate') | 59 | help='learning rate') |
27 | - parser.add_argument('--n_nets', type=int, default=100, metavar='NN', | 60 | + parser.add_argument('--n_nets', type=int, default=100, |
28 | help='number of workers in a distributed cluster') | 61 | help='number of workers in a distributed cluster') |
29 | - parser.add_argument('--comm_type', type=str, default='fedtwa', | 62 | + parser.add_argument('--comm_type', type=str, default='fedprox', |
30 | - help='which type of communication strategy is going to be used: layerwise/blockwise') | 63 | + help='type of communication, [fedavg, fedprox, fedtwa, feddw, edge]') |
31 | - parser.add_argument('--comm_round', type=int, default=10, | 64 | + parser.add_argument('--comm_round', type=int, default=50, |
32 | help='how many round of communications we shoud use') | 65 | help='how many round of communications we shoud use') |
66 | + parser.add_argument('--weight_save_path', type=str, default='./weights', | ||
67 | + help='model weight save path') | ||
33 | args = parser.parse_args(args=[]) | 68 | args = parser.parse_args(args=[]) |
34 | return args | 69 | return args |
35 | 70 | ||
36 | 71 | ||
72 | +def test_model(fed_model, args, testloader, device): | ||
73 | + fed_model.to(device) | ||
74 | + fed_model.eval() | ||
75 | + | ||
76 | + cnt = 0 | ||
77 | + step_acc = 0.0 | ||
78 | + with torch.no_grad(): | ||
79 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
80 | + for i, (inputs, labels) in enumerate(testloader): | ||
81 | + inputs, labels = inputs.to(device), labels.to(device) | ||
82 | + | ||
83 | + outputs, packet_state = fed_model(inputs, packet_state) | ||
84 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
85 | + | ||
86 | + _, preds = torch.max(outputs, 1) | ||
87 | + | ||
88 | + cnt += inputs.shape[0] | ||
89 | + corr_sum = torch.sum(preds == labels.data) | ||
90 | + step_acc += corr_sum.double() | ||
91 | + if i % 200 == 0: | ||
92 | + print('test [%4d/%4d] acc: %.3f' % (i, len(testloader), (step_acc / cnt).item())) | ||
93 | + # break | ||
94 | + fed_accuracy = (step_acc / cnt).item() | ||
95 | + print('test acc', fed_accuracy) | ||
96 | + fed_model.to('cpu') | ||
97 | + fed_model.train() | ||
98 | + torch.save(fed_model.state_dict(), os.path.join(args.weight_save_path, '%s_%d_%.4f.pth' % (args.comm_type, cr, fed_accuracy))) | ||
99 | + | ||
100 | + | ||
37 | def start_fedavg(fed_model, args, | 101 | def start_fedavg(fed_model, args, |
38 | train_data_set, | 102 | train_data_set, |
39 | data_idx_map, | 103 | data_idx_map, |
... | @@ -42,8 +106,7 @@ def start_fedavg(fed_model, args, | ... | @@ -42,8 +106,7 @@ def start_fedavg(fed_model, args, |
42 | edges, | 106 | edges, |
43 | device): | 107 | device): |
44 | print("start fed avg") | 108 | print("start fed avg") |
45 | - criterion = nn.CrossEntropyLoss() | 109 | + |
46 | - C = 0.1 | ||
47 | num_edge = int(max(C * args.n_nets, 1)) | 110 | num_edge = int(max(C * args.n_nets, 1)) |
48 | total_data_count = 0 | 111 | total_data_count = 0 |
49 | for _, data_count in net_data_count.items(): | 112 | for _, data_count in net_data_count.items(): |
... | @@ -59,20 +122,25 @@ def start_fedavg(fed_model, args, | ... | @@ -59,20 +122,25 @@ def start_fedavg(fed_model, args, |
59 | 122 | ||
60 | for edge_progress, edge_index in enumerate(selected_edge): | 123 | for edge_progress, edge_index in enumerate(selected_edge): |
61 | train_data_set.set_idx_map(data_idx_map[edge_index]) | 124 | train_data_set.set_idx_map(data_idx_map[edge_index]) |
62 | - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | 125 | + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size) |
63 | - shuffle=True, num_workers=2) | 126 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler, |
127 | + shuffle=False, num_workers=2, drop_last=True) | ||
64 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | 128 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) |
65 | 129 | ||
66 | edges[edge_index] = copy.deepcopy(fed_model) | 130 | edges[edge_index] = copy.deepcopy(fed_model) |
67 | edges[edge_index].to(device) | 131 | edges[edge_index].to(device) |
68 | edges[edge_index].train() | 132 | edges[edge_index].train() |
69 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | 133 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) |
134 | + | ||
70 | # train | 135 | # train |
136 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
71 | for data_idx, (inputs, labels) in enumerate(train_loader): | 137 | for data_idx, (inputs, labels) in enumerate(train_loader): |
72 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | 138 | + inputs, labels = inputs.to(device), labels.to(device) |
139 | + | ||
140 | + edge_pred, packet_state = edges[edge_index](inputs, packet_state) | ||
141 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
73 | 142 | ||
74 | edge_opt.zero_grad() | 143 | edge_opt.zero_grad() |
75 | - edge_pred = edges[edge_index](inputs) | ||
76 | 144 | ||
77 | edge_loss = criterion(edge_pred, labels) | 145 | edge_loss = criterion(edge_pred, labels) |
78 | edge_loss.backward() | 146 | edge_loss.backward() |
... | @@ -90,39 +158,13 @@ def start_fedavg(fed_model, args, | ... | @@ -90,39 +158,13 @@ def start_fedavg(fed_model, args, |
90 | local_state = edge.state_dict() | 158 | local_state = edge.state_dict() |
91 | for key in fed_model.state_dict().keys(): | 159 | for key in fed_model.state_dict().keys(): |
92 | if k == 0: | 160 | if k == 0: |
93 | - update_state[key] = local_state[key] * net_data_count[k] / total_data_count | 161 | + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) |
94 | else: | 162 | else: |
95 | - update_state[key] += local_state[key] * net_data_count[k] / total_data_count | 163 | + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) |
96 | - | 164 | + |
97 | fed_model.load_state_dict(update_state) | 165 | fed_model.load_state_dict(update_state) |
98 | if cr % 10 == 0: | 166 | if cr % 10 == 0: |
99 | - fed_model.to(device) | 167 | + test_model(fed_model, args, testloader, device) |
100 | - fed_model.eval() | ||
101 | - | ||
102 | - total_loss = 0.0 | ||
103 | - cnt = 0 | ||
104 | - step_acc = 0.0 | ||
105 | - with torch.no_grad(): | ||
106 | - for i, data in enumerate(testloader): | ||
107 | - inputs, labels = data | ||
108 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
109 | - | ||
110 | - outputs = fed_model(inputs) | ||
111 | - _, preds = torch.max(outputs, 1) | ||
112 | - | ||
113 | - loss = criterion(outputs, labels) | ||
114 | - cnt += inputs.shape[0] | ||
115 | - | ||
116 | - corr_sum = torch.sum(preds == labels.data) | ||
117 | - step_acc += corr_sum.double() | ||
118 | - running_loss = loss.item() * inputs.shape[0] | ||
119 | - total_loss += running_loss | ||
120 | - if i % 200 == 0: | ||
121 | - print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
122 | - # break | ||
123 | - print((step_acc / cnt).data) | ||
124 | - print(total_loss / cnt) | ||
125 | - fed_model.to('cpu') | ||
126 | 168 | ||
127 | 169 | ||
128 | def start_fedprox(fed_model, args, | 170 | def start_fedprox(fed_model, args, |
... | @@ -131,9 +173,7 @@ def start_fedprox(fed_model, args, | ... | @@ -131,9 +173,7 @@ def start_fedprox(fed_model, args, |
131 | testloader, | 173 | testloader, |
132 | device): | 174 | device): |
133 | print("start fed prox") | 175 | print("start fed prox") |
134 | - criterion = nn.CrossEntropyLoss() | 176 | + |
135 | - mu = 0.001 | ||
136 | - C = 0.1 | ||
137 | num_edge = int(max(C * args.n_nets, 1)) | 177 | num_edge = int(max(C * args.n_nets, 1)) |
138 | fed_model.to(device) | 178 | fed_model.to(device) |
139 | 179 | ||
... | @@ -149,25 +189,26 @@ def start_fedprox(fed_model, args, | ... | @@ -149,25 +189,26 @@ def start_fedprox(fed_model, args, |
149 | selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) | 189 | selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) |
150 | print("selected edge", selected_edge) | 190 | print("selected edge", selected_edge) |
151 | 191 | ||
152 | - total_data_length = 0 | ||
153 | - edge_data_len = [] | ||
154 | for edge_progress, edge_index in enumerate(selected_edge): | 192 | for edge_progress, edge_index in enumerate(selected_edge): |
155 | train_data_set.set_idx_map(data_idx_map[edge_index]) | 193 | train_data_set.set_idx_map(data_idx_map[edge_index]) |
156 | - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | 194 | + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size) |
157 | - shuffle=True, num_workers=2) | 195 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler, |
196 | + shuffle=False, num_workers=2, drop_last=True) | ||
158 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | 197 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) |
159 | - total_data_length += len(train_data_set) | ||
160 | - edge_data_len.append(len(train_data_set)) | ||
161 | 198 | ||
162 | edge_model = copy.deepcopy(fed_model) | 199 | edge_model = copy.deepcopy(fed_model) |
163 | edge_model.to(device) | 200 | edge_model.to(device) |
201 | + edge_model.train() | ||
164 | edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr) | 202 | edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr) |
165 | # train | 203 | # train |
204 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
166 | for data_idx, (inputs, labels) in enumerate(train_loader): | 205 | for data_idx, (inputs, labels) in enumerate(train_loader): |
167 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | 206 | + inputs, labels = inputs.to(device), labels.to(device) |
207 | + | ||
208 | + edge_pred, packet_state = edge_model(inputs, packet_state) | ||
209 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
168 | 210 | ||
169 | edge_opt.zero_grad() | 211 | edge_opt.zero_grad() |
170 | - edge_pred = edge_model(inputs) | ||
171 | 212 | ||
172 | edge_loss = criterion(edge_pred, labels) | 213 | edge_loss = criterion(edge_pred, labels) |
173 | # prox term | 214 | # prox term |
... | @@ -196,30 +237,8 @@ def start_fedprox(fed_model, args, | ... | @@ -196,30 +237,8 @@ def start_fedprox(fed_model, args, |
196 | fed_model.to(device) | 237 | fed_model.to(device) |
197 | 238 | ||
198 | if cr % 10 == 0: | 239 | if cr % 10 == 0: |
199 | - fed_model.eval() | 240 | + test_model(fed_model, args, testloader, device) |
200 | - total_loss = 0.0 | 241 | + fed_model.to(device) |
201 | - cnt = 0 | ||
202 | - step_acc = 0.0 | ||
203 | - with torch.no_grad(): | ||
204 | - for i, data in enumerate(testloader): | ||
205 | - inputs, labels = data | ||
206 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
207 | - | ||
208 | - outputs = fed_model(inputs) | ||
209 | - _, preds = torch.max(outputs, 1) | ||
210 | - | ||
211 | - loss = criterion(outputs, labels) | ||
212 | - cnt += inputs.shape[0] | ||
213 | - | ||
214 | - corr_sum = torch.sum(preds == labels.data) | ||
215 | - step_acc += corr_sum.double() | ||
216 | - running_loss = loss.item() * inputs.shape[0] | ||
217 | - total_loss += running_loss | ||
218 | - if i % 200 == 0: | ||
219 | - print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
220 | - # break | ||
221 | - print((step_acc / cnt).data) | ||
222 | - print(total_loss / cnt) | ||
223 | 242 | ||
224 | 243 | ||
225 | def start_fedtwa(fed_model, args, | 244 | def start_fedtwa(fed_model, args, |
... | @@ -231,10 +250,8 @@ def start_fedtwa(fed_model, args, | ... | @@ -231,10 +250,8 @@ def start_fedtwa(fed_model, args, |
231 | device): | 250 | device): |
232 | # TEFL, without asynchronous model update | 251 | # TEFL, without asynchronous model update |
233 | print("start fed temporally weighted aggregation") | 252 | print("start fed temporally weighted aggregation") |
234 | - criterion = nn.CrossEntropyLoss() | 253 | + |
235 | time_stamp = [0 for worker in range(args.n_nets)] | 254 | time_stamp = [0 for worker in range(args.n_nets)] |
236 | - twa_exp = math.e / 2.0 | ||
237 | - C = 0.1 | ||
238 | num_edge = int(max(C * args.n_nets, 1)) | 255 | num_edge = int(max(C * args.n_nets, 1)) |
239 | total_data_count = 0 | 256 | total_data_count = 0 |
240 | for _, data_count in net_data_count.items(): | 257 | for _, data_count in net_data_count.items(): |
... | @@ -251,20 +268,25 @@ def start_fedtwa(fed_model, args, | ... | @@ -251,20 +268,25 @@ def start_fedtwa(fed_model, args, |
251 | for edge_progress, edge_index in enumerate(selected_edge): | 268 | for edge_progress, edge_index in enumerate(selected_edge): |
252 | time_stamp[edge_index] = cr | 269 | time_stamp[edge_index] = cr |
253 | train_data_set.set_idx_map(data_idx_map[edge_index]) | 270 | train_data_set.set_idx_map(data_idx_map[edge_index]) |
254 | - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | 271 | + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size) |
255 | - shuffle=True, num_workers=2) | 272 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler, |
273 | + shuffle=False, num_workers=2, drop_last=True) | ||
256 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | 274 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) |
257 | 275 | ||
258 | edges[edge_index] = copy.deepcopy(fed_model) | 276 | edges[edge_index] = copy.deepcopy(fed_model) |
259 | edges[edge_index].to(device) | 277 | edges[edge_index].to(device) |
260 | edges[edge_index].train() | 278 | edges[edge_index].train() |
261 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | 279 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) |
280 | + | ||
262 | # train | 281 | # train |
282 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
263 | for data_idx, (inputs, labels) in enumerate(train_loader): | 283 | for data_idx, (inputs, labels) in enumerate(train_loader): |
264 | inputs, labels = inputs.float().to(device), labels.long().to(device) | 284 | inputs, labels = inputs.float().to(device), labels.long().to(device) |
265 | 285 | ||
286 | + edge_pred, packet_state = edges[edge_index](inputs, packet_state) | ||
287 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
288 | + | ||
266 | edge_opt.zero_grad() | 289 | edge_opt.zero_grad() |
267 | - edge_pred = edges[edge_index](inputs) | ||
268 | 290 | ||
269 | edge_loss = criterion(edge_pred, labels) | 291 | edge_loss = criterion(edge_pred, labels) |
270 | edge_loss.backward() | 292 | edge_loss.backward() |
... | @@ -277,44 +299,19 @@ def start_fedtwa(fed_model, args, | ... | @@ -277,44 +299,19 @@ def start_fedtwa(fed_model, args, |
277 | edges[edge_index].to('cpu') | 299 | edges[edge_index].to('cpu') |
278 | 300 | ||
279 | # cal weight using time stamp | 301 | # cal weight using time stamp |
302 | + # in paper, cr - time_stamp[k] used, but error is high | ||
280 | update_state = OrderedDict() | 303 | update_state = OrderedDict() |
281 | for k, edge in enumerate(edges): | 304 | for k, edge in enumerate(edges): |
282 | local_state = edge.state_dict() | 305 | local_state = edge.state_dict() |
283 | for key in fed_model.state_dict().keys(): | 306 | for key in fed_model.state_dict().keys(): |
284 | if k == 0: | 307 | if k == 0: |
285 | - update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) | 308 | + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr -2 - time_stamp[k])) |
286 | else: | 309 | else: |
287 | - update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) | 310 | + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr -2 - time_stamp[k])) |
288 | 311 | ||
289 | fed_model.load_state_dict(update_state) | 312 | fed_model.load_state_dict(update_state) |
290 | if cr % 10 == 0: | 313 | if cr % 10 == 0: |
291 | - fed_model.to(device) | 314 | + test_model(fed_model, args, testloader, device) |
292 | - fed_model.eval() | ||
293 | - | ||
294 | - total_loss = 0.0 | ||
295 | - cnt = 0 | ||
296 | - step_acc = 0.0 | ||
297 | - with torch.no_grad(): | ||
298 | - for i, data in enumerate(testloader): | ||
299 | - inputs, labels = data | ||
300 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
301 | - | ||
302 | - outputs = fed_model(inputs) | ||
303 | - _, preds = torch.max(outputs, 1) | ||
304 | - | ||
305 | - loss = criterion(outputs, labels) | ||
306 | - cnt += inputs.shape[0] | ||
307 | - | ||
308 | - corr_sum = torch.sum(preds == labels.data) | ||
309 | - step_acc += corr_sum.double() | ||
310 | - running_loss = loss.item() * inputs.shape[0] | ||
311 | - total_loss += running_loss | ||
312 | - if i % 200 == 0: | ||
313 | - print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
314 | - # break | ||
315 | - print((step_acc / cnt).data) | ||
316 | - print(total_loss / cnt) | ||
317 | - fed_model.to('cpu') | ||
318 | 315 | ||
319 | 316 | ||
320 | def start_feddw(fed_model, args, | 317 | def start_feddw(fed_model, args, |
... | @@ -326,13 +323,8 @@ def start_feddw(fed_model, args, | ... | @@ -326,13 +323,8 @@ def start_feddw(fed_model, args, |
326 | edges, | 323 | edges, |
327 | device): | 324 | device): |
328 | print("start fed Node-aware Dynamic Weighting") | 325 | print("start fed Node-aware Dynamic Weighting") |
326 | + | ||
329 | worker_selected_frequency = [0 for worker in range(args.n_nets)] | 327 | worker_selected_frequency = [0 for worker in range(args.n_nets)] |
330 | - criterion = nn.CrossEntropyLoss() | ||
331 | - H = 0.5 | ||
332 | - P = 0.5 | ||
333 | - G = 0.1 | ||
334 | - R = 0.1 | ||
335 | - alpha, beta, gamma = 30.0/100.0, 50.0/100.0, 20.0/100.0 | ||
336 | num_edge = int(max(G * args.n_nets, 1)) | 328 | num_edge = int(max(G * args.n_nets, 1)) |
337 | 329 | ||
338 | # cal data weight for selecting participants | 330 | # cal data weight for selecting participants |
... | @@ -382,20 +374,25 @@ def start_feddw(fed_model, args, | ... | @@ -382,20 +374,25 @@ def start_feddw(fed_model, args, |
382 | for edge_progress, edge_index in enumerate(selected_edge): | 374 | for edge_progress, edge_index in enumerate(selected_edge): |
383 | worker_selected_frequency[edge_index] += 1 | 375 | worker_selected_frequency[edge_index] += 1 |
384 | train_data_set.set_idx_map(data_idx_map[edge_index]) | 376 | train_data_set.set_idx_map(data_idx_map[edge_index]) |
385 | - train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | 377 | + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size) |
386 | - shuffle=True, num_workers=2) | 378 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler, |
379 | + shuffle=False, num_workers=2, drop_last=True) | ||
387 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | 380 | print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) |
388 | 381 | ||
389 | edges[edge_index] = copy.deepcopy(fed_model) | 382 | edges[edge_index] = copy.deepcopy(fed_model) |
390 | edges[edge_index].to(device) | 383 | edges[edge_index].to(device) |
391 | edges[edge_index].train() | 384 | edges[edge_index].train() |
392 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | 385 | edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) |
386 | + | ||
393 | # train | 387 | # train |
388 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
394 | for data_idx, (inputs, labels) in enumerate(train_loader): | 389 | for data_idx, (inputs, labels) in enumerate(train_loader): |
395 | inputs, labels = inputs.float().to(device), labels.long().to(device) | 390 | inputs, labels = inputs.float().to(device), labels.long().to(device) |
396 | 391 | ||
392 | + edge_pred, packet_state = edges[edge_index](inputs, packet_state) | ||
393 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
394 | + | ||
397 | edge_opt.zero_grad() | 395 | edge_opt.zero_grad() |
398 | - edge_pred = edges[edge_index](inputs) | ||
399 | 396 | ||
400 | edge_loss = criterion(edge_pred, labels) | 397 | edge_loss = criterion(edge_pred, labels) |
401 | edge_loss.backward() | 398 | edge_loss.backward() |
... | @@ -408,17 +405,20 @@ def start_feddw(fed_model, args, | ... | @@ -408,17 +405,20 @@ def start_feddw(fed_model, args, |
408 | 405 | ||
409 | # get edge accuracy using subset of testset | 406 | # get edge accuracy using subset of testset |
410 | edges[edge_index].eval() | 407 | edges[edge_index].eval() |
411 | - print("[%2d/%2d] edge: %d, cal accuracy" % (edge_progress, len(selected_edge), edge_index)) | 408 | + print("[%2d/%2d] edge: %d, cal local accuracy" % (edge_progress, len(selected_edge), edge_index)) |
412 | cnt = 0 | 409 | cnt = 0 |
413 | step_acc = 0.0 | 410 | step_acc = 0.0 |
414 | with torch.no_grad(): | 411 | with torch.no_grad(): |
412 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
415 | for inputs, labels in local_test_loader: | 413 | for inputs, labels in local_test_loader: |
416 | inputs, labels = inputs.float().to(device), labels.long().to(device) | 414 | inputs, labels = inputs.float().to(device), labels.long().to(device) |
417 | 415 | ||
418 | - outputs = edges[edge_index](inputs) | 416 | + edge_pred, packet_state = edges[edge_index](inputs, packet_state) |
419 | - _, preds = torch.max(outputs, 1) | 417 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) |
418 | + | ||
419 | + _, preds = torch.max(edge_pred, 1) | ||
420 | 420 | ||
421 | - loss = criterion(outputs, labels) | 421 | + loss = criterion(edge_pred, labels) |
422 | cnt += inputs.shape[0] | 422 | cnt += inputs.shape[0] |
423 | 423 | ||
424 | corr_sum = torch.sum(preds == labels.data) | 424 | corr_sum = torch.sum(preds == labels.data) |
... | @@ -426,7 +426,7 @@ def start_feddw(fed_model, args, | ... | @@ -426,7 +426,7 @@ def start_feddw(fed_model, args, |
426 | # break | 426 | # break |
427 | 427 | ||
428 | worker_local_accuracy[edge_index] = (step_acc / cnt).item() | 428 | worker_local_accuracy[edge_index] = (step_acc / cnt).item() |
429 | - print(worker_local_accuracy[edge_index]) | 429 | + print('edge local accuracy', worker_local_accuracy[edge_index]) |
430 | edges[edge_index].to('cpu') | 430 | edges[edge_index].to('cpu') |
431 | 431 | ||
432 | # cal weight dynamically | 432 | # cal weight dynamically |
... | @@ -449,59 +449,123 @@ def start_feddw(fed_model, args, | ... | @@ -449,59 +449,123 @@ def start_feddw(fed_model, args, |
449 | 449 | ||
450 | fed_model.load_state_dict(update_state) | 450 | fed_model.load_state_dict(update_state) |
451 | if cr % 10 == 0: | 451 | if cr % 10 == 0: |
452 | - fed_model.to(device) | 452 | + test_model(fed_model, args, testloader, device) |
453 | - fed_model.eval() | ||
454 | - | ||
455 | - total_loss = 0.0 | ||
456 | - cnt = 0 | ||
457 | - step_acc = 0.0 | ||
458 | - with torch.no_grad(): | ||
459 | - for i, data in enumerate(testloader): | ||
460 | - inputs, labels = data | ||
461 | - inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
462 | 453 | ||
463 | - outputs = fed_model(inputs) | ||
464 | - _, preds = torch.max(outputs, 1) | ||
465 | 454 | ||
466 | - loss = criterion(outputs, labels) | 455 | +def start_only_edge(args, |
467 | - cnt += inputs.shape[0] | 456 | + train_data_set, |
457 | + data_idx_map, | ||
458 | + testloader, | ||
459 | + edges, | ||
460 | + device): | ||
461 | + print("start only edge") | ||
462 | + total_epoch = int(args.comm_round * C) | ||
468 | 463 | ||
469 | - corr_sum = torch.sum(preds == labels.data) | 464 | + for cr in range(1, total_epoch + 1): |
470 | - step_acc += corr_sum.double() | 465 | + print("Edge round : %d" % (cr)) |
471 | - running_loss = loss.item() * inputs.shape[0] | 466 | + edge_accuracy_list = [] |
472 | - total_loss += running_loss | 467 | + |
473 | - if i % 200 == 0: | 468 | + for edge_index, edge_model in enumerate(edges): |
474 | - print('test [%4d] loss: %.3f' % (i, loss.item())) | 469 | + train_data_set.set_idx_map(data_idx_map[edge_index]) |
470 | + sampler = dataset.BatchIntervalSampler(len(train_data_set), args.batch_size) | ||
471 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, sampler=sampler, | ||
472 | + shuffle=False, num_workers=2, drop_last=True) | ||
473 | + print("edge[%2d/%2d] data len: %d" % (edge_index, len(edges), len(train_data_set))) | ||
474 | + | ||
475 | + edge_model.to(device) | ||
476 | + edge_model.train() | ||
477 | + edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr) | ||
478 | + | ||
479 | + # train | ||
480 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
481 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
482 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
483 | + | ||
484 | + edge_pred, packet_state = edge_model(inputs, packet_state) | ||
485 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
486 | + | ||
487 | + edge_opt.zero_grad() | ||
488 | + | ||
489 | + edge_loss = criterion(edge_pred, labels) | ||
490 | + edge_loss.backward() | ||
491 | + | ||
492 | + edge_opt.step() | ||
493 | + edge_loss = edge_loss.item() | ||
494 | + if data_idx % 100 == 0: | ||
495 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
475 | # break | 496 | # break |
476 | - print((step_acc / cnt).data) | 497 | + |
477 | - print(total_loss / cnt) | 498 | + # test |
478 | - fed_model.to('cpu') | 499 | + # if cr < 4: |
500 | + # continue | ||
501 | + edge_model.eval() | ||
502 | + total_loss = 0.0 | ||
503 | + cnt = 0 | ||
504 | + step_acc = 0.0 | ||
505 | + with torch.no_grad(): | ||
506 | + packet_state = torch.zeros(args.batch_size, model.STATE_DIM).to(device) | ||
507 | + for i, (inputs, labels) in enumerate(testloader): | ||
508 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
509 | + | ||
510 | + outputs, packet_state = edge_model(inputs, packet_state) | ||
511 | + packet_state = torch.autograd.Variable(packet_state, requires_grad=False) | ||
512 | + | ||
513 | + _, preds = torch.max(outputs, 1) | ||
514 | + | ||
515 | + loss = criterion(outputs, labels) | ||
516 | + cnt += inputs.shape[0] | ||
517 | + | ||
518 | + corr_sum = torch.sum(preds == labels.data) | ||
519 | + step_acc += corr_sum.double() | ||
520 | + running_loss = loss.item() * inputs.shape[0] | ||
521 | + total_loss += running_loss | ||
522 | + if i % 200 == 0: | ||
523 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
524 | + # break | ||
525 | + edge_accuracy = (step_acc / cnt).item() | ||
526 | + edge_accuracy_list.append(edge_accuracy) | ||
527 | + print("edge[%2d/%2d] acc: %.4f" % (edge_index, len(edges), edge_accuracy)) | ||
528 | + edge_model.to('cpu') | ||
529 | + | ||
530 | + # if cr < 4: | ||
531 | + # continue | ||
532 | + edge_accuracy_avg = sum(edge_accuracy_list) / len(edge_accuracy_list) | ||
533 | + torch.save(edges[0].state_dict(), os.path.join(weight_path, 'edge_%d_%.4f.pth' % (cr, edge_accuracy_avg))) | ||
534 | + | ||
479 | 535 | ||
480 | 536 | ||
481 | def start_train(): | 537 | def start_train(): |
482 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | 538 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
483 | - print(device) | 539 | + print('device:', device) |
540 | + | ||
484 | args = add_args(argparse.ArgumentParser()) | 541 | args = add_args(argparse.ArgumentParser()) |
485 | 542 | ||
543 | + # make weight folder | ||
544 | + os.makedirs(args.weight_save_path, exist_ok=True) | ||
545 | + | ||
546 | + # for reproductivity | ||
486 | seed = 0 | 547 | seed = 0 |
487 | np.random.seed(seed) | 548 | np.random.seed(seed) |
488 | torch.manual_seed(seed) | 549 | torch.manual_seed(seed) |
489 | 550 | ||
490 | print("Loading data...") | 551 | print("Loading data...") |
491 | - # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt', | 552 | + if args.dataset == 'can': |
492 | - # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt', | 553 | + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") |
493 | - # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt', | 554 | + elif args.dataset == 'syncan': |
494 | - # "./dataset/gear_dataset.csv" : './gear_dataset.txt' | 555 | + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt") |
495 | - # } | 556 | + |
496 | - kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | 557 | + sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) |
497 | - train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs) | 558 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, |
498 | - testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, | 559 | + shuffle=False, num_workers=2, drop_last=True) |
499 | - shuffle=False, num_workers=2) | 560 | + |
500 | - | 561 | + if args.dataset == 'can': |
501 | - fed_model = model.Net() | 562 | + fed_model = model.OneNet(args.packet_num) |
502 | - args.comm_type = 'feddw' | 563 | + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] |
564 | + elif args.dataset == 'syncan': | ||
565 | + fed_model = model.OneNet(args.packet_num) | ||
566 | + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] | ||
567 | + | ||
503 | if args.comm_type == "fedavg": | 568 | if args.comm_type == "fedavg": |
504 | - edges, _, _ = init_models(args.n_nets, args) | ||
505 | start_fedavg(fed_model, args, | 569 | start_fedavg(fed_model, args, |
506 | train_data_set, | 570 | train_data_set, |
507 | data_idx_map, | 571 | data_idx_map, |
... | @@ -516,7 +580,6 @@ def start_train(): | ... | @@ -516,7 +580,6 @@ def start_train(): |
516 | testloader, | 580 | testloader, |
517 | device) | 581 | device) |
518 | elif args.comm_type == "fedtwa": | 582 | elif args.comm_type == "fedtwa": |
519 | - edges, _, _ = init_models(args.n_nets, args) | ||
520 | start_fedtwa(fed_model, args, | 583 | start_fedtwa(fed_model, args, |
521 | train_data_set, | 584 | train_data_set, |
522 | data_idx_map, | 585 | data_idx_map, |
... | @@ -526,14 +589,14 @@ def start_train(): | ... | @@ -526,14 +589,14 @@ def start_train(): |
526 | device) | 589 | device) |
527 | elif args.comm_type == "feddw": | 590 | elif args.comm_type == "feddw": |
528 | local_test_set = copy.deepcopy(test_data_set) | 591 | local_test_set = copy.deepcopy(test_data_set) |
529 | - # mnist train 60,000 / test 10,000 / 1,000 | 592 | + # in paper, mnist train 60,000 / test 10,000 / 1,000 - 10% |
530 | - # CAN train ~ 13,000,000 / test 2,000,000 / for speed 40,000 | 593 | + # CAN train ~ 1,400,000 / test 300,000 / for speed 15,000 - 5% |
531 | - local_test_idx = np.random.choice(len(local_test_set), len(local_test_set) // 50, replace=False) | 594 | + local_test_idx = [idx for idx in range(0, len(local_test_set) // 20)] |
532 | local_test_set.set_idx_map(local_test_idx) | 595 | local_test_set.set_idx_map(local_test_idx) |
533 | - local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size, | 596 | + sampler = dataset.BatchIntervalSampler(len(local_test_set), args.batch_size) |
534 | - shuffle=False, num_workers=2) | 597 | + local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size, sampler=sampler, |
598 | + shuffle=False, num_workers=2, drop_last=True) | ||
535 | 599 | ||
536 | - edges, _, _ = init_models(args.n_nets, args) | ||
537 | start_feddw(fed_model, args, | 600 | start_feddw(fed_model, args, |
538 | train_data_set, | 601 | train_data_set, |
539 | data_idx_map, | 602 | data_idx_map, |
... | @@ -542,6 +605,13 @@ def start_train(): | ... | @@ -542,6 +605,13 @@ def start_train(): |
542 | local_test_loader, | 605 | local_test_loader, |
543 | edges, | 606 | edges, |
544 | device) | 607 | device) |
608 | + elif args.comm_type == "edge": | ||
609 | + start_only_edge(args, | ||
610 | + train_data_set, | ||
611 | + data_idx_map, | ||
612 | + testloader, | ||
613 | + edges, | ||
614 | + device) | ||
545 | 615 | ||
546 | if __name__ == "__main__": | 616 | if __name__ == "__main__": |
547 | start_train() | 617 | start_train() |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
1 | import torch.nn as nn | 1 | import torch.nn as nn |
2 | -import torch.nn.functional as F | ||
3 | import torch | 2 | import torch |
4 | import const | 3 | import const |
5 | 4 | ||
6 | -class Net(nn.Module): | ||
7 | - def __init__(self): | ||
8 | - super(Net, self).__init__() | ||
9 | 5 | ||
10 | - self.f1 = nn.Sequential( | 6 | +STATE_DIM = 8 * 32 |
11 | - nn.Conv2d(1, 2, 3), | 7 | +class OneNet(nn.Module): |
12 | - nn.ReLU(True), | 8 | + def __init__(self, packet_num): |
9 | + super(OneNet, self).__init__() | ||
10 | + IN_DIM = 8 * packet_num # byte | ||
11 | + FEATURE_DIM = 32 | ||
12 | + | ||
13 | + # transform the given packet into a tensor which is in a good feature space | ||
14 | + self.feature_layer = nn.Sequential( | ||
15 | + nn.Linear(IN_DIM, 32), | ||
16 | + nn.ReLU(), | ||
17 | + nn.Linear(32, FEATURE_DIM), | ||
18 | + nn.ReLU() | ||
13 | ) | 19 | ) |
14 | - self.f2 = nn.Sequential( | 20 | + |
15 | - nn.Conv2d(2, 4, 3), | 21 | + # generates the current state 's' |
16 | - nn.ReLU(True), | 22 | + self.f = nn.Sequential( |
17 | - ) | 23 | + nn.Linear(STATE_DIM + FEATURE_DIM, STATE_DIM), |
18 | - self.f3 = nn.Sequential( | 24 | + nn.ReLU(), |
19 | - nn.Conv2d(4, 8, 3), | 25 | + nn.Linear(STATE_DIM, STATE_DIM), |
20 | - nn.ReLU(True), | 26 | + nn.ReLU() |
21 | ) | 27 | ) |
22 | - self.f4 = nn.Sequential( | 28 | + |
23 | - nn.Linear(8 * 23 * 23, 2), | 29 | + # check whether the given packet is malicious |
30 | + self.g = nn.Sequential( | ||
31 | + nn.Linear(STATE_DIM + FEATURE_DIM, 64), | ||
32 | + nn.ReLU(), | ||
33 | + nn.Linear(64, 64), | ||
34 | + nn.ReLU(), | ||
35 | + nn.Linear(64, 2), | ||
24 | ) | 36 | ) |
25 | 37 | ||
26 | - def forward(self, x): | ||
27 | - x = self.f1(x) | ||
28 | - x = self.f2(x) | ||
29 | - x = self.f3(x) | ||
30 | - x = torch.flatten(x, 1) | ||
31 | - x = self.f4(x) | ||
32 | - return x | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
38 | + def forward(self, x, s): | ||
39 | + x = self.feature_layer(x) | ||
40 | + x = torch.cat((x, s), 1) | ||
41 | + s2 = self.f(x) | ||
42 | + x2 = self.g(x) | ||
43 | + | ||
44 | + return x2, s2 | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
코드/연합학습/quantization/onnx2tensorRT.py
0 → 100644
1 | +import tensorrt as trt | ||
2 | + | ||
3 | +onnx_file_name = 'bert.onnx' | ||
4 | +tensorrt_file_name = 'bert.plan' | ||
5 | +fp16_mode = True | ||
6 | +# int8_mode = True | ||
7 | +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||
8 | +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | ||
9 | + | ||
10 | +builder = trt.Builder(TRT_LOGGER) | ||
11 | +network = builder.create_network(EXPLICIT_BATCH) | ||
12 | +parser = trt.OnnxParser(network, TRT_LOGGER) | ||
13 | + | ||
14 | +builder.max_workspace_size = (1 << 30) | ||
15 | +builder.fp16_mode = fp16_mode | ||
16 | +# builder.int8_mode = int8_mode | ||
17 | + | ||
18 | +with open(onnx_file_name, 'rb') as model: | ||
19 | + if not parser.parse(model.read()): | ||
20 | + for error in range(parser.num_errors): | ||
21 | + print (parser.get_error(error)) | ||
22 | + | ||
23 | +# for int8 mode | ||
24 | +# print(network.num_layers, network.num_inputs , network.num_outputs) | ||
25 | +# for layer_index in range(network.num_layers): | ||
26 | +# layer = network[layer_index] | ||
27 | +# print(layer.name) | ||
28 | +# tensor = layer.get_output(0) | ||
29 | +# print(tensor.name) | ||
30 | +# tensor.dynamic_range = (0, 255) | ||
31 | + | ||
32 | + # input_tensor = layer.get_input(0) | ||
33 | + # print(input_tensor) | ||
34 | + # input_tensor.dynamic_range = (0, 255) | ||
35 | + | ||
36 | +engine = builder.build_cuda_engine(network) | ||
37 | +buf = engine.serialize() | ||
38 | +with open(tensorrt_file_name, 'wb') as f: | ||
39 | + f.write(buf) | ||
40 | + | ||
41 | +print('done, trt model') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/연합학습/quantization/tensorRT_test.py
0 → 100644
1 | +import tensorrt as trt | ||
2 | +import pycuda.driver as cuda | ||
3 | +import numpy as np | ||
4 | +import torch | ||
5 | +import pycuda.autoinit | ||
6 | +import dataset | ||
7 | +import model | ||
8 | +import time | ||
9 | +# print(dir(trt)) | ||
10 | + | ||
11 | +tensorrt_file_name = 'bert.plan' | ||
12 | +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||
13 | +trt_runtime = trt.Runtime(TRT_LOGGER) | ||
14 | + | ||
15 | +with open(tensorrt_file_name, 'rb') as f: | ||
16 | + engine_data = f.read() | ||
17 | +engine = trt_runtime.deserialize_cuda_engine(engine_data) | ||
18 | +context = engine.create_execution_context() | ||
19 | + | ||
20 | +# class HostDeviceMem(object): | ||
21 | +# def __init__(self, host_mem, device_mem): | ||
22 | +# self.host = host_mem | ||
23 | +# self.device = device_mem | ||
24 | + | ||
25 | +# def __str__(self): | ||
26 | +# return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) | ||
27 | + | ||
28 | +# def __repr__(self): | ||
29 | +# return self.__str__() | ||
30 | + | ||
31 | +# inputs, outputs, bindings, stream = [], [], [], [] | ||
32 | +# for binding in engine: | ||
33 | +# size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size | ||
34 | +# dtype = trt.nptype(engine.get_binding_dtype(binding)) | ||
35 | +# host_mem = cuda.pagelocked_empty(size, dtype) | ||
36 | +# device_mem = cuda.mem_alloc(host_mem.nbytes) | ||
37 | +# bindings.append(int(device_mem)) | ||
38 | +# if engine.binding_is_input(binding): | ||
39 | +# inputs.append( HostDeviceMem(host_mem, device_mem) ) | ||
40 | +# else: | ||
41 | +# outputs.append(HostDeviceMem(host_mem, device_mem)) | ||
42 | + | ||
43 | +# input_ids = np.ones([1, 1, 29, 29]) | ||
44 | + | ||
45 | +# numpy_array_input = [input_ids] | ||
46 | +# hosts = [input.host for input in inputs] | ||
47 | +# trt_types = [trt.int32] | ||
48 | + | ||
49 | +# for numpy_array, host, trt_types in zip(numpy_array_input, hosts, trt_types): | ||
50 | +# numpy_array = np.asarray(numpy_array).ravel() | ||
51 | +# np.copyto(host, numpy_array) | ||
52 | + | ||
53 | +# def do_inference(context, bindings, inputs, outputs, stream): | ||
54 | +# [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] | ||
55 | +# context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) | ||
56 | +# [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] | ||
57 | +# stream.synchronize() | ||
58 | +# return [out.host for out in outputs] | ||
59 | + | ||
60 | +# trt_outputs = do_inference( | ||
61 | +# context=context, | ||
62 | +# bindings=bindings, | ||
63 | +# inputs=inputs, | ||
64 | +# outputs=outputs, | ||
65 | +# stream=stream) | ||
66 | + | ||
67 | +def infer(context, input_img, output_size, batch_size): | ||
68 | + # Load engine | ||
69 | + # engine = context.get_engine() | ||
70 | + # assert(engine.get_nb_bindings() == 2) | ||
71 | + # Convert input data to float32 | ||
72 | + input_img = input_img.astype(np.float32) | ||
73 | + # Create host buffer to receive data | ||
74 | + output = np.empty(output_size, dtype = np.float32) | ||
75 | + # Allocate device memory | ||
76 | + d_input = cuda.mem_alloc(batch_size * input_img.size * input_img.dtype.itemsize) | ||
77 | + d_output = cuda.mem_alloc(batch_size * output.size * output.dtype.itemsize) | ||
78 | + | ||
79 | + bindings = [int(d_input), int(d_output)] | ||
80 | + stream = cuda.Stream() | ||
81 | + # Transfer input data to device | ||
82 | + cuda.memcpy_htod_async(d_input, input_img, stream) | ||
83 | + # Execute model | ||
84 | + context.execute_async(batch_size, bindings, stream.handle, None) | ||
85 | + # Transfer predictions back | ||
86 | + cuda.memcpy_dtoh_async(output, d_output, stream) | ||
87 | + # Synchronize threads | ||
88 | + stream.synchronize() | ||
89 | + # Return predictions | ||
90 | + return output | ||
91 | + | ||
92 | + | ||
93 | +# kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | ||
94 | +# train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(100, 0, **kwargs) | ||
95 | +# testloader = torch.utils.data.DataLoader(test_data_set, batch_size=256, | ||
96 | +# shuffle=False, num_workers=2) | ||
97 | + | ||
98 | +check_time = time.time() | ||
99 | +cnt = 0 | ||
100 | +temp = np.ones([256, 1, 29, 29]) | ||
101 | +for idx in range(100): | ||
102 | +# for i, (inputs, labels) in enumerate(testloader): | ||
103 | + trt_outputs = infer(context, temp, (256, 2), 256) | ||
104 | + | ||
105 | + print(trt_outputs.shape) | ||
106 | + # print(trt_outputs) | ||
107 | + # print(np.argmax(trt_outputs, axis=0)) | ||
108 | + # cnt += 1 | ||
109 | + # if cnt == 100: | ||
110 | + # break | ||
111 | +print(time.time() - check_time) | ||
112 | + | ||
113 | + | ||
114 | +tensorrt_file_name = 'bert_int.plan' | ||
115 | +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||
116 | +trt_runtime = trt.Runtime(TRT_LOGGER) | ||
117 | + | ||
118 | +with open(tensorrt_file_name, 'rb') as f: | ||
119 | + engine_data = f.read() | ||
120 | +engine = trt_runtime.deserialize_cuda_engine(engine_data) | ||
121 | +context = engine.create_execution_context() | ||
122 | +check_time = time.time() | ||
123 | +cnt = 0 | ||
124 | +temp = np.ones([256, 1, 29, 29]) | ||
125 | +for idx in range(100): | ||
126 | +# for i, (inputs, labels) in enumerate(testloader): | ||
127 | + trt_outputs = infer(context, temp, (256, 2), 256) | ||
128 | + | ||
129 | + print(trt_outputs.shape) | ||
130 | + # print(trt_outputs) | ||
131 | + # print(np.argmax(trt_outputs, axis=0)) | ||
132 | + # cnt += 1 | ||
133 | + # if cnt == 100: | ||
134 | + # break | ||
135 | +print(time.time() - check_time) | ||
136 | + | ||
137 | + | ||
138 | +test_model = model.Net().cuda() | ||
139 | +check_time = time.time() | ||
140 | +cnt = 0 | ||
141 | +temp = torch.randn(256, 1, 29, 29).cuda() | ||
142 | +for idx in range(100): | ||
143 | +# for i, (inputs, labels) in enumerate(testloader): | ||
144 | + # inputs = inputs.float().cuda() | ||
145 | + normal_outputs = test_model(temp) | ||
146 | + # print(normal_outputs) | ||
147 | + print(normal_outputs.shape) | ||
148 | + cnt += 1 | ||
149 | + if cnt == 100: | ||
150 | + break | ||
151 | +print(time.time() - check_time) | ||
152 | + | ||
153 | + | ||
154 | + | ||
155 | +import tensorrt as trt | ||
156 | +import numpy as np | ||
157 | +import pycuda.autoinit | ||
158 | +import pycuda.driver as cuda | ||
159 | +import time | ||
160 | + | ||
161 | +model_path = "bert.onnx" | ||
162 | +input_size = 32 | ||
163 | + | ||
164 | +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||
165 | + | ||
166 | +# def build_engine(model_path): | ||
167 | +# with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser: | ||
168 | +# builder.max_workspace_size = 1<<20 | ||
169 | +# builder.max_batch_size = 1 | ||
170 | +# with open(model_path, "rb") as f: | ||
171 | +# parser.parse(f.read()) | ||
172 | +# engine = builder.build_cuda_engine(network) | ||
173 | +# return engine | ||
174 | + | ||
175 | +def alloc_buf(engine): | ||
176 | + # host cpu mem | ||
177 | + h_in_size = trt.volume(engine.get_binding_shape(0)) | ||
178 | + h_out_size = trt.volume(engine.get_binding_shape(1)) | ||
179 | + h_in_dtype = trt.nptype(engine.get_binding_dtype(0)) | ||
180 | + h_out_dtype = trt.nptype(engine.get_binding_dtype(1)) | ||
181 | + in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype) | ||
182 | + out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype) | ||
183 | + # allocate gpu mem | ||
184 | + in_gpu = cuda.mem_alloc(in_cpu.nbytes) | ||
185 | + out_gpu = cuda.mem_alloc(out_cpu.nbytes) | ||
186 | + stream = cuda.Stream() | ||
187 | + return in_cpu, out_cpu, in_gpu, out_gpu, stream | ||
188 | + | ||
189 | + | ||
190 | +def inference(engine, context, inputs, out_cpu, in_gpu, out_gpu, stream): | ||
191 | + # async version | ||
192 | + # with engine.create_execution_context() as context: # cost time to initialize | ||
193 | + # cuda.memcpy_htod_async(in_gpu, inputs, stream) | ||
194 | + # context.execute_async(1, [int(in_gpu), int(out_gpu)], stream.handle, None) | ||
195 | + # cuda.memcpy_dtoh_async(out_cpu, out_gpu, stream) | ||
196 | + # stream.synchronize() | ||
197 | + | ||
198 | + # sync version | ||
199 | + cuda.memcpy_htod(in_gpu, inputs) | ||
200 | + context.execute(1, [int(in_gpu), int(out_gpu)]) | ||
201 | + cuda.memcpy_dtoh(out_cpu, out_gpu) | ||
202 | + return out_cpu | ||
203 | + | ||
204 | +if __name__ == "__main__": | ||
205 | + inputs = np.random.random((1, 1, 29, 29)).astype(np.float32) | ||
206 | + | ||
207 | + tensorrt_file_name = '/content/drive/My Drive/capstone1/CAN/bert.plan' | ||
208 | + TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||
209 | + trt_runtime = trt.Runtime(TRT_LOGGER) | ||
210 | + | ||
211 | + with open(tensorrt_file_name, 'rb') as f: | ||
212 | + engine_data = f.read() | ||
213 | + engine = trt_runtime.deserialize_cuda_engine(engine_data) | ||
214 | + # engine = build_engine(model_path) | ||
215 | + context = engine.create_execution_context() | ||
216 | + for _ in range(10): | ||
217 | + t1 = time.time() | ||
218 | + in_cpu, out_cpu, in_gpu, out_gpu, stream = alloc_buf(engine) | ||
219 | + res = inference(engine, context, inputs.reshape(-1), out_cpu, in_gpu, out_gpu, stream) | ||
220 | + print(res) | ||
221 | + print("cost time: ", time.time()-t1) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/연합학습/quantization/torch2onnx.py
0 → 100644
1 | +import model | ||
2 | +import torch | ||
3 | + | ||
4 | +import importlib | ||
5 | +importlib.reload(model) | ||
6 | + | ||
7 | +batch_size = 256 | ||
8 | +model = model.Net().cuda().eval() | ||
9 | +inputs = torch.randn(batch_size, 1, 29, 29, requires_grad=True).cuda() | ||
10 | +torch_out = model(inputs) | ||
11 | + | ||
12 | +torch.onnx.export( | ||
13 | + model, | ||
14 | + inputs, | ||
15 | + 'bert.onnx', | ||
16 | + input_names=['inputs'], | ||
17 | + output_names=['outputs'], | ||
18 | + export_params=True) | ||
19 | + | ||
20 | +print('done, onnx model') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/연합학습/utils.py
0 → 100644
1 | +import pandas as pd | ||
2 | +import numpy as np | ||
3 | +import csv | ||
4 | +import os | ||
5 | +import const | ||
6 | +from matplotlib import pyplot as plt | ||
7 | + | ||
8 | + | ||
9 | +def run_benchmark_cnn(): | ||
10 | + import sys | ||
11 | + sys.path.append("/content/drive/My Drive/capstone1/CAN/torch2trt") | ||
12 | + from torch2trt import torch2trt | ||
13 | + import model | ||
14 | + import time | ||
15 | + import torch | ||
16 | + import dataset | ||
17 | + import torch.nn as nn | ||
18 | + | ||
19 | + test_model = model.CnnNet() | ||
20 | + test_model.eval().cuda() | ||
21 | + | ||
22 | + batch_size = 1 | ||
23 | + # _, _, _, test_data_set = dataset.GetCanDataset(100, 0, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") | ||
24 | + | ||
25 | + # sampler = dataset.BatchIntervalSampler(len(test_data_set), batch_size) | ||
26 | + # testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler, | ||
27 | + # shuffle=False, num_workers=2, drop_last=True) | ||
28 | + | ||
29 | + # create model and input data | ||
30 | + # for inputs, labels in testloader: | ||
31 | + # trt_x = inputs.float().cuda() | ||
32 | + # trt_state = torch.zeros((batch_size, 8 * 32)).float().cuda() | ||
33 | + # trt_model = model.OneNet() | ||
34 | + # trt_model.load_state_dict(torch.load(weight_path)) | ||
35 | + # trt_model.float().eval().cuda() | ||
36 | + | ||
37 | + # trt_f16_x = inputs.half().cuda() | ||
38 | + # trt_f16_state = torch.zeros((batch_size, 8 * 32)).half().cuda() | ||
39 | + # trt_f16_model = model.OneNet() | ||
40 | + # trt_f16_model.load_state_dict(torch.load(weight_path)) | ||
41 | + # trt_f16_model.half().eval().cuda() | ||
42 | + | ||
43 | + # trt_int8_strict_x = inputs.float().cuda() | ||
44 | + # trt_int8_strict_state = torch.zeros((batch_size, 8 * 32)).float().cuda() # match model weight | ||
45 | + # trt_int8_strict_model = model.OneNet() | ||
46 | + # trt_int8_strict_model.load_state_dict(torch.load(weight_path)) | ||
47 | + # trt_int8_strict_model.eval().cuda() # no attribute 'char' | ||
48 | + | ||
49 | + # break | ||
50 | + | ||
51 | + inputs = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)) | ||
52 | + | ||
53 | + | ||
54 | + trt_x = inputs.half().cuda() # ??? densenet error? | ||
55 | + trt_model = model.CnnNet() | ||
56 | + # trt_model.load_state_dict(torch.load(weight_path)) | ||
57 | + trt_model.eval().cuda() | ||
58 | + | ||
59 | + trt_f16_x = inputs.half().cuda() | ||
60 | + trt_f16_model = model.CnnNet().half() | ||
61 | + # trt_f16_model.load_state_dict(torch.load(weight_path)) | ||
62 | + trt_f16_model.half().eval().cuda() | ||
63 | + | ||
64 | + trt_int8_strict_x = inputs.half().cuda() # match model weight | ||
65 | + trt_int8_strict_model = model.CnnNet() | ||
66 | + # trt_int8_strict_model.load_state_dict(torch.load(weight_path)) | ||
67 | + trt_int8_strict_model.eval().cuda() # no attribute 'char' | ||
68 | + | ||
69 | + # convert to TensorRT feeding sample data as input | ||
70 | + model_trt = torch2trt(trt_model, [trt_x], max_batch_size=batch_size) | ||
71 | + model_trt_f16 = torch2trt(trt_f16_model, [trt_f16_x], fp16_mode=True, max_batch_size=batch_size) | ||
72 | + model_trt_int8_strict = torch2trt(trt_int8_strict_model, [trt_int8_strict_x], fp16_mode=False, int8_mode=True, strict_type_constraints=True, max_batch_size=batch_size) | ||
73 | + | ||
74 | + # testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler, | ||
75 | + # shuffle=False, num_workers=2, drop_last=True) | ||
76 | + | ||
77 | + with torch.no_grad(): | ||
78 | + ### test inference time | ||
79 | + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).half().cuda() | ||
80 | + dummy_cnt = 10000 | ||
81 | + print('ignore data loading time, inference random data') | ||
82 | + | ||
83 | + check_time = time.time() | ||
84 | + for i in range(dummy_cnt): | ||
85 | + _ = test_model(dummy_x) | ||
86 | + print('torch model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
87 | + | ||
88 | + check_time = time.time() | ||
89 | + for i in range(dummy_cnt): | ||
90 | + _ = model_trt(dummy_x) | ||
91 | + print('trt model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
92 | + | ||
93 | + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).half().cuda() | ||
94 | + check_time = time.time() | ||
95 | + for i in range(dummy_cnt): | ||
96 | + _ = model_trt_f16(dummy_x) | ||
97 | + print('trt float 16 model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
98 | + | ||
99 | + dummy_x = torch.ones((batch_size, 1, const.CNN_FRAME_LEN, const.CNN_FRAME_LEN)).char().cuda() | ||
100 | + check_time = time.time() | ||
101 | + for i in range(dummy_cnt): | ||
102 | + _ = model_trt_int8_strict(dummy_x) | ||
103 | + print('trt int8 strict model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
104 | + return | ||
105 | + ## end | ||
106 | + | ||
107 | + criterion = nn.CrossEntropyLoss() | ||
108 | + state_temp = torch.zeros((batch_size, 8 * 32)).cuda() | ||
109 | + step_acc = 0.0 | ||
110 | + step_loss = 0.0 | ||
111 | + cnt = 0 | ||
112 | + loss_cnt = 0 | ||
113 | + for i, (inputs, labels) in enumerate(testloader): | ||
114 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
115 | + normal_outputs, state_temp = test_model(inputs, state_temp) | ||
116 | + | ||
117 | + _, preds = torch.max(normal_outputs, 1) | ||
118 | + edge_loss = criterion(normal_outputs, labels) | ||
119 | + step_loss += edge_loss.item() | ||
120 | + loss_cnt += 1 | ||
121 | + | ||
122 | + corr_sum = torch.sum(preds == labels.data) | ||
123 | + step_acc += corr_sum.double() | ||
124 | + cnt += batch_size | ||
125 | + print('torch', step_acc.item() / cnt, step_loss / loss_cnt) | ||
126 | + | ||
127 | + state_temp = torch.zeros((batch_size, 8 * 32)).cuda() | ||
128 | + step_acc = 0.0 | ||
129 | + cnt = 0 | ||
130 | + step_loss = 0.0 | ||
131 | + loss_cnt = 0 | ||
132 | + for i, (inputs, labels) in enumerate(testloader): | ||
133 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
134 | + normal_outputs, state_temp = model_trt(inputs, state_temp) | ||
135 | + | ||
136 | + _, preds = torch.max(normal_outputs, 1) | ||
137 | + edge_loss = criterion(normal_outputs, labels) | ||
138 | + step_loss += edge_loss.item() | ||
139 | + loss_cnt += 1 | ||
140 | + | ||
141 | + corr_sum = torch.sum(preds == labels.data) | ||
142 | + step_acc += corr_sum.double() | ||
143 | + cnt += batch_size | ||
144 | + print('trt', step_acc.item() / cnt, step_loss / loss_cnt) | ||
145 | + | ||
146 | + state_temp = torch.zeros((batch_size, 8 * 32)).half().cuda() | ||
147 | + step_acc = 0.0 | ||
148 | + cnt = 0 | ||
149 | + step_loss = 0.0 | ||
150 | + loss_cnt = 0 | ||
151 | + for i, (inputs, labels) in enumerate(testloader): | ||
152 | + inputs, labels = inputs.half().cuda(), labels.long().cuda() | ||
153 | + normal_outputs, state_temp = model_trt_f16(inputs, state_temp) | ||
154 | + | ||
155 | + _, preds = torch.max(normal_outputs, 1) | ||
156 | + edge_loss = criterion(normal_outputs, labels) | ||
157 | + step_loss += edge_loss.item() | ||
158 | + loss_cnt += 1 | ||
159 | + | ||
160 | + corr_sum = torch.sum(preds == labels.data) | ||
161 | + step_acc += corr_sum.double() | ||
162 | + cnt += batch_size | ||
163 | + print('float16', step_acc.item() / cnt, step_loss / loss_cnt) | ||
164 | + | ||
165 | + state_temp = torch.zeros((batch_size, 8 * 32)).float().cuda() | ||
166 | + step_acc = 0.0 | ||
167 | + cnt = 0 | ||
168 | + step_loss = 0.0 | ||
169 | + loss_cnt = 0 | ||
170 | + for i, (inputs, labels) in enumerate(testloader): | ||
171 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
172 | + normal_outputs, state_temp = model_trt_int8_strict(inputs, state_temp) | ||
173 | + | ||
174 | + _, preds = torch.max(normal_outputs, 1) | ||
175 | + edge_loss = criterion(normal_outputs, labels) | ||
176 | + step_loss += edge_loss.item() | ||
177 | + loss_cnt += 1 | ||
178 | + | ||
179 | + corr_sum = torch.sum(preds == labels.data) | ||
180 | + step_acc += corr_sum.double() | ||
181 | + cnt += batch_size | ||
182 | + print('int8 strict', step_acc.item() / cnt, step_loss / loss_cnt) | ||
183 | + | ||
184 | + | ||
185 | +def run_benchmark(weight_path): | ||
186 | + import sys | ||
187 | + sys.path.append("/content/drive/My Drive/capstone1/CAN/torch2trt") | ||
188 | + from torch2trt import torch2trt | ||
189 | + import model | ||
190 | + import time | ||
191 | + import torch | ||
192 | + import dataset | ||
193 | + import torch.nn as nn | ||
194 | + | ||
195 | + test_model = model.OneNet() | ||
196 | + test_model.load_state_dict(torch.load(weight_path)) | ||
197 | + test_model.eval().cuda() | ||
198 | + | ||
199 | + batch_size = 1 | ||
200 | + _, _, _, test_data_set = dataset.GetCanDataset(100, 0, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") | ||
201 | + | ||
202 | + sampler = dataset.BatchIntervalSampler(len(test_data_set), batch_size) | ||
203 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler, | ||
204 | + shuffle=False, num_workers=2, drop_last=True) | ||
205 | + | ||
206 | + # create model and input data | ||
207 | + for inputs, labels in testloader: | ||
208 | + # inputs = torch.cat([inputs, inputs, inputs], 1) | ||
209 | + | ||
210 | + trt_x = inputs.float().cuda() | ||
211 | + trt_state = torch.zeros((batch_size, 8 * 32)).float().cuda() | ||
212 | + trt_model = model.OneNet() | ||
213 | + trt_model.load_state_dict(torch.load(weight_path)) | ||
214 | + trt_model.float().eval().cuda() | ||
215 | + | ||
216 | + trt_f16_x = inputs.half().cuda() | ||
217 | + trt_f16_state = torch.zeros((batch_size, 8 * 32)).half().cuda() | ||
218 | + trt_f16_model = model.OneNet().half() | ||
219 | + trt_f16_model.load_state_dict(torch.load(weight_path)) | ||
220 | + trt_f16_model.half().eval().cuda() | ||
221 | + | ||
222 | + trt_int8_strict_x = inputs.float().cuda() | ||
223 | + trt_int8_strict_state = torch.zeros((batch_size, 8 * 32)).float().cuda() # match model weight | ||
224 | + trt_int8_strict_model = model.OneNet() | ||
225 | + trt_int8_strict_model.load_state_dict(torch.load(weight_path)) | ||
226 | + trt_int8_strict_model.eval().cuda() # no attribute 'char' | ||
227 | + | ||
228 | + break | ||
229 | + | ||
230 | + # convert to TensorRT feeding sample data as input | ||
231 | + model_trt = torch2trt(trt_model, [trt_x, trt_state], max_batch_size=batch_size) | ||
232 | + model_trt_f16 = torch2trt(trt_f16_model, [trt_f16_x, trt_f16_state], fp16_mode=True, max_batch_size=batch_size) | ||
233 | + model_trt_int8_strict = torch2trt(trt_int8_strict_model, [trt_int8_strict_x, trt_int8_strict_state], fp16_mode=False, int8_mode=True, strict_type_constraints=True, max_batch_size=batch_size) | ||
234 | + | ||
235 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, sampler=sampler, | ||
236 | + shuffle=False, num_workers=2, drop_last=True) | ||
237 | + | ||
238 | + with torch.no_grad(): | ||
239 | + ### test inference time | ||
240 | + dummy_x = torch.ones((batch_size, 8)).cuda() | ||
241 | + dummy_state = torch.zeros(batch_size, model.STATE_DIM).cuda() | ||
242 | + dummy_cnt = 10000 | ||
243 | + print('ignore data loading time, inference random data') | ||
244 | + | ||
245 | + check_time = time.time() | ||
246 | + for i in range(dummy_cnt): | ||
247 | + _, _ = test_model(dummy_x, dummy_state) | ||
248 | + print('torch model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
249 | + | ||
250 | + check_time = time.time() | ||
251 | + for i in range(dummy_cnt): | ||
252 | + _, _ = model_trt(dummy_x, dummy_state) | ||
253 | + print('trt model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
254 | + | ||
255 | + dummy_x = torch.ones((batch_size, 8)).half().cuda() | ||
256 | + dummy_state = torch.zeros(batch_size, model.STATE_DIM).half().cuda() | ||
257 | + check_time = time.time() | ||
258 | + for i in range(dummy_cnt): | ||
259 | + _, _ = model_trt_f16(dummy_x, dummy_state) | ||
260 | + print('trt float 16 model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
261 | + | ||
262 | + dummy_x = torch.ones((batch_size, 8)).char().cuda() | ||
263 | + dummy_state = torch.zeros(batch_size, model.STATE_DIM).char().cuda() | ||
264 | + check_time = time.time() | ||
265 | + for i in range(dummy_cnt): | ||
266 | + _, _ = model_trt_int8_strict(dummy_x, dummy_state) | ||
267 | + print('trt int8 strict model: %.6f' % ((time.time() - check_time) / dummy_cnt)) | ||
268 | + return | ||
269 | + ## end | ||
270 | + | ||
271 | + criterion = nn.CrossEntropyLoss() | ||
272 | + state_temp = torch.zeros((batch_size, 8 * 32)).cuda() | ||
273 | + step_acc = 0.0 | ||
274 | + step_loss = 0.0 | ||
275 | + cnt = 0 | ||
276 | + loss_cnt = 0 | ||
277 | + for i, (inputs, labels) in enumerate(testloader): | ||
278 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
279 | + normal_outputs, state_temp = test_model(inputs, state_temp) | ||
280 | + | ||
281 | + _, preds = torch.max(normal_outputs, 1) | ||
282 | + edge_loss = criterion(normal_outputs, labels) | ||
283 | + step_loss += edge_loss.item() | ||
284 | + loss_cnt += 1 | ||
285 | + | ||
286 | + corr_sum = torch.sum(preds == labels.data) | ||
287 | + step_acc += corr_sum.double() | ||
288 | + cnt += batch_size | ||
289 | + print('torch', step_acc.item() / cnt, step_loss / loss_cnt) | ||
290 | + | ||
291 | + state_temp = torch.zeros((batch_size, 8 * 32)).cuda() | ||
292 | + step_acc = 0.0 | ||
293 | + cnt = 0 | ||
294 | + step_loss = 0.0 | ||
295 | + loss_cnt = 0 | ||
296 | + for i, (inputs, labels) in enumerate(testloader): | ||
297 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
298 | + normal_outputs, state_temp = model_trt(inputs, state_temp) | ||
299 | + | ||
300 | + _, preds = torch.max(normal_outputs, 1) | ||
301 | + edge_loss = criterion(normal_outputs, labels) | ||
302 | + step_loss += edge_loss.item() | ||
303 | + loss_cnt += 1 | ||
304 | + | ||
305 | + corr_sum = torch.sum(preds == labels.data) | ||
306 | + step_acc += corr_sum.double() | ||
307 | + cnt += batch_size | ||
308 | + print('trt', step_acc.item() / cnt, step_loss / loss_cnt) | ||
309 | + | ||
310 | + state_temp = torch.zeros((batch_size, 8 * 32)).half().cuda() | ||
311 | + step_acc = 0.0 | ||
312 | + cnt = 0 | ||
313 | + step_loss = 0.0 | ||
314 | + loss_cnt = 0 | ||
315 | + for i, (inputs, labels) in enumerate(testloader): | ||
316 | + inputs, labels = inputs.half().cuda(), labels.long().cuda() | ||
317 | + normal_outputs, state_temp = model_trt_f16(inputs, state_temp) | ||
318 | + | ||
319 | + _, preds = torch.max(normal_outputs, 1) | ||
320 | + edge_loss = criterion(normal_outputs, labels) | ||
321 | + step_loss += edge_loss.item() | ||
322 | + loss_cnt += 1 | ||
323 | + | ||
324 | + corr_sum = torch.sum(preds == labels.data) | ||
325 | + step_acc += corr_sum.double() | ||
326 | + cnt += batch_size | ||
327 | + print('float16', step_acc.item() / cnt, step_loss / loss_cnt) | ||
328 | + | ||
329 | + state_temp = torch.zeros((batch_size, 8 * 32)).float().cuda() | ||
330 | + step_acc = 0.0 | ||
331 | + cnt = 0 | ||
332 | + step_loss = 0.0 | ||
333 | + loss_cnt = 0 | ||
334 | + for i, (inputs, labels) in enumerate(testloader): | ||
335 | + inputs, labels = inputs.float().cuda(), labels.long().cuda() | ||
336 | + normal_outputs, state_temp = model_trt_int8_strict(inputs, state_temp) | ||
337 | + | ||
338 | + _, preds = torch.max(normal_outputs, 1) | ||
339 | + edge_loss = criterion(normal_outputs, labels) | ||
340 | + step_loss += edge_loss.item() | ||
341 | + loss_cnt += 1 | ||
342 | + | ||
343 | + corr_sum = torch.sum(preds == labels.data) | ||
344 | + step_acc += corr_sum.double() | ||
345 | + cnt += batch_size | ||
346 | + print('int8 strict', step_acc.item() / cnt, step_loss / loss_cnt) | ||
347 | + | ||
348 | + | ||
349 | +def drawGraph(x_value, x_label, y_axis, y_label): | ||
350 | + pass | ||
351 | + | ||
352 | + | ||
353 | +def CsvToTextOne(csv_file): | ||
354 | + target_csv = pd.read_csv(csv_file) | ||
355 | + file_name, extension = os.path.splitext(csv_file) | ||
356 | + print(file_name, extension) | ||
357 | + target_text = open(file_name + '_1.txt', mode='wt', encoding='utf-8') | ||
358 | + | ||
359 | + idx = 0 | ||
360 | + print(len(target_csv)) | ||
361 | + | ||
362 | + while idx < len(target_csv): | ||
363 | + csv_row = target_csv.iloc[idx] | ||
364 | + data_len = csv_row[1] | ||
365 | + is_regular = (csv_row[data_len + 2] == 'R') | ||
366 | + | ||
367 | + if is_regular: | ||
368 | + target_text.write("%d R\n" % idx) | ||
369 | + else: | ||
370 | + target_text.write("%d T\n" % idx) | ||
371 | + | ||
372 | + idx += 1 | ||
373 | + if (idx % 1000000 == 0): | ||
374 | + print(idx) | ||
375 | + | ||
376 | + target_text.close() | ||
377 | + print('done') | ||
378 | + | ||
379 | + | ||
380 | +def Mix_Four_CANDataset(): | ||
381 | + Dos_csv = pd.read_csv('./dataset/DoS_dataset.csv') | ||
382 | + Other_csv = [pd.read_csv('./dataset/Fuzzy_dataset.csv'), | ||
383 | + pd.read_csv('./dataset/RPM_dataset.csv'), | ||
384 | + pd.read_csv('./dataset/gear_dataset.csv')] | ||
385 | + Other_csv_idx = [0, 0, 0] | ||
386 | + | ||
387 | + save_csv = open('./dataset/Mixed_dataset.csv', 'w') | ||
388 | + save_csv_file = csv.writer(save_csv) | ||
389 | + | ||
390 | + # DoS 유해 트래픽 주기를 바꿈 | ||
391 | + # DoS 다음 세번의 Dos 자리를 다른 유해 트래픽으로 바꿈 | ||
392 | + # DoS / (Fuzzy, RPM, gear) 중 3번 순서 랜덤, 뽑히는 개수 랜덤 / Dos ... | ||
393 | + dos_idx = 0 | ||
394 | + dos_preriod = 3 | ||
395 | + while dos_idx < len(Dos_csv): | ||
396 | + dos_row = Dos_csv.iloc[dos_idx] | ||
397 | + number_of_data = dos_row[2] | ||
398 | + is_regular = (dos_row[number_of_data + 3] == 'R') | ||
399 | + dos_row.dropna(inplace=True) | ||
400 | + | ||
401 | + if is_regular: | ||
402 | + save_csv_file.writerow(dos_row[1:]) | ||
403 | + else: | ||
404 | + if dos_preriod == 3: | ||
405 | + save_csv_file.writerow(dos_row[1:]) | ||
406 | + np.random.seed(dos_idx) | ||
407 | + selected_edge = np.random.choice([0, 1, 2], 3, replace=True) | ||
408 | + else: | ||
409 | + selected_idx = selected_edge[dos_preriod] | ||
410 | + local_csv = Other_csv[selected_idx] | ||
411 | + local_idx = Other_csv_idx[selected_idx] | ||
412 | + | ||
413 | + while True: | ||
414 | + local_row = local_csv.iloc[local_idx] | ||
415 | + local_number_of_data = local_row[2] | ||
416 | + is_injected = (local_row[local_number_of_data + 3] == 'T') | ||
417 | + local_idx += 1 | ||
418 | + if is_injected: | ||
419 | + local_row.dropna(inplace=True) | ||
420 | + save_csv_file.writerow(local_row[1:]) | ||
421 | + break | ||
422 | + Other_csv_idx[selected_idx] = local_idx | ||
423 | + | ||
424 | + dos_preriod -= 1 | ||
425 | + if dos_preriod == -1: | ||
426 | + dos_preriod = 3 | ||
427 | + | ||
428 | + dos_idx += 1 | ||
429 | + if dos_idx % 100000 == 0: | ||
430 | + print(dos_idx) | ||
431 | + # break | ||
432 | + save_csv.close() | ||
433 | + | ||
434 | +def Mix_Six_SynCANDataset(): | ||
435 | + normal_csv = pd.read_csv('./dataset/test_normal.csv') | ||
436 | + normal_idx = 0 | ||
437 | + target_len = len(normal_csv) | ||
438 | + | ||
439 | + save_csv = open('./dataset/test_mixed.csv', 'w') | ||
440 | + save_csv_file = csv.writer(save_csv) | ||
441 | + | ||
442 | + other_csv = [pd.read_csv('./dataset/test_continuous.csv'), | ||
443 | + pd.read_csv('./dataset/test_flooding.csv'), | ||
444 | + pd.read_csv('./dataset/test_plateau.csv'), | ||
445 | + pd.read_csv('./dataset/test_playback.csv'), | ||
446 | + pd.read_csv('./dataset/test_suppress.csv')] | ||
447 | + other_csv_idx = [0, 0, 0, 0, 0] | ||
448 | + | ||
449 | + while normal_idx < target_len: | ||
450 | + np.random.seed(normal_idx) | ||
451 | + selected_csv = np.random.choice([0, 1, 2, 3, 4], 5, replace=True) | ||
452 | + all_done = True | ||
453 | + for csv_idx in selected_csv: | ||
454 | + now_csv = other_csv[csv_idx] | ||
455 | + now_idx = other_csv_idx[csv_idx] | ||
456 | + | ||
457 | + start_normal_idx = now_idx | ||
458 | + while now_idx < len(now_csv): | ||
459 | + csv_row_ahead = now_csv.iloc[now_idx + 1] | ||
460 | + label_ahead = csv_row_ahead[0] | ||
461 | + | ||
462 | + csv_row_behind = now_csv.iloc[now_idx] | ||
463 | + label_behind = csv_row_behind[0] | ||
464 | + | ||
465 | + if label_ahead == 1 and label_behind == 0: | ||
466 | + print(now_idx, 'start error') | ||
467 | + add_normal_len = (now_idx - start_normal_idx) // 9 | ||
468 | + start_abnormal_idx = now_idx + 1 | ||
469 | + elif label_ahead == 0 and label_behind == 1: | ||
470 | + print(now_idx, 'end error') | ||
471 | + add_abnormal_len = (now_idx - start_abnormal_idx) // 6 | ||
472 | + | ||
473 | + for _ in range(6): | ||
474 | + # done | ||
475 | + if normal_idx + add_normal_len >= target_len: | ||
476 | + save_csv.close() | ||
477 | + return | ||
478 | + | ||
479 | + # write normal | ||
480 | + for idx in range(normal_idx, normal_idx + add_normal_len): | ||
481 | + row = normal_csv.iloc[idx] | ||
482 | + row = row.fillna(0) | ||
483 | + save_csv_file.writerow(row[0:1].append(row[2:])) | ||
484 | + normal_idx += add_normal_len | ||
485 | + # write abnormal | ||
486 | + for idx in range(start_abnormal_idx, start_abnormal_idx + add_abnormal_len): | ||
487 | + row = now_csv.iloc[idx] | ||
488 | + row = row.fillna(0) | ||
489 | + save_csv_file.writerow(row[0:1].append(row[2:])) | ||
490 | + start_abnormal_idx += add_abnormal_len | ||
491 | + | ||
492 | + other_csv_idx[csv_idx] = now_idx + 1 | ||
493 | + # check other csv not end | ||
494 | + all_done = False | ||
495 | + break | ||
496 | + | ||
497 | + now_idx += 1 | ||
498 | + | ||
499 | + if all_done: | ||
500 | + break | ||
501 | + | ||
502 | + save_csv.close() |
-
Please register or login to post a comment