Showing
5 changed files
with
1042 additions
and
0 deletions
코드/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# C extensions | ||
7 | +*.so | ||
8 | + | ||
9 | +# Distribution / packaging | ||
10 | +.Python | ||
11 | +build/ | ||
12 | +develop-eggs/ | ||
13 | +dist/ | ||
14 | +downloads/ | ||
15 | +eggs/ | ||
16 | +.eggs/ | ||
17 | +lib/ | ||
18 | +lib64/ | ||
19 | +parts/ | ||
20 | +sdist/ | ||
21 | +var/ | ||
22 | +wheels/ | ||
23 | +share/python-wheels/ | ||
24 | +*.egg-info/ | ||
25 | +.installed.cfg | ||
26 | +*.egg | ||
27 | +MANIFEST | ||
28 | + | ||
29 | +# PyInstaller | ||
30 | +# Usually these files are written by a python script from a template | ||
31 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
32 | +*.manifest | ||
33 | +*.spec | ||
34 | + | ||
35 | +# Installer logs | ||
36 | +pip-log.txt | ||
37 | +pip-delete-this-directory.txt | ||
38 | + | ||
39 | +# Unit test / coverage reports | ||
40 | +htmlcov/ | ||
41 | +.tox/ | ||
42 | +.nox/ | ||
43 | +.coverage | ||
44 | +.coverage.* | ||
45 | +.cache | ||
46 | +nosetests.xml | ||
47 | +coverage.xml | ||
48 | +*.cover | ||
49 | +*.py,cover | ||
50 | +.hypothesis/ | ||
51 | +.pytest_cache/ | ||
52 | +cover/ | ||
53 | + | ||
54 | +# Translations | ||
55 | +*.mo | ||
56 | +*.pot | ||
57 | + | ||
58 | +# Django stuff: | ||
59 | +*.log | ||
60 | +local_settings.py | ||
61 | +db.sqlite3 | ||
62 | +db.sqlite3-journal | ||
63 | + | ||
64 | +# Flask stuff: | ||
65 | +instance/ | ||
66 | +.webassets-cache | ||
67 | + | ||
68 | +# Scrapy stuff: | ||
69 | +.scrapy | ||
70 | + | ||
71 | +# Sphinx documentation | ||
72 | +docs/_build/ | ||
73 | + | ||
74 | +# PyBuilder | ||
75 | +.pybuilder/ | ||
76 | +target/ | ||
77 | + | ||
78 | +# Jupyter Notebook | ||
79 | +.ipynb_checkpoints | ||
80 | + | ||
81 | +# IPython | ||
82 | +profile_default/ | ||
83 | +ipython_config.py | ||
84 | + | ||
85 | +# pyenv | ||
86 | +# For a library or package, you might want to ignore these files since the code is | ||
87 | +# intended to run in multiple environments; otherwise, check them in: | ||
88 | +# .python-version | ||
89 | + | ||
90 | +# pipenv | ||
91 | +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
92 | +# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
93 | +# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
94 | +# install all needed dependencies. | ||
95 | +#Pipfile.lock | ||
96 | + | ||
97 | +# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
98 | +__pypackages__/ | ||
99 | + | ||
100 | +# Celery stuff | ||
101 | +celerybeat-schedule | ||
102 | +celerybeat.pid | ||
103 | + | ||
104 | +# SageMath parsed files | ||
105 | +*.sage.py | ||
106 | + | ||
107 | +# Environments | ||
108 | +.env | ||
109 | +.venv | ||
110 | +env/ | ||
111 | +venv/ | ||
112 | +ENV/ | ||
113 | +env.bak/ | ||
114 | +venv.bak/ | ||
115 | + | ||
116 | +# Spyder project settings | ||
117 | +.spyderproject | ||
118 | +.spyproject | ||
119 | + | ||
120 | +# Rope project settings | ||
121 | +.ropeproject | ||
122 | + | ||
123 | +# mkdocs documentation | ||
124 | +/site | ||
125 | + | ||
126 | +# mypy | ||
127 | +.mypy_cache/ | ||
128 | +.dmypy.json | ||
129 | +dmypy.json | ||
130 | + | ||
131 | +# Pyre type checker | ||
132 | +.pyre/ | ||
133 | + | ||
134 | +# pytype static type analyzer | ||
135 | +.pytype/ | ||
136 | + | ||
137 | +# Cython debug symbols | ||
138 | +cython_debug/ | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/const.py
0 → 100644
1 | +CAN_ID_BIT = 29 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/dataset.py
0 → 100644
1 | +import os | ||
2 | +import torch | ||
3 | +import pandas as pd | ||
4 | +import numpy as np | ||
5 | +from torch.utils.data import Dataset, DataLoader | ||
6 | +import const | ||
7 | + | ||
8 | +''' | ||
9 | +def int_to_binary(x, bits): | ||
10 | + mask = 2 ** torch.arange(bits).to(x.device, x.dtype) | ||
11 | + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() | ||
12 | +''' | ||
13 | + | ||
14 | +def unpack_bits(x, num_bits): | ||
15 | + """ | ||
16 | + Args: | ||
17 | + x (int): bit로 변환할 정수 | ||
18 | + num_bits (int): 표현할 비트수 | ||
19 | + """ | ||
20 | + xshape = list(x.shape) | ||
21 | + x = x.reshape([-1, 1]) | ||
22 | + mask = 2**np.arange(num_bits).reshape([1, num_bits]) | ||
23 | + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits]) | ||
24 | + | ||
25 | + | ||
26 | +# def CsvToNumpy(csv_file): | ||
27 | +# target_csv = pd.read_csv(csv_file) | ||
28 | +# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | ||
29 | +# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | ||
30 | +# print(inputs_save_numpy, labels_save_numpy) | ||
31 | + | ||
32 | +# i = 0 | ||
33 | +# inputs_array = [] | ||
34 | +# labels_array = [] | ||
35 | +# print(len(target_csv)) | ||
36 | + | ||
37 | +# while i + const.CAN_ID_BIT - 1 < len(target_csv): | ||
38 | + | ||
39 | +# is_regular = True | ||
40 | +# for j in range(const.CAN_ID_BIT): | ||
41 | +# l = target_csv.iloc[i + j] | ||
42 | +# b = l[2] | ||
43 | +# r = (l[b+2+1] == 'R') | ||
44 | + | ||
45 | +# if not r: | ||
46 | +# is_regular = False | ||
47 | +# break | ||
48 | + | ||
49 | +# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
50 | +# for idx in range(const.CAN_ID_BIT): | ||
51 | +# can_id = int(target_csv.iloc[i + idx, 1], 16) | ||
52 | +# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT) | ||
53 | +# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
54 | + | ||
55 | +# if is_regular: | ||
56 | +# labels = 1 | ||
57 | +# else: | ||
58 | +# labels = 0 | ||
59 | + | ||
60 | +# inputs_array.append(inputs) | ||
61 | +# labels_array.append(labels) | ||
62 | + | ||
63 | +# i+=1 | ||
64 | +# if (i % 5000 == 0): | ||
65 | +# print(i) | ||
66 | +# # break | ||
67 | + | ||
68 | +# inputs_array = np.array(inputs_array) | ||
69 | +# labels_array = np.array(labels_array) | ||
70 | +# np.save(inputs_save_numpy, arr=inputs_array) | ||
71 | +# np.save(labels_save_numpy, arr=labels_array) | ||
72 | +# print('done') | ||
73 | + | ||
74 | + | ||
75 | +def CsvToText(csv_file): | ||
76 | + target_csv = pd.read_csv(csv_file) | ||
77 | + text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt' | ||
78 | + print(text_file_name) | ||
79 | + target_text = open(text_file_name, mode='wt', encoding='utf-8') | ||
80 | + | ||
81 | + i = 0 | ||
82 | + datum = [ [], [] ] | ||
83 | + print(len(target_csv)) | ||
84 | + | ||
85 | + while i + const.CAN_ID_BIT - 1 < len(target_csv): | ||
86 | + | ||
87 | + is_regular = True | ||
88 | + for j in range(const.CAN_ID_BIT): | ||
89 | + l = target_csv.iloc[i + j] | ||
90 | + b = l[2] | ||
91 | + r = (l[b+2+1] == 'R') | ||
92 | + | ||
93 | + if not r: | ||
94 | + is_regular = False | ||
95 | + break | ||
96 | + | ||
97 | + if is_regular: | ||
98 | + target_text.write("%d R\n" % i) | ||
99 | + else: | ||
100 | + target_text.write("%d T\n" % i) | ||
101 | + | ||
102 | + i+=1 | ||
103 | + if (i % 5000 == 0): | ||
104 | + print(i) | ||
105 | + | ||
106 | + target_text.close() | ||
107 | + print('done') | ||
108 | + | ||
109 | + | ||
110 | +def record_net_data_stats(label_temp, data_idx_map): | ||
111 | + net_class_count = {} | ||
112 | + net_data_count= {} | ||
113 | + | ||
114 | + for net_i, dataidx in data_idx_map.items(): | ||
115 | + unq, unq_cnt = np.unique(label_temp[dataidx], return_counts=True) | ||
116 | + tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} | ||
117 | + net_class_count[net_i] = tmp | ||
118 | + net_data_count[net_i] = len(dataidx) | ||
119 | + print('Data statistics: %s' % str(net_class_count)) | ||
120 | + return net_class_count, net_data_count | ||
121 | + | ||
122 | + | ||
123 | +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs): | ||
124 | + csv_list = [] | ||
125 | + total_datum = [] | ||
126 | + total_label_temp = [] | ||
127 | + csv_idx = 0 | ||
128 | + for csv_file, txt_file in kwargs.items(): | ||
129 | + csv = pd.read_csv(csv_file) | ||
130 | + csv_list.append(csv) | ||
131 | + | ||
132 | + txt = open(txt_file, "r") | ||
133 | + lines = txt.read().splitlines() | ||
134 | + | ||
135 | + idx = 0 | ||
136 | + local_datum = [] | ||
137 | + while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
138 | + line = lines[idx] | ||
139 | + if not line: | ||
140 | + break | ||
141 | + | ||
142 | + if line.split(' ')[1] == 'R': | ||
143 | + local_datum.append((csv_idx, idx, 1)) | ||
144 | + total_label_temp.append(1) | ||
145 | + else: | ||
146 | + local_datum.append((csv_idx, idx, 0)) | ||
147 | + total_label_temp.append(0) | ||
148 | + | ||
149 | + idx += 1 | ||
150 | + if (idx % 1000000 == 0): | ||
151 | + print(idx) | ||
152 | + | ||
153 | + csv_idx += 1 | ||
154 | + total_datum += local_datum | ||
155 | + | ||
156 | + fold_length = int(len(total_label_temp) / 5) | ||
157 | + datum = [] | ||
158 | + label_temp = [] | ||
159 | + for i in range(5): | ||
160 | + if i != fold_num: | ||
161 | + datum += total_datum[i*fold_length:(i+1)*fold_length] | ||
162 | + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length] | ||
163 | + else: | ||
164 | + test_datum = total_datum[i*fold_length:(i+1)*fold_length] | ||
165 | + | ||
166 | + min_size = 0 | ||
167 | + output_class_num = 2 | ||
168 | + N = len(label_temp) | ||
169 | + label_temp = np.array(label_temp) | ||
170 | + data_idx_map = {} | ||
171 | + | ||
172 | + while min_size < 512: | ||
173 | + idx_batch = [[] for _ in range(total_edge)] | ||
174 | + # for each class in the dataset | ||
175 | + for k in range(output_class_num): | ||
176 | + idx_k = np.where(label_temp == k)[0] | ||
177 | + np.random.shuffle(idx_k) | ||
178 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
179 | + ## Balance | ||
180 | + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)]) | ||
181 | + proportions = proportions/proportions.sum() | ||
182 | + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] | ||
183 | + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))] | ||
184 | + min_size = min([len(idx_j) for idx_j in idx_batch]) | ||
185 | + | ||
186 | + for j in range(total_edge): | ||
187 | + np.random.shuffle(idx_batch[j]) | ||
188 | + data_idx_map[j] = idx_batch[j] | ||
189 | + | ||
190 | + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map) | ||
191 | + | ||
192 | + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False) | ||
193 | + | ||
194 | + | ||
195 | +class CanDatasetKwarg(Dataset): | ||
196 | + | ||
197 | + def __init__(self, csv_list, datum, is_train=True): | ||
198 | + self.csv_list = csv_list | ||
199 | + self.datum = datum | ||
200 | + if is_train: | ||
201 | + self.idx_map = [] | ||
202 | + else: | ||
203 | + self.idx_map = [idx for idx in range(len(self.datum))] | ||
204 | + | ||
205 | + def __len__(self): | ||
206 | + return len(self.idx_map) | ||
207 | + | ||
208 | + def set_idx_map(self, data_idx_map): | ||
209 | + self.idx_map = data_idx_map | ||
210 | + | ||
211 | + def __getitem__(self, idx): | ||
212 | + csv_idx = self.datum[self.idx_map[idx]][0] | ||
213 | + start_i = self.datum[self.idx_map[idx]][1] | ||
214 | + is_regular = self.datum[self.idx_map[idx]][2] | ||
215 | + | ||
216 | + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
217 | + for i in range(const.CAN_ID_BIT): | ||
218 | + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16) | ||
219 | + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT) | ||
220 | + l[i] = bits | ||
221 | + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
222 | + | ||
223 | + return (l, is_regular) | ||
224 | + | ||
225 | + | ||
226 | +def GetCanDatasetUsingTxt(csv_file, txt_path, length): | ||
227 | + csv = pd.read_csv(csv_file) | ||
228 | + txt = open(txt_path, "r") | ||
229 | + lines = txt.read().splitlines() | ||
230 | + | ||
231 | + idx = 0 | ||
232 | + datum = [ [], [] ] | ||
233 | + while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
234 | + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
235 | + break | ||
236 | + | ||
237 | + line = lines[idx] | ||
238 | + if not line: | ||
239 | + break | ||
240 | + | ||
241 | + if line.split(' ')[1] == 'R': | ||
242 | + if len(datum[0]) < length//2: | ||
243 | + datum[0].append((idx, 1)) | ||
244 | + else: | ||
245 | + if len(datum[1]) < length//2: | ||
246 | + datum[1].append((idx, 0)) | ||
247 | + | ||
248 | + idx += 1 | ||
249 | + if (idx % 5000 == 0): | ||
250 | + print(idx, len(datum[0]), len(datum[1])) | ||
251 | + | ||
252 | + l = int((length // 2) * 0.9) | ||
253 | + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
254 | + CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
255 | + | ||
256 | + | ||
257 | +def GetCanDataset(csv_file, length): | ||
258 | + csv = pd.read_csv(csv_file) | ||
259 | + | ||
260 | + i = 0 | ||
261 | + datum = [ [], [] ] | ||
262 | + | ||
263 | + while i + const.CAN_ID_BIT - 1 < len(csv): | ||
264 | + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
265 | + break | ||
266 | + | ||
267 | + is_regular = True | ||
268 | + for j in range(const.CAN_ID_BIT): | ||
269 | + l = csv.iloc[i + j] | ||
270 | + b = l[2] | ||
271 | + r = (l[b+2+1] == 'R') | ||
272 | + | ||
273 | + if not r: | ||
274 | + is_regular = False | ||
275 | + break | ||
276 | + | ||
277 | + if is_regular: | ||
278 | + if len(datum[0]) < length//2: | ||
279 | + datum[0].append((i, 1)) | ||
280 | + else: | ||
281 | + if len(datum[1]) < length//2: | ||
282 | + datum[1].append((i, 0)) | ||
283 | + i+=1 | ||
284 | + if (i % 5000 == 0): | ||
285 | + print(i, len(datum[0]), len(datum[1])) | ||
286 | + | ||
287 | + l = int((length // 2) * 0.9) | ||
288 | + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
289 | + CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
290 | + | ||
291 | + | ||
292 | +class CanDataset(Dataset): | ||
293 | + | ||
294 | + def __init__(self, csv, datum): | ||
295 | + self.csv = csv | ||
296 | + self.datum = datum | ||
297 | + | ||
298 | + def __len__(self): | ||
299 | + return len(self.datum) | ||
300 | + | ||
301 | + def __getitem__(self, idx): | ||
302 | + start_i = self.datum[idx][0] | ||
303 | + is_regular = self.datum[idx][1] | ||
304 | + | ||
305 | + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
306 | + for i in range(const.CAN_ID_BIT): | ||
307 | + id = int(self.csv.iloc[start_i + i, 1], 16) | ||
308 | + bits = unpack_bits(np.array(id), const.CAN_ID_BIT) | ||
309 | + l[i] = bits | ||
310 | + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
311 | + | ||
312 | + return (l, is_regular) | ||
313 | + | ||
314 | + | ||
315 | +if __name__ == "__main__": | ||
316 | + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | ||
317 | + test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs) | ||
318 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, | ||
319 | + shuffle=False, num_workers=2) | ||
320 | + | ||
321 | + for x, y in testloader: | ||
322 | + print(x) | ||
323 | + print(y) | ||
324 | + break |
코드/fed_train.py
0 → 100644
1 | +import utils | ||
2 | +import copy | ||
3 | +from collections import OrderedDict | ||
4 | + | ||
5 | +import model | ||
6 | +import dataset | ||
7 | + | ||
8 | +import importlib | ||
9 | +importlib.reload(utils) | ||
10 | +importlib.reload(model) | ||
11 | +importlib.reload(dataset) | ||
12 | + | ||
13 | +from utils import * | ||
14 | + | ||
15 | + | ||
16 | +def add_args(parser): | ||
17 | + # parser.add_argument('--model', type=str, default='moderate-cnn', | ||
18 | + # help='neural network used in training') | ||
19 | + parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', | ||
20 | + help='dataset used for training') | ||
21 | + parser.add_argument('--fold_num', type=int, default=0, | ||
22 | + help='5-fold, 0 ~ 4') | ||
23 | + parser.add_argument('--batch_size', type=int, default=256, metavar='N', | ||
24 | + help='input batch size for training') | ||
25 | + parser.add_argument('--lr', type=float, default=0.002, metavar='LR', | ||
26 | + help='learning rate') | ||
27 | + parser.add_argument('--n_nets', type=int, default=100, metavar='NN', | ||
28 | + help='number of workers in a distributed cluster') | ||
29 | + parser.add_argument('--comm_type', type=str, default='fedtwa', | ||
30 | + help='which type of communication strategy is going to be used: layerwise/blockwise') | ||
31 | + parser.add_argument('--comm_round', type=int, default=10, | ||
32 | + help='how many round of communications we shoud use') | ||
33 | + args = parser.parse_args(args=[]) | ||
34 | + return args | ||
35 | + | ||
36 | + | ||
37 | +def start_fedavg(fed_model, args, | ||
38 | + train_data_set, | ||
39 | + data_idx_map, | ||
40 | + net_data_count, | ||
41 | + testloader, | ||
42 | + edges, | ||
43 | + device): | ||
44 | + print("start fed avg") | ||
45 | + criterion = nn.CrossEntropyLoss() | ||
46 | + C = 0.1 | ||
47 | + num_edge = int(max(C * args.n_nets, 1)) | ||
48 | + total_data_count = 0 | ||
49 | + for _, data_count in net_data_count.items(): | ||
50 | + total_data_count += data_count | ||
51 | + print("total data: %d" % total_data_count) | ||
52 | + | ||
53 | + for cr in range(1, args.comm_round + 1): | ||
54 | + print("Communication round : %d" % (cr)) | ||
55 | + | ||
56 | + np.random.seed(cr) # make sure for each comparison, select the same clients each round | ||
57 | + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) | ||
58 | + print("selected edge", selected_edge) | ||
59 | + | ||
60 | + for edge_progress, edge_index in enumerate(selected_edge): | ||
61 | + train_data_set.set_idx_map(data_idx_map[edge_index]) | ||
62 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | ||
63 | + shuffle=True, num_workers=2) | ||
64 | + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | ||
65 | + | ||
66 | + edges[edge_index] = copy.deepcopy(fed_model) | ||
67 | + edges[edge_index].to(device) | ||
68 | + edges[edge_index].train() | ||
69 | + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | ||
70 | + # train | ||
71 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
72 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
73 | + | ||
74 | + edge_opt.zero_grad() | ||
75 | + edge_pred = edges[edge_index](inputs) | ||
76 | + | ||
77 | + edge_loss = criterion(edge_pred, labels) | ||
78 | + edge_loss.backward() | ||
79 | + | ||
80 | + edge_opt.step() | ||
81 | + edge_loss = edge_loss.item() | ||
82 | + if data_idx % 100 == 0: | ||
83 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
84 | + # break | ||
85 | + edges[edge_index].to('cpu') | ||
86 | + | ||
87 | + # cal weight using fed avg | ||
88 | + update_state = OrderedDict() | ||
89 | + for k, edge in enumerate(edges): | ||
90 | + local_state = edge.state_dict() | ||
91 | + for key in fed_model.state_dict().keys(): | ||
92 | + if k == 0: | ||
93 | + update_state[key] = local_state[key] * net_data_count[k] / total_data_count | ||
94 | + else: | ||
95 | + update_state[key] += local_state[key] * net_data_count[k] / total_data_count | ||
96 | + | ||
97 | + fed_model.load_state_dict(update_state) | ||
98 | + if cr % 10 == 0: | ||
99 | + fed_model.to(device) | ||
100 | + fed_model.eval() | ||
101 | + | ||
102 | + total_loss = 0.0 | ||
103 | + cnt = 0 | ||
104 | + step_acc = 0.0 | ||
105 | + with torch.no_grad(): | ||
106 | + for i, data in enumerate(testloader): | ||
107 | + inputs, labels = data | ||
108 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
109 | + | ||
110 | + outputs = fed_model(inputs) | ||
111 | + _, preds = torch.max(outputs, 1) | ||
112 | + | ||
113 | + loss = criterion(outputs, labels) | ||
114 | + cnt += inputs.shape[0] | ||
115 | + | ||
116 | + corr_sum = torch.sum(preds == labels.data) | ||
117 | + step_acc += corr_sum.double() | ||
118 | + running_loss = loss.item() * inputs.shape[0] | ||
119 | + total_loss += running_loss | ||
120 | + if i % 200 == 0: | ||
121 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
122 | + # break | ||
123 | + print((step_acc / cnt).data) | ||
124 | + print(total_loss / cnt) | ||
125 | + fed_model.to('cpu') | ||
126 | + | ||
127 | + | ||
128 | +def start_fedprox(fed_model, args, | ||
129 | + train_data_set, | ||
130 | + data_idx_map, | ||
131 | + testloader, | ||
132 | + device): | ||
133 | + print("start fed prox") | ||
134 | + criterion = nn.CrossEntropyLoss() | ||
135 | + mu = 0.001 | ||
136 | + C = 0.1 | ||
137 | + num_edge = int(max(C * args.n_nets, 1)) | ||
138 | + fed_model.to(device) | ||
139 | + | ||
140 | + for cr in range(1, args.comm_round + 1): | ||
141 | + print("Communication round : %d" % (cr)) | ||
142 | + edge_weight_dict = {} | ||
143 | + fed_weight_dict = {} | ||
144 | + for fed_name, fed_param in fed_model.named_parameters(): | ||
145 | + edge_weight_dict[fed_name] = [] | ||
146 | + fed_weight_dict[fed_name] = fed_param | ||
147 | + | ||
148 | + np.random.seed(cr) # make sure for each comparison, select the same clients each round | ||
149 | + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) | ||
150 | + print("selected edge", selected_edge) | ||
151 | + | ||
152 | + total_data_length = 0 | ||
153 | + edge_data_len = [] | ||
154 | + for edge_progress, edge_index in enumerate(selected_edge): | ||
155 | + train_data_set.set_idx_map(data_idx_map[edge_index]) | ||
156 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | ||
157 | + shuffle=True, num_workers=2) | ||
158 | + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | ||
159 | + total_data_length += len(train_data_set) | ||
160 | + edge_data_len.append(len(train_data_set)) | ||
161 | + | ||
162 | + edge_model = copy.deepcopy(fed_model) | ||
163 | + edge_model.to(device) | ||
164 | + edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr) | ||
165 | + # train | ||
166 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
167 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
168 | + | ||
169 | + edge_opt.zero_grad() | ||
170 | + edge_pred = edge_model(inputs) | ||
171 | + | ||
172 | + edge_loss = criterion(edge_pred, labels) | ||
173 | + # prox term | ||
174 | + fed_prox_reg = 0.0 | ||
175 | + for edge_name, edge_param in edge_model.named_parameters(): | ||
176 | + fed_prox_reg += ((mu / 2) * torch.norm((fed_weight_dict[edge_name] - edge_param))**2) | ||
177 | + edge_loss += fed_prox_reg | ||
178 | + | ||
179 | + edge_loss.backward() | ||
180 | + | ||
181 | + edge_opt.step() | ||
182 | + edge_loss = edge_loss.item() | ||
183 | + if data_idx % 100 == 0: | ||
184 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
185 | + # break | ||
186 | + | ||
187 | + edge_model.to('cpu') | ||
188 | + # save edge weight | ||
189 | + for edge_name, edge_param in edge_model.named_parameters(): | ||
190 | + edge_weight_dict[edge_name].append(edge_param) | ||
191 | + | ||
192 | + fed_model.to('cpu') | ||
193 | + # cal weight, / number of edge | ||
194 | + for fed_name, fed_param in fed_model.named_parameters(): | ||
195 | + fed_param.data.copy_( sum(weight / num_edge for weight in edge_weight_dict[fed_name]) ) | ||
196 | + fed_model.to(device) | ||
197 | + | ||
198 | + if cr % 10 == 0: | ||
199 | + fed_model.eval() | ||
200 | + total_loss = 0.0 | ||
201 | + cnt = 0 | ||
202 | + step_acc = 0.0 | ||
203 | + with torch.no_grad(): | ||
204 | + for i, data in enumerate(testloader): | ||
205 | + inputs, labels = data | ||
206 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
207 | + | ||
208 | + outputs = fed_model(inputs) | ||
209 | + _, preds = torch.max(outputs, 1) | ||
210 | + | ||
211 | + loss = criterion(outputs, labels) | ||
212 | + cnt += inputs.shape[0] | ||
213 | + | ||
214 | + corr_sum = torch.sum(preds == labels.data) | ||
215 | + step_acc += corr_sum.double() | ||
216 | + running_loss = loss.item() * inputs.shape[0] | ||
217 | + total_loss += running_loss | ||
218 | + if i % 200 == 0: | ||
219 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
220 | + # break | ||
221 | + print((step_acc / cnt).data) | ||
222 | + print(total_loss / cnt) | ||
223 | + | ||
224 | + | ||
225 | +def start_fedtwa(fed_model, args, | ||
226 | + train_data_set, | ||
227 | + data_idx_map, | ||
228 | + net_data_count, | ||
229 | + testloader, | ||
230 | + edges, | ||
231 | + device): | ||
232 | + # TEFL, without asynchronous model update | ||
233 | + print("start fed temporally weighted aggregation") | ||
234 | + criterion = nn.CrossEntropyLoss() | ||
235 | + time_stamp = [0 for worker in range(args.n_nets)] | ||
236 | + twa_exp = math.e / 2.0 | ||
237 | + C = 0.1 | ||
238 | + num_edge = int(max(C * args.n_nets, 1)) | ||
239 | + total_data_count = 0 | ||
240 | + for _, data_count in net_data_count.items(): | ||
241 | + total_data_count += data_count | ||
242 | + print("total data: %d" % total_data_count) | ||
243 | + | ||
244 | + for cr in range(1, args.comm_round + 1): | ||
245 | + print("Communication round : %d" % (cr)) | ||
246 | + | ||
247 | + np.random.seed(cr) # make sure for each comparison, select the same clients each round | ||
248 | + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) | ||
249 | + print("selected edge", selected_edge) | ||
250 | + | ||
251 | + for edge_progress, edge_index in enumerate(selected_edge): | ||
252 | + time_stamp[edge_index] = cr | ||
253 | + train_data_set.set_idx_map(data_idx_map[edge_index]) | ||
254 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | ||
255 | + shuffle=True, num_workers=2) | ||
256 | + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | ||
257 | + | ||
258 | + edges[edge_index] = copy.deepcopy(fed_model) | ||
259 | + edges[edge_index].to(device) | ||
260 | + edges[edge_index].train() | ||
261 | + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | ||
262 | + # train | ||
263 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
264 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
265 | + | ||
266 | + edge_opt.zero_grad() | ||
267 | + edge_pred = edges[edge_index](inputs) | ||
268 | + | ||
269 | + edge_loss = criterion(edge_pred, labels) | ||
270 | + edge_loss.backward() | ||
271 | + | ||
272 | + edge_opt.step() | ||
273 | + edge_loss = edge_loss.item() | ||
274 | + if data_idx % 100 == 0: | ||
275 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
276 | + # break | ||
277 | + edges[edge_index].to('cpu') | ||
278 | + | ||
279 | + # cal weight using time stamp | ||
280 | + update_state = OrderedDict() | ||
281 | + for k, edge in enumerate(edges): | ||
282 | + local_state = edge.state_dict() | ||
283 | + for key in fed_model.state_dict().keys(): | ||
284 | + if k == 0: | ||
285 | + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) | ||
286 | + else: | ||
287 | + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k])) | ||
288 | + | ||
289 | + fed_model.load_state_dict(update_state) | ||
290 | + if cr % 10 == 0: | ||
291 | + fed_model.to(device) | ||
292 | + fed_model.eval() | ||
293 | + | ||
294 | + total_loss = 0.0 | ||
295 | + cnt = 0 | ||
296 | + step_acc = 0.0 | ||
297 | + with torch.no_grad(): | ||
298 | + for i, data in enumerate(testloader): | ||
299 | + inputs, labels = data | ||
300 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
301 | + | ||
302 | + outputs = fed_model(inputs) | ||
303 | + _, preds = torch.max(outputs, 1) | ||
304 | + | ||
305 | + loss = criterion(outputs, labels) | ||
306 | + cnt += inputs.shape[0] | ||
307 | + | ||
308 | + corr_sum = torch.sum(preds == labels.data) | ||
309 | + step_acc += corr_sum.double() | ||
310 | + running_loss = loss.item() * inputs.shape[0] | ||
311 | + total_loss += running_loss | ||
312 | + if i % 200 == 0: | ||
313 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
314 | + # break | ||
315 | + print((step_acc / cnt).data) | ||
316 | + print(total_loss / cnt) | ||
317 | + fed_model.to('cpu') | ||
318 | + | ||
319 | + | ||
320 | +def start_feddw(fed_model, args, | ||
321 | + train_data_set, | ||
322 | + data_idx_map, | ||
323 | + net_data_count, | ||
324 | + testloader, | ||
325 | + local_test_loader, | ||
326 | + edges, | ||
327 | + device): | ||
328 | + print("start fed Node-aware Dynamic Weighting") | ||
329 | + worker_selected_frequency = [0 for worker in range(args.n_nets)] | ||
330 | + criterion = nn.CrossEntropyLoss() | ||
331 | + H = 0.5 | ||
332 | + P = 0.5 | ||
333 | + G = 0.1 | ||
334 | + R = 0.1 | ||
335 | + alpha, beta, gamma = 30.0/100.0, 50.0/100.0, 20.0/100.0 | ||
336 | + num_edge = int(max(G * args.n_nets, 1)) | ||
337 | + | ||
338 | + # cal data weight for selecting participants | ||
339 | + total_data_count = 0 | ||
340 | + for _, data_count in net_data_count.items(): | ||
341 | + total_data_count += data_count | ||
342 | + print("total data: %d" % total_data_count) | ||
343 | + | ||
344 | + total_data_weight = 0.0 | ||
345 | + net_weight_dict = {} | ||
346 | + for net_key, data_count in net_data_count.items(): | ||
347 | + net_data_count[net_key] = data_count / total_data_count | ||
348 | + net_weight_dict[net_key] = total_data_count / data_count | ||
349 | + total_data_weight += net_weight_dict[net_key] | ||
350 | + | ||
351 | + for net_key, data_count in net_weight_dict.items(): | ||
352 | + net_weight_dict[net_key] = net_weight_dict[net_key] / total_data_weight | ||
353 | + # end | ||
354 | + | ||
355 | + worker_local_accuracy = [0 for worker in range(args.n_nets)] | ||
356 | + | ||
357 | + for cr in range(1, args.comm_round + 1): | ||
358 | + print("Communication round : %d" % (cr)) | ||
359 | + | ||
360 | + # select participants | ||
361 | + candidates = [] | ||
362 | + sum_frequency = sum(worker_selected_frequency) | ||
363 | + if sum_frequency == 0: | ||
364 | + sum_frequency = 1 | ||
365 | + for worker_index in range(args.n_nets): | ||
366 | + candidates.append((H * worker_selected_frequency[worker_index] / sum_frequency + (1 - H) * net_weight_dict[worker_index], worker_index)) | ||
367 | + candidates = sorted(candidates)[:int(R * args.n_nets)] | ||
368 | + candidates = [temp[1] for temp in candidates] | ||
369 | + | ||
370 | + np.random.seed(cr) | ||
371 | + selected_edge = np.random.choice(candidates, num_edge, replace=False) | ||
372 | + # end select | ||
373 | + | ||
374 | + # weighted frequency | ||
375 | + avg_selected_frequency = sum(worker_selected_frequency) / len(worker_selected_frequency) | ||
376 | + weighted_frequency = [P * (avg_selected_frequency - worker_frequency) for worker_frequency in worker_selected_frequency] | ||
377 | + frequency_prime = min(weighted_frequency) | ||
378 | + weighted_frequency = [frequency + frequency_prime + 1 for frequency in weighted_frequency] | ||
379 | + # end weigthed | ||
380 | + | ||
381 | + print("selected edge", selected_edge) | ||
382 | + for edge_progress, edge_index in enumerate(selected_edge): | ||
383 | + worker_selected_frequency[edge_index] += 1 | ||
384 | + train_data_set.set_idx_map(data_idx_map[edge_index]) | ||
385 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | ||
386 | + shuffle=True, num_workers=2) | ||
387 | + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | ||
388 | + | ||
389 | + edges[edge_index] = copy.deepcopy(fed_model) | ||
390 | + edges[edge_index].to(device) | ||
391 | + edges[edge_index].train() | ||
392 | + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | ||
393 | + # train | ||
394 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
395 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
396 | + | ||
397 | + edge_opt.zero_grad() | ||
398 | + edge_pred = edges[edge_index](inputs) | ||
399 | + | ||
400 | + edge_loss = criterion(edge_pred, labels) | ||
401 | + edge_loss.backward() | ||
402 | + | ||
403 | + edge_opt.step() | ||
404 | + edge_loss = edge_loss.item() | ||
405 | + if data_idx % 100 == 0: | ||
406 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
407 | + # break | ||
408 | + | ||
409 | + # get edge accuracy using subset of testset | ||
410 | + edges[edge_index].eval() | ||
411 | + print("[%2d/%2d] edge: %d, cal accuracy" % (edge_progress, len(selected_edge), edge_index)) | ||
412 | + cnt = 0 | ||
413 | + step_acc = 0.0 | ||
414 | + with torch.no_grad(): | ||
415 | + for inputs, labels in local_test_loader: | ||
416 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
417 | + | ||
418 | + outputs = edges[edge_index](inputs) | ||
419 | + _, preds = torch.max(outputs, 1) | ||
420 | + | ||
421 | + loss = criterion(outputs, labels) | ||
422 | + cnt += inputs.shape[0] | ||
423 | + | ||
424 | + corr_sum = torch.sum(preds == labels.data) | ||
425 | + step_acc += corr_sum.double() | ||
426 | + # break | ||
427 | + | ||
428 | + worker_local_accuracy[edge_index] = (step_acc / cnt).item() | ||
429 | + print(worker_local_accuracy[edge_index]) | ||
430 | + edges[edge_index].to('cpu') | ||
431 | + | ||
432 | + # cal weight dynamically | ||
433 | + sum_accuracy = sum(worker_local_accuracy) | ||
434 | + sum_weighted_frequency = sum(weighted_frequency) | ||
435 | + update_state = OrderedDict() | ||
436 | + for k, edge in enumerate(edges): | ||
437 | + local_state = edge.state_dict() | ||
438 | + for key in fed_model.state_dict().keys(): | ||
439 | + if k == 0: | ||
440 | + update_state[key] = local_state[key] \ | ||
441 | + * (net_data_count[k] * alpha \ | ||
442 | + + worker_local_accuracy[k] / sum_accuracy * beta \ | ||
443 | + + weighted_frequency[k] / sum_weighted_frequency * gamma) | ||
444 | + else: | ||
445 | + update_state[key] += local_state[key] \ | ||
446 | + * (net_data_count[k] * alpha \ | ||
447 | + + worker_local_accuracy[k] / sum_accuracy * beta \ | ||
448 | + + weighted_frequency[k] / sum_weighted_frequency * gamma) | ||
449 | + | ||
450 | + fed_model.load_state_dict(update_state) | ||
451 | + if cr % 10 == 0: | ||
452 | + fed_model.to(device) | ||
453 | + fed_model.eval() | ||
454 | + | ||
455 | + total_loss = 0.0 | ||
456 | + cnt = 0 | ||
457 | + step_acc = 0.0 | ||
458 | + with torch.no_grad(): | ||
459 | + for i, data in enumerate(testloader): | ||
460 | + inputs, labels = data | ||
461 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
462 | + | ||
463 | + outputs = fed_model(inputs) | ||
464 | + _, preds = torch.max(outputs, 1) | ||
465 | + | ||
466 | + loss = criterion(outputs, labels) | ||
467 | + cnt += inputs.shape[0] | ||
468 | + | ||
469 | + corr_sum = torch.sum(preds == labels.data) | ||
470 | + step_acc += corr_sum.double() | ||
471 | + running_loss = loss.item() * inputs.shape[0] | ||
472 | + total_loss += running_loss | ||
473 | + if i % 200 == 0: | ||
474 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
475 | + # break | ||
476 | + print((step_acc / cnt).data) | ||
477 | + print(total_loss / cnt) | ||
478 | + fed_model.to('cpu') | ||
479 | + | ||
480 | + | ||
481 | +def start_train(): | ||
482 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
483 | + print(device) | ||
484 | + args = add_args(argparse.ArgumentParser()) | ||
485 | + | ||
486 | + seed = 0 | ||
487 | + np.random.seed(seed) | ||
488 | + torch.manual_seed(seed) | ||
489 | + | ||
490 | + print("Loading data...") | ||
491 | + # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt', | ||
492 | + # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt', | ||
493 | + # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt', | ||
494 | + # "./dataset/gear_dataset.csv" : './gear_dataset.txt' | ||
495 | + # } | ||
496 | + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | ||
497 | + train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs) | ||
498 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, | ||
499 | + shuffle=False, num_workers=2) | ||
500 | + | ||
501 | + fed_model = model.Net() | ||
502 | + args.comm_type = 'feddw' | ||
503 | + if args.comm_type == "fedavg": | ||
504 | + edges, _, _ = init_models(args.n_nets, args) | ||
505 | + start_fedavg(fed_model, args, | ||
506 | + train_data_set, | ||
507 | + data_idx_map, | ||
508 | + net_data_count, | ||
509 | + testloader, | ||
510 | + edges, | ||
511 | + device) | ||
512 | + elif args.comm_type == "fedprox": | ||
513 | + start_fedprox(fed_model, args, | ||
514 | + train_data_set, | ||
515 | + data_idx_map, | ||
516 | + testloader, | ||
517 | + device) | ||
518 | + elif args.comm_type == "fedtwa": | ||
519 | + edges, _, _ = init_models(args.n_nets, args) | ||
520 | + start_fedtwa(fed_model, args, | ||
521 | + train_data_set, | ||
522 | + data_idx_map, | ||
523 | + net_data_count, | ||
524 | + testloader, | ||
525 | + edges, | ||
526 | + device) | ||
527 | + elif args.comm_type == "feddw": | ||
528 | + local_test_set = copy.deepcopy(test_data_set) | ||
529 | + # mnist train 60,000 / test 10,000 / 1,000 | ||
530 | + # CAN train ~ 13,000,000 / test 2,000,000 / for speed 40,000 | ||
531 | + local_test_idx = np.random.choice(len(local_test_set), len(local_test_set) // 50, replace=False) | ||
532 | + local_test_set.set_idx_map(local_test_idx) | ||
533 | + local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size, | ||
534 | + shuffle=False, num_workers=2) | ||
535 | + | ||
536 | + edges, _, _ = init_models(args.n_nets, args) | ||
537 | + start_feddw(fed_model, args, | ||
538 | + train_data_set, | ||
539 | + data_idx_map, | ||
540 | + net_data_count, | ||
541 | + testloader, | ||
542 | + local_test_loader, | ||
543 | + edges, | ||
544 | + device) | ||
545 | + | ||
546 | +if __name__ == "__main__": | ||
547 | + start_train() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/model.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import torch.nn.functional as F | ||
3 | +import torch | ||
4 | +import const | ||
5 | + | ||
6 | +class Net(nn.Module): | ||
7 | + def __init__(self): | ||
8 | + super(Net, self).__init__() | ||
9 | + | ||
10 | + self.f1 = nn.Sequential( | ||
11 | + nn.Conv2d(1, 2, 3), | ||
12 | + nn.ReLU(True), | ||
13 | + ) | ||
14 | + self.f2 = nn.Sequential( | ||
15 | + nn.Conv2d(2, 4, 3), | ||
16 | + nn.ReLU(True), | ||
17 | + ) | ||
18 | + self.f3 = nn.Sequential( | ||
19 | + nn.Conv2d(4, 8, 3), | ||
20 | + nn.ReLU(True), | ||
21 | + ) | ||
22 | + self.f4 = nn.Sequential( | ||
23 | + nn.Linear(8 * 23 * 23, 2), | ||
24 | + ) | ||
25 | + | ||
26 | + def forward(self, x): | ||
27 | + x = self.f1(x) | ||
28 | + x = self.f2(x) | ||
29 | + x = self.f3(x) | ||
30 | + x = torch.flatten(x, 1) | ||
31 | + x = self.f4(x) | ||
32 | + return x | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment