김지훈

엣지모델 수정 전, 학습 속도 개선 필요

1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +build/
12 +develop-eggs/
13 +dist/
14 +downloads/
15 +eggs/
16 +.eggs/
17 +lib/
18 +lib64/
19 +parts/
20 +sdist/
21 +var/
22 +wheels/
23 +share/python-wheels/
24 +*.egg-info/
25 +.installed.cfg
26 +*.egg
27 +MANIFEST
28 +
29 +# PyInstaller
30 +# Usually these files are written by a python script from a template
31 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
32 +*.manifest
33 +*.spec
34 +
35 +# Installer logs
36 +pip-log.txt
37 +pip-delete-this-directory.txt
38 +
39 +# Unit test / coverage reports
40 +htmlcov/
41 +.tox/
42 +.nox/
43 +.coverage
44 +.coverage.*
45 +.cache
46 +nosetests.xml
47 +coverage.xml
48 +*.cover
49 +*.py,cover
50 +.hypothesis/
51 +.pytest_cache/
52 +cover/
53 +
54 +# Translations
55 +*.mo
56 +*.pot
57 +
58 +# Django stuff:
59 +*.log
60 +local_settings.py
61 +db.sqlite3
62 +db.sqlite3-journal
63 +
64 +# Flask stuff:
65 +instance/
66 +.webassets-cache
67 +
68 +# Scrapy stuff:
69 +.scrapy
70 +
71 +# Sphinx documentation
72 +docs/_build/
73 +
74 +# PyBuilder
75 +.pybuilder/
76 +target/
77 +
78 +# Jupyter Notebook
79 +.ipynb_checkpoints
80 +
81 +# IPython
82 +profile_default/
83 +ipython_config.py
84 +
85 +# pyenv
86 +# For a library or package, you might want to ignore these files since the code is
87 +# intended to run in multiple environments; otherwise, check them in:
88 +# .python-version
89 +
90 +# pipenv
91 +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 +# However, in case of collaboration, if having platform-specific dependencies or dependencies
93 +# having no cross-platform support, pipenv may install dependencies that don't work, or not
94 +# install all needed dependencies.
95 +#Pipfile.lock
96 +
97 +# PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 +__pypackages__/
99 +
100 +# Celery stuff
101 +celerybeat-schedule
102 +celerybeat.pid
103 +
104 +# SageMath parsed files
105 +*.sage.py
106 +
107 +# Environments
108 +.env
109 +.venv
110 +env/
111 +venv/
112 +ENV/
113 +env.bak/
114 +venv.bak/
115 +
116 +# Spyder project settings
117 +.spyderproject
118 +.spyproject
119 +
120 +# Rope project settings
121 +.ropeproject
122 +
123 +# mkdocs documentation
124 +/site
125 +
126 +# mypy
127 +.mypy_cache/
128 +.dmypy.json
129 +dmypy.json
130 +
131 +# Pyre type checker
132 +.pyre/
133 +
134 +# pytype static type analyzer
135 +.pytype/
136 +
137 +# Cython debug symbols
138 +cython_debug/
...\ No newline at end of file ...\ No newline at end of file
1 +CAN_ID_BIT = 29
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import torch
3 +import pandas as pd
4 +import numpy as np
5 +from torch.utils.data import Dataset, DataLoader
6 +import const
7 +
8 +'''
9 +def int_to_binary(x, bits):
10 + mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
11 + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
12 +'''
13 +
14 +def unpack_bits(x, num_bits):
15 + """
16 + Args:
17 + x (int): bit로 변환할 정수
18 + num_bits (int): 표현할 비트수
19 + """
20 + xshape = list(x.shape)
21 + x = x.reshape([-1, 1])
22 + mask = 2**np.arange(num_bits).reshape([1, num_bits])
23 + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])
24 +
25 +
26 +# def CsvToNumpy(csv_file):
27 +# target_csv = pd.read_csv(csv_file)
28 +# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
29 +# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
30 +# print(inputs_save_numpy, labels_save_numpy)
31 +
32 +# i = 0
33 +# inputs_array = []
34 +# labels_array = []
35 +# print(len(target_csv))
36 +
37 +# while i + const.CAN_ID_BIT - 1 < len(target_csv):
38 +
39 +# is_regular = True
40 +# for j in range(const.CAN_ID_BIT):
41 +# l = target_csv.iloc[i + j]
42 +# b = l[2]
43 +# r = (l[b+2+1] == 'R')
44 +
45 +# if not r:
46 +# is_regular = False
47 +# break
48 +
49 +# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
50 +# for idx in range(const.CAN_ID_BIT):
51 +# can_id = int(target_csv.iloc[i + idx, 1], 16)
52 +# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT)
53 +# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
54 +
55 +# if is_regular:
56 +# labels = 1
57 +# else:
58 +# labels = 0
59 +
60 +# inputs_array.append(inputs)
61 +# labels_array.append(labels)
62 +
63 +# i+=1
64 +# if (i % 5000 == 0):
65 +# print(i)
66 +# # break
67 +
68 +# inputs_array = np.array(inputs_array)
69 +# labels_array = np.array(labels_array)
70 +# np.save(inputs_save_numpy, arr=inputs_array)
71 +# np.save(labels_save_numpy, arr=labels_array)
72 +# print('done')
73 +
74 +
75 +def CsvToText(csv_file):
76 + target_csv = pd.read_csv(csv_file)
77 + text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt'
78 + print(text_file_name)
79 + target_text = open(text_file_name, mode='wt', encoding='utf-8')
80 +
81 + i = 0
82 + datum = [ [], [] ]
83 + print(len(target_csv))
84 +
85 + while i + const.CAN_ID_BIT - 1 < len(target_csv):
86 +
87 + is_regular = True
88 + for j in range(const.CAN_ID_BIT):
89 + l = target_csv.iloc[i + j]
90 + b = l[2]
91 + r = (l[b+2+1] == 'R')
92 +
93 + if not r:
94 + is_regular = False
95 + break
96 +
97 + if is_regular:
98 + target_text.write("%d R\n" % i)
99 + else:
100 + target_text.write("%d T\n" % i)
101 +
102 + i+=1
103 + if (i % 5000 == 0):
104 + print(i)
105 +
106 + target_text.close()
107 + print('done')
108 +
109 +
110 +def record_net_data_stats(label_temp, data_idx_map):
111 + net_class_count = {}
112 + net_data_count= {}
113 +
114 + for net_i, dataidx in data_idx_map.items():
115 + unq, unq_cnt = np.unique(label_temp[dataidx], return_counts=True)
116 + tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
117 + net_class_count[net_i] = tmp
118 + net_data_count[net_i] = len(dataidx)
119 + print('Data statistics: %s' % str(net_class_count))
120 + return net_class_count, net_data_count
121 +
122 +
123 +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs):
124 + csv_list = []
125 + total_datum = []
126 + total_label_temp = []
127 + csv_idx = 0
128 + for csv_file, txt_file in kwargs.items():
129 + csv = pd.read_csv(csv_file)
130 + csv_list.append(csv)
131 +
132 + txt = open(txt_file, "r")
133 + lines = txt.read().splitlines()
134 +
135 + idx = 0
136 + local_datum = []
137 + while idx + const.CAN_ID_BIT - 1 < len(csv):
138 + line = lines[idx]
139 + if not line:
140 + break
141 +
142 + if line.split(' ')[1] == 'R':
143 + local_datum.append((csv_idx, idx, 1))
144 + total_label_temp.append(1)
145 + else:
146 + local_datum.append((csv_idx, idx, 0))
147 + total_label_temp.append(0)
148 +
149 + idx += 1
150 + if (idx % 1000000 == 0):
151 + print(idx)
152 +
153 + csv_idx += 1
154 + total_datum += local_datum
155 +
156 + fold_length = int(len(total_label_temp) / 5)
157 + datum = []
158 + label_temp = []
159 + for i in range(5):
160 + if i != fold_num:
161 + datum += total_datum[i*fold_length:(i+1)*fold_length]
162 + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length]
163 + else:
164 + test_datum = total_datum[i*fold_length:(i+1)*fold_length]
165 +
166 + min_size = 0
167 + output_class_num = 2
168 + N = len(label_temp)
169 + label_temp = np.array(label_temp)
170 + data_idx_map = {}
171 +
172 + while min_size < 512:
173 + idx_batch = [[] for _ in range(total_edge)]
174 + # for each class in the dataset
175 + for k in range(output_class_num):
176 + idx_k = np.where(label_temp == k)[0]
177 + np.random.shuffle(idx_k)
178 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
179 + ## Balance
180 + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
181 + proportions = proportions/proportions.sum()
182 + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
183 + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
184 + min_size = min([len(idx_j) for idx_j in idx_batch])
185 +
186 + for j in range(total_edge):
187 + np.random.shuffle(idx_batch[j])
188 + data_idx_map[j] = idx_batch[j]
189 +
190 + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map)
191 +
192 + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False)
193 +
194 +
195 +class CanDatasetKwarg(Dataset):
196 +
197 + def __init__(self, csv_list, datum, is_train=True):
198 + self.csv_list = csv_list
199 + self.datum = datum
200 + if is_train:
201 + self.idx_map = []
202 + else:
203 + self.idx_map = [idx for idx in range(len(self.datum))]
204 +
205 + def __len__(self):
206 + return len(self.idx_map)
207 +
208 + def set_idx_map(self, data_idx_map):
209 + self.idx_map = data_idx_map
210 +
211 + def __getitem__(self, idx):
212 + csv_idx = self.datum[self.idx_map[idx]][0]
213 + start_i = self.datum[self.idx_map[idx]][1]
214 + is_regular = self.datum[self.idx_map[idx]][2]
215 +
216 + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
217 + for i in range(const.CAN_ID_BIT):
218 + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16)
219 + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT)
220 + l[i] = bits
221 + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
222 +
223 + return (l, is_regular)
224 +
225 +
226 +def GetCanDatasetUsingTxt(csv_file, txt_path, length):
227 + csv = pd.read_csv(csv_file)
228 + txt = open(txt_path, "r")
229 + lines = txt.read().splitlines()
230 +
231 + idx = 0
232 + datum = [ [], [] ]
233 + while idx + const.CAN_ID_BIT - 1 < len(csv):
234 + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
235 + break
236 +
237 + line = lines[idx]
238 + if not line:
239 + break
240 +
241 + if line.split(' ')[1] == 'R':
242 + if len(datum[0]) < length//2:
243 + datum[0].append((idx, 1))
244 + else:
245 + if len(datum[1]) < length//2:
246 + datum[1].append((idx, 0))
247 +
248 + idx += 1
249 + if (idx % 5000 == 0):
250 + print(idx, len(datum[0]), len(datum[1]))
251 +
252 + l = int((length // 2) * 0.9)
253 + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
254 + CanDataset(csv, datum[0][l:] + datum[1][l:])
255 +
256 +
257 +def GetCanDataset(csv_file, length):
258 + csv = pd.read_csv(csv_file)
259 +
260 + i = 0
261 + datum = [ [], [] ]
262 +
263 + while i + const.CAN_ID_BIT - 1 < len(csv):
264 + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
265 + break
266 +
267 + is_regular = True
268 + for j in range(const.CAN_ID_BIT):
269 + l = csv.iloc[i + j]
270 + b = l[2]
271 + r = (l[b+2+1] == 'R')
272 +
273 + if not r:
274 + is_regular = False
275 + break
276 +
277 + if is_regular:
278 + if len(datum[0]) < length//2:
279 + datum[0].append((i, 1))
280 + else:
281 + if len(datum[1]) < length//2:
282 + datum[1].append((i, 0))
283 + i+=1
284 + if (i % 5000 == 0):
285 + print(i, len(datum[0]), len(datum[1]))
286 +
287 + l = int((length // 2) * 0.9)
288 + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
289 + CanDataset(csv, datum[0][l:] + datum[1][l:])
290 +
291 +
292 +class CanDataset(Dataset):
293 +
294 + def __init__(self, csv, datum):
295 + self.csv = csv
296 + self.datum = datum
297 +
298 + def __len__(self):
299 + return len(self.datum)
300 +
301 + def __getitem__(self, idx):
302 + start_i = self.datum[idx][0]
303 + is_regular = self.datum[idx][1]
304 +
305 + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
306 + for i in range(const.CAN_ID_BIT):
307 + id = int(self.csv.iloc[start_i + i, 1], 16)
308 + bits = unpack_bits(np.array(id), const.CAN_ID_BIT)
309 + l[i] = bits
310 + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
311 +
312 + return (l, is_regular)
313 +
314 +
315 +if __name__ == "__main__":
316 + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
317 + test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs)
318 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size,
319 + shuffle=False, num_workers=2)
320 +
321 + for x, y in testloader:
322 + print(x)
323 + print(y)
324 + break
1 +import utils
2 +import copy
3 +from collections import OrderedDict
4 +
5 +import model
6 +import dataset
7 +
8 +import importlib
9 +importlib.reload(utils)
10 +importlib.reload(model)
11 +importlib.reload(dataset)
12 +
13 +from utils import *
14 +
15 +
16 +def add_args(parser):
17 + # parser.add_argument('--model', type=str, default='moderate-cnn',
18 + # help='neural network used in training')
19 + parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
20 + help='dataset used for training')
21 + parser.add_argument('--fold_num', type=int, default=0,
22 + help='5-fold, 0 ~ 4')
23 + parser.add_argument('--batch_size', type=int, default=256, metavar='N',
24 + help='input batch size for training')
25 + parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
26 + help='learning rate')
27 + parser.add_argument('--n_nets', type=int, default=100, metavar='NN',
28 + help='number of workers in a distributed cluster')
29 + parser.add_argument('--comm_type', type=str, default='fedtwa',
30 + help='which type of communication strategy is going to be used: layerwise/blockwise')
31 + parser.add_argument('--comm_round', type=int, default=10,
32 + help='how many round of communications we shoud use')
33 + args = parser.parse_args(args=[])
34 + return args
35 +
36 +
37 +def start_fedavg(fed_model, args,
38 + train_data_set,
39 + data_idx_map,
40 + net_data_count,
41 + testloader,
42 + edges,
43 + device):
44 + print("start fed avg")
45 + criterion = nn.CrossEntropyLoss()
46 + C = 0.1
47 + num_edge = int(max(C * args.n_nets, 1))
48 + total_data_count = 0
49 + for _, data_count in net_data_count.items():
50 + total_data_count += data_count
51 + print("total data: %d" % total_data_count)
52 +
53 + for cr in range(1, args.comm_round + 1):
54 + print("Communication round : %d" % (cr))
55 +
56 + np.random.seed(cr) # make sure for each comparison, select the same clients each round
57 + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
58 + print("selected edge", selected_edge)
59 +
60 + for edge_progress, edge_index in enumerate(selected_edge):
61 + train_data_set.set_idx_map(data_idx_map[edge_index])
62 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
63 + shuffle=True, num_workers=2)
64 + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
65 +
66 + edges[edge_index] = copy.deepcopy(fed_model)
67 + edges[edge_index].to(device)
68 + edges[edge_index].train()
69 + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
70 + # train
71 + for data_idx, (inputs, labels) in enumerate(train_loader):
72 + inputs, labels = inputs.float().to(device), labels.long().to(device)
73 +
74 + edge_opt.zero_grad()
75 + edge_pred = edges[edge_index](inputs)
76 +
77 + edge_loss = criterion(edge_pred, labels)
78 + edge_loss.backward()
79 +
80 + edge_opt.step()
81 + edge_loss = edge_loss.item()
82 + if data_idx % 100 == 0:
83 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
84 + # break
85 + edges[edge_index].to('cpu')
86 +
87 + # cal weight using fed avg
88 + update_state = OrderedDict()
89 + for k, edge in enumerate(edges):
90 + local_state = edge.state_dict()
91 + for key in fed_model.state_dict().keys():
92 + if k == 0:
93 + update_state[key] = local_state[key] * net_data_count[k] / total_data_count
94 + else:
95 + update_state[key] += local_state[key] * net_data_count[k] / total_data_count
96 +
97 + fed_model.load_state_dict(update_state)
98 + if cr % 10 == 0:
99 + fed_model.to(device)
100 + fed_model.eval()
101 +
102 + total_loss = 0.0
103 + cnt = 0
104 + step_acc = 0.0
105 + with torch.no_grad():
106 + for i, data in enumerate(testloader):
107 + inputs, labels = data
108 + inputs, labels = inputs.float().to(device), labels.long().to(device)
109 +
110 + outputs = fed_model(inputs)
111 + _, preds = torch.max(outputs, 1)
112 +
113 + loss = criterion(outputs, labels)
114 + cnt += inputs.shape[0]
115 +
116 + corr_sum = torch.sum(preds == labels.data)
117 + step_acc += corr_sum.double()
118 + running_loss = loss.item() * inputs.shape[0]
119 + total_loss += running_loss
120 + if i % 200 == 0:
121 + print('test [%4d] loss: %.3f' % (i, loss.item()))
122 + # break
123 + print((step_acc / cnt).data)
124 + print(total_loss / cnt)
125 + fed_model.to('cpu')
126 +
127 +
128 +def start_fedprox(fed_model, args,
129 + train_data_set,
130 + data_idx_map,
131 + testloader,
132 + device):
133 + print("start fed prox")
134 + criterion = nn.CrossEntropyLoss()
135 + mu = 0.001
136 + C = 0.1
137 + num_edge = int(max(C * args.n_nets, 1))
138 + fed_model.to(device)
139 +
140 + for cr in range(1, args.comm_round + 1):
141 + print("Communication round : %d" % (cr))
142 + edge_weight_dict = {}
143 + fed_weight_dict = {}
144 + for fed_name, fed_param in fed_model.named_parameters():
145 + edge_weight_dict[fed_name] = []
146 + fed_weight_dict[fed_name] = fed_param
147 +
148 + np.random.seed(cr) # make sure for each comparison, select the same clients each round
149 + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
150 + print("selected edge", selected_edge)
151 +
152 + total_data_length = 0
153 + edge_data_len = []
154 + for edge_progress, edge_index in enumerate(selected_edge):
155 + train_data_set.set_idx_map(data_idx_map[edge_index])
156 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
157 + shuffle=True, num_workers=2)
158 + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
159 + total_data_length += len(train_data_set)
160 + edge_data_len.append(len(train_data_set))
161 +
162 + edge_model = copy.deepcopy(fed_model)
163 + edge_model.to(device)
164 + edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr)
165 + # train
166 + for data_idx, (inputs, labels) in enumerate(train_loader):
167 + inputs, labels = inputs.float().to(device), labels.long().to(device)
168 +
169 + edge_opt.zero_grad()
170 + edge_pred = edge_model(inputs)
171 +
172 + edge_loss = criterion(edge_pred, labels)
173 + # prox term
174 + fed_prox_reg = 0.0
175 + for edge_name, edge_param in edge_model.named_parameters():
176 + fed_prox_reg += ((mu / 2) * torch.norm((fed_weight_dict[edge_name] - edge_param))**2)
177 + edge_loss += fed_prox_reg
178 +
179 + edge_loss.backward()
180 +
181 + edge_opt.step()
182 + edge_loss = edge_loss.item()
183 + if data_idx % 100 == 0:
184 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
185 + # break
186 +
187 + edge_model.to('cpu')
188 + # save edge weight
189 + for edge_name, edge_param in edge_model.named_parameters():
190 + edge_weight_dict[edge_name].append(edge_param)
191 +
192 + fed_model.to('cpu')
193 + # cal weight, / number of edge
194 + for fed_name, fed_param in fed_model.named_parameters():
195 + fed_param.data.copy_( sum(weight / num_edge for weight in edge_weight_dict[fed_name]) )
196 + fed_model.to(device)
197 +
198 + if cr % 10 == 0:
199 + fed_model.eval()
200 + total_loss = 0.0
201 + cnt = 0
202 + step_acc = 0.0
203 + with torch.no_grad():
204 + for i, data in enumerate(testloader):
205 + inputs, labels = data
206 + inputs, labels = inputs.float().to(device), labels.long().to(device)
207 +
208 + outputs = fed_model(inputs)
209 + _, preds = torch.max(outputs, 1)
210 +
211 + loss = criterion(outputs, labels)
212 + cnt += inputs.shape[0]
213 +
214 + corr_sum = torch.sum(preds == labels.data)
215 + step_acc += corr_sum.double()
216 + running_loss = loss.item() * inputs.shape[0]
217 + total_loss += running_loss
218 + if i % 200 == 0:
219 + print('test [%4d] loss: %.3f' % (i, loss.item()))
220 + # break
221 + print((step_acc / cnt).data)
222 + print(total_loss / cnt)
223 +
224 +
225 +def start_fedtwa(fed_model, args,
226 + train_data_set,
227 + data_idx_map,
228 + net_data_count,
229 + testloader,
230 + edges,
231 + device):
232 + # TEFL, without asynchronous model update
233 + print("start fed temporally weighted aggregation")
234 + criterion = nn.CrossEntropyLoss()
235 + time_stamp = [0 for worker in range(args.n_nets)]
236 + twa_exp = math.e / 2.0
237 + C = 0.1
238 + num_edge = int(max(C * args.n_nets, 1))
239 + total_data_count = 0
240 + for _, data_count in net_data_count.items():
241 + total_data_count += data_count
242 + print("total data: %d" % total_data_count)
243 +
244 + for cr in range(1, args.comm_round + 1):
245 + print("Communication round : %d" % (cr))
246 +
247 + np.random.seed(cr) # make sure for each comparison, select the same clients each round
248 + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
249 + print("selected edge", selected_edge)
250 +
251 + for edge_progress, edge_index in enumerate(selected_edge):
252 + time_stamp[edge_index] = cr
253 + train_data_set.set_idx_map(data_idx_map[edge_index])
254 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
255 + shuffle=True, num_workers=2)
256 + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
257 +
258 + edges[edge_index] = copy.deepcopy(fed_model)
259 + edges[edge_index].to(device)
260 + edges[edge_index].train()
261 + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
262 + # train
263 + for data_idx, (inputs, labels) in enumerate(train_loader):
264 + inputs, labels = inputs.float().to(device), labels.long().to(device)
265 +
266 + edge_opt.zero_grad()
267 + edge_pred = edges[edge_index](inputs)
268 +
269 + edge_loss = criterion(edge_pred, labels)
270 + edge_loss.backward()
271 +
272 + edge_opt.step()
273 + edge_loss = edge_loss.item()
274 + if data_idx % 100 == 0:
275 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
276 + # break
277 + edges[edge_index].to('cpu')
278 +
279 + # cal weight using time stamp
280 + update_state = OrderedDict()
281 + for k, edge in enumerate(edges):
282 + local_state = edge.state_dict()
283 + for key in fed_model.state_dict().keys():
284 + if k == 0:
285 + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k]))
286 + else:
287 + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k]))
288 +
289 + fed_model.load_state_dict(update_state)
290 + if cr % 10 == 0:
291 + fed_model.to(device)
292 + fed_model.eval()
293 +
294 + total_loss = 0.0
295 + cnt = 0
296 + step_acc = 0.0
297 + with torch.no_grad():
298 + for i, data in enumerate(testloader):
299 + inputs, labels = data
300 + inputs, labels = inputs.float().to(device), labels.long().to(device)
301 +
302 + outputs = fed_model(inputs)
303 + _, preds = torch.max(outputs, 1)
304 +
305 + loss = criterion(outputs, labels)
306 + cnt += inputs.shape[0]
307 +
308 + corr_sum = torch.sum(preds == labels.data)
309 + step_acc += corr_sum.double()
310 + running_loss = loss.item() * inputs.shape[0]
311 + total_loss += running_loss
312 + if i % 200 == 0:
313 + print('test [%4d] loss: %.3f' % (i, loss.item()))
314 + # break
315 + print((step_acc / cnt).data)
316 + print(total_loss / cnt)
317 + fed_model.to('cpu')
318 +
319 +
320 +def start_feddw(fed_model, args,
321 + train_data_set,
322 + data_idx_map,
323 + net_data_count,
324 + testloader,
325 + local_test_loader,
326 + edges,
327 + device):
328 + print("start fed Node-aware Dynamic Weighting")
329 + worker_selected_frequency = [0 for worker in range(args.n_nets)]
330 + criterion = nn.CrossEntropyLoss()
331 + H = 0.5
332 + P = 0.5
333 + G = 0.1
334 + R = 0.1
335 + alpha, beta, gamma = 30.0/100.0, 50.0/100.0, 20.0/100.0
336 + num_edge = int(max(G * args.n_nets, 1))
337 +
338 + # cal data weight for selecting participants
339 + total_data_count = 0
340 + for _, data_count in net_data_count.items():
341 + total_data_count += data_count
342 + print("total data: %d" % total_data_count)
343 +
344 + total_data_weight = 0.0
345 + net_weight_dict = {}
346 + for net_key, data_count in net_data_count.items():
347 + net_data_count[net_key] = data_count / total_data_count
348 + net_weight_dict[net_key] = total_data_count / data_count
349 + total_data_weight += net_weight_dict[net_key]
350 +
351 + for net_key, data_count in net_weight_dict.items():
352 + net_weight_dict[net_key] = net_weight_dict[net_key] / total_data_weight
353 + # end
354 +
355 + worker_local_accuracy = [0 for worker in range(args.n_nets)]
356 +
357 + for cr in range(1, args.comm_round + 1):
358 + print("Communication round : %d" % (cr))
359 +
360 + # select participants
361 + candidates = []
362 + sum_frequency = sum(worker_selected_frequency)
363 + if sum_frequency == 0:
364 + sum_frequency = 1
365 + for worker_index in range(args.n_nets):
366 + candidates.append((H * worker_selected_frequency[worker_index] / sum_frequency + (1 - H) * net_weight_dict[worker_index], worker_index))
367 + candidates = sorted(candidates)[:int(R * args.n_nets)]
368 + candidates = [temp[1] for temp in candidates]
369 +
370 + np.random.seed(cr)
371 + selected_edge = np.random.choice(candidates, num_edge, replace=False)
372 + # end select
373 +
374 + # weighted frequency
375 + avg_selected_frequency = sum(worker_selected_frequency) / len(worker_selected_frequency)
376 + weighted_frequency = [P * (avg_selected_frequency - worker_frequency) for worker_frequency in worker_selected_frequency]
377 + frequency_prime = min(weighted_frequency)
378 + weighted_frequency = [frequency + frequency_prime + 1 for frequency in weighted_frequency]
379 + # end weigthed
380 +
381 + print("selected edge", selected_edge)
382 + for edge_progress, edge_index in enumerate(selected_edge):
383 + worker_selected_frequency[edge_index] += 1
384 + train_data_set.set_idx_map(data_idx_map[edge_index])
385 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
386 + shuffle=True, num_workers=2)
387 + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
388 +
389 + edges[edge_index] = copy.deepcopy(fed_model)
390 + edges[edge_index].to(device)
391 + edges[edge_index].train()
392 + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
393 + # train
394 + for data_idx, (inputs, labels) in enumerate(train_loader):
395 + inputs, labels = inputs.float().to(device), labels.long().to(device)
396 +
397 + edge_opt.zero_grad()
398 + edge_pred = edges[edge_index](inputs)
399 +
400 + edge_loss = criterion(edge_pred, labels)
401 + edge_loss.backward()
402 +
403 + edge_opt.step()
404 + edge_loss = edge_loss.item()
405 + if data_idx % 100 == 0:
406 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
407 + # break
408 +
409 + # get edge accuracy using subset of testset
410 + edges[edge_index].eval()
411 + print("[%2d/%2d] edge: %d, cal accuracy" % (edge_progress, len(selected_edge), edge_index))
412 + cnt = 0
413 + step_acc = 0.0
414 + with torch.no_grad():
415 + for inputs, labels in local_test_loader:
416 + inputs, labels = inputs.float().to(device), labels.long().to(device)
417 +
418 + outputs = edges[edge_index](inputs)
419 + _, preds = torch.max(outputs, 1)
420 +
421 + loss = criterion(outputs, labels)
422 + cnt += inputs.shape[0]
423 +
424 + corr_sum = torch.sum(preds == labels.data)
425 + step_acc += corr_sum.double()
426 + # break
427 +
428 + worker_local_accuracy[edge_index] = (step_acc / cnt).item()
429 + print(worker_local_accuracy[edge_index])
430 + edges[edge_index].to('cpu')
431 +
432 + # cal weight dynamically
433 + sum_accuracy = sum(worker_local_accuracy)
434 + sum_weighted_frequency = sum(weighted_frequency)
435 + update_state = OrderedDict()
436 + for k, edge in enumerate(edges):
437 + local_state = edge.state_dict()
438 + for key in fed_model.state_dict().keys():
439 + if k == 0:
440 + update_state[key] = local_state[key] \
441 + * (net_data_count[k] * alpha \
442 + + worker_local_accuracy[k] / sum_accuracy * beta \
443 + + weighted_frequency[k] / sum_weighted_frequency * gamma)
444 + else:
445 + update_state[key] += local_state[key] \
446 + * (net_data_count[k] * alpha \
447 + + worker_local_accuracy[k] / sum_accuracy * beta \
448 + + weighted_frequency[k] / sum_weighted_frequency * gamma)
449 +
450 + fed_model.load_state_dict(update_state)
451 + if cr % 10 == 0:
452 + fed_model.to(device)
453 + fed_model.eval()
454 +
455 + total_loss = 0.0
456 + cnt = 0
457 + step_acc = 0.0
458 + with torch.no_grad():
459 + for i, data in enumerate(testloader):
460 + inputs, labels = data
461 + inputs, labels = inputs.float().to(device), labels.long().to(device)
462 +
463 + outputs = fed_model(inputs)
464 + _, preds = torch.max(outputs, 1)
465 +
466 + loss = criterion(outputs, labels)
467 + cnt += inputs.shape[0]
468 +
469 + corr_sum = torch.sum(preds == labels.data)
470 + step_acc += corr_sum.double()
471 + running_loss = loss.item() * inputs.shape[0]
472 + total_loss += running_loss
473 + if i % 200 == 0:
474 + print('test [%4d] loss: %.3f' % (i, loss.item()))
475 + # break
476 + print((step_acc / cnt).data)
477 + print(total_loss / cnt)
478 + fed_model.to('cpu')
479 +
480 +
481 +def start_train():
482 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
483 + print(device)
484 + args = add_args(argparse.ArgumentParser())
485 +
486 + seed = 0
487 + np.random.seed(seed)
488 + torch.manual_seed(seed)
489 +
490 + print("Loading data...")
491 + # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt',
492 + # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt',
493 + # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt',
494 + # "./dataset/gear_dataset.csv" : './gear_dataset.txt'
495 + # }
496 + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
497 + train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs)
498 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size,
499 + shuffle=False, num_workers=2)
500 +
501 + fed_model = model.Net()
502 + args.comm_type = 'feddw'
503 + if args.comm_type == "fedavg":
504 + edges, _, _ = init_models(args.n_nets, args)
505 + start_fedavg(fed_model, args,
506 + train_data_set,
507 + data_idx_map,
508 + net_data_count,
509 + testloader,
510 + edges,
511 + device)
512 + elif args.comm_type == "fedprox":
513 + start_fedprox(fed_model, args,
514 + train_data_set,
515 + data_idx_map,
516 + testloader,
517 + device)
518 + elif args.comm_type == "fedtwa":
519 + edges, _, _ = init_models(args.n_nets, args)
520 + start_fedtwa(fed_model, args,
521 + train_data_set,
522 + data_idx_map,
523 + net_data_count,
524 + testloader,
525 + edges,
526 + device)
527 + elif args.comm_type == "feddw":
528 + local_test_set = copy.deepcopy(test_data_set)
529 + # mnist train 60,000 / test 10,000 / 1,000
530 + # CAN train ~ 13,000,000 / test 2,000,000 / for speed 40,000
531 + local_test_idx = np.random.choice(len(local_test_set), len(local_test_set) // 50, replace=False)
532 + local_test_set.set_idx_map(local_test_idx)
533 + local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size,
534 + shuffle=False, num_workers=2)
535 +
536 + edges, _, _ = init_models(args.n_nets, args)
537 + start_feddw(fed_model, args,
538 + train_data_set,
539 + data_idx_map,
540 + net_data_count,
541 + testloader,
542 + local_test_loader,
543 + edges,
544 + device)
545 +
546 +if __name__ == "__main__":
547 + start_train()
...\ No newline at end of file ...\ No newline at end of file
1 +import torch.nn as nn
2 +import torch.nn.functional as F
3 +import torch
4 +import const
5 +
6 +class Net(nn.Module):
7 + def __init__(self):
8 + super(Net, self).__init__()
9 +
10 + self.f1 = nn.Sequential(
11 + nn.Conv2d(1, 2, 3),
12 + nn.ReLU(True),
13 + )
14 + self.f2 = nn.Sequential(
15 + nn.Conv2d(2, 4, 3),
16 + nn.ReLU(True),
17 + )
18 + self.f3 = nn.Sequential(
19 + nn.Conv2d(4, 8, 3),
20 + nn.ReLU(True),
21 + )
22 + self.f4 = nn.Sequential(
23 + nn.Linear(8 * 23 * 23, 2),
24 + )
25 +
26 + def forward(self, x):
27 + x = self.f1(x)
28 + x = self.f2(x)
29 + x = self.f3(x)
30 + x = torch.flatten(x, 1)
31 + x = self.f4(x)
32 + return x
...\ No newline at end of file ...\ No newline at end of file