김지훈

엣지모델 수정 전, 학습 속도 개선 필요

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
\ No newline at end of file
CAN_ID_BIT = 29
\ No newline at end of file
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import const
'''
def int_to_binary(x, bits):
mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
'''
def unpack_bits(x, num_bits):
"""
Args:
x (int): bit로 변환할 정수
num_bits (int): 표현할 비트수
"""
xshape = list(x.shape)
x = x.reshape([-1, 1])
mask = 2**np.arange(num_bits).reshape([1, num_bits])
return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])
# def CsvToNumpy(csv_file):
# target_csv = pd.read_csv(csv_file)
# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
# print(inputs_save_numpy, labels_save_numpy)
# i = 0
# inputs_array = []
# labels_array = []
# print(len(target_csv))
# while i + const.CAN_ID_BIT - 1 < len(target_csv):
# is_regular = True
# for j in range(const.CAN_ID_BIT):
# l = target_csv.iloc[i + j]
# b = l[2]
# r = (l[b+2+1] == 'R')
# if not r:
# is_regular = False
# break
# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
# for idx in range(const.CAN_ID_BIT):
# can_id = int(target_csv.iloc[i + idx, 1], 16)
# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT)
# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
# if is_regular:
# labels = 1
# else:
# labels = 0
# inputs_array.append(inputs)
# labels_array.append(labels)
# i+=1
# if (i % 5000 == 0):
# print(i)
# # break
# inputs_array = np.array(inputs_array)
# labels_array = np.array(labels_array)
# np.save(inputs_save_numpy, arr=inputs_array)
# np.save(labels_save_numpy, arr=labels_array)
# print('done')
def CsvToText(csv_file):
target_csv = pd.read_csv(csv_file)
text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt'
print(text_file_name)
target_text = open(text_file_name, mode='wt', encoding='utf-8')
i = 0
datum = [ [], [] ]
print(len(target_csv))
while i + const.CAN_ID_BIT - 1 < len(target_csv):
is_regular = True
for j in range(const.CAN_ID_BIT):
l = target_csv.iloc[i + j]
b = l[2]
r = (l[b+2+1] == 'R')
if not r:
is_regular = False
break
if is_regular:
target_text.write("%d R\n" % i)
else:
target_text.write("%d T\n" % i)
i+=1
if (i % 5000 == 0):
print(i)
target_text.close()
print('done')
def record_net_data_stats(label_temp, data_idx_map):
net_class_count = {}
net_data_count= {}
for net_i, dataidx in data_idx_map.items():
unq, unq_cnt = np.unique(label_temp[dataidx], return_counts=True)
tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
net_class_count[net_i] = tmp
net_data_count[net_i] = len(dataidx)
print('Data statistics: %s' % str(net_class_count))
return net_class_count, net_data_count
def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs):
csv_list = []
total_datum = []
total_label_temp = []
csv_idx = 0
for csv_file, txt_file in kwargs.items():
csv = pd.read_csv(csv_file)
csv_list.append(csv)
txt = open(txt_file, "r")
lines = txt.read().splitlines()
idx = 0
local_datum = []
while idx + const.CAN_ID_BIT - 1 < len(csv):
line = lines[idx]
if not line:
break
if line.split(' ')[1] == 'R':
local_datum.append((csv_idx, idx, 1))
total_label_temp.append(1)
else:
local_datum.append((csv_idx, idx, 0))
total_label_temp.append(0)
idx += 1
if (idx % 1000000 == 0):
print(idx)
csv_idx += 1
total_datum += local_datum
fold_length = int(len(total_label_temp) / 5)
datum = []
label_temp = []
for i in range(5):
if i != fold_num:
datum += total_datum[i*fold_length:(i+1)*fold_length]
label_temp += total_label_temp[i*fold_length:(i+1)*fold_length]
else:
test_datum = total_datum[i*fold_length:(i+1)*fold_length]
min_size = 0
output_class_num = 2
N = len(label_temp)
label_temp = np.array(label_temp)
data_idx_map = {}
while min_size < 512:
idx_batch = [[] for _ in range(total_edge)]
# for each class in the dataset
for k in range(output_class_num):
idx_k = np.where(label_temp == k)[0]
np.random.shuffle(idx_k)
proportions = np.random.dirichlet(np.repeat(1, total_edge))
## Balance
proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
proportions = proportions/proportions.sum()
proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
min_size = min([len(idx_j) for idx_j in idx_batch])
for j in range(total_edge):
np.random.shuffle(idx_batch[j])
data_idx_map[j] = idx_batch[j]
net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map)
return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False)
class CanDatasetKwarg(Dataset):
def __init__(self, csv_list, datum, is_train=True):
self.csv_list = csv_list
self.datum = datum
if is_train:
self.idx_map = []
else:
self.idx_map = [idx for idx in range(len(self.datum))]
def __len__(self):
return len(self.idx_map)
def set_idx_map(self, data_idx_map):
self.idx_map = data_idx_map
def __getitem__(self, idx):
csv_idx = self.datum[self.idx_map[idx]][0]
start_i = self.datum[self.idx_map[idx]][1]
is_regular = self.datum[self.idx_map[idx]][2]
l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
for i in range(const.CAN_ID_BIT):
id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16)
bits = unpack_bits(np.array(id_), const.CAN_ID_BIT)
l[i] = bits
l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
return (l, is_regular)
def GetCanDatasetUsingTxt(csv_file, txt_path, length):
csv = pd.read_csv(csv_file)
txt = open(txt_path, "r")
lines = txt.read().splitlines()
idx = 0
datum = [ [], [] ]
while idx + const.CAN_ID_BIT - 1 < len(csv):
if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
break
line = lines[idx]
if not line:
break
if line.split(' ')[1] == 'R':
if len(datum[0]) < length//2:
datum[0].append((idx, 1))
else:
if len(datum[1]) < length//2:
datum[1].append((idx, 0))
idx += 1
if (idx % 5000 == 0):
print(idx, len(datum[0]), len(datum[1]))
l = int((length // 2) * 0.9)
return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
CanDataset(csv, datum[0][l:] + datum[1][l:])
def GetCanDataset(csv_file, length):
csv = pd.read_csv(csv_file)
i = 0
datum = [ [], [] ]
while i + const.CAN_ID_BIT - 1 < len(csv):
if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
break
is_regular = True
for j in range(const.CAN_ID_BIT):
l = csv.iloc[i + j]
b = l[2]
r = (l[b+2+1] == 'R')
if not r:
is_regular = False
break
if is_regular:
if len(datum[0]) < length//2:
datum[0].append((i, 1))
else:
if len(datum[1]) < length//2:
datum[1].append((i, 0))
i+=1
if (i % 5000 == 0):
print(i, len(datum[0]), len(datum[1]))
l = int((length // 2) * 0.9)
return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
CanDataset(csv, datum[0][l:] + datum[1][l:])
class CanDataset(Dataset):
def __init__(self, csv, datum):
self.csv = csv
self.datum = datum
def __len__(self):
return len(self.datum)
def __getitem__(self, idx):
start_i = self.datum[idx][0]
is_regular = self.datum[idx][1]
l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
for i in range(const.CAN_ID_BIT):
id = int(self.csv.iloc[start_i + i, 1], 16)
bits = unpack_bits(np.array(id), const.CAN_ID_BIT)
l[i] = bits
l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
return (l, is_regular)
if __name__ == "__main__":
kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs)
testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size,
shuffle=False, num_workers=2)
for x, y in testloader:
print(x)
print(y)
break
import utils
import copy
from collections import OrderedDict
import model
import dataset
import importlib
importlib.reload(utils)
importlib.reload(model)
importlib.reload(dataset)
from utils import *
def add_args(parser):
# parser.add_argument('--model', type=str, default='moderate-cnn',
# help='neural network used in training')
parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
help='dataset used for training')
parser.add_argument('--fold_num', type=int, default=0,
help='5-fold, 0 ~ 4')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
help='input batch size for training')
parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
help='learning rate')
parser.add_argument('--n_nets', type=int, default=100, metavar='NN',
help='number of workers in a distributed cluster')
parser.add_argument('--comm_type', type=str, default='fedtwa',
help='which type of communication strategy is going to be used: layerwise/blockwise')
parser.add_argument('--comm_round', type=int, default=10,
help='how many round of communications we shoud use')
args = parser.parse_args(args=[])
return args
def start_fedavg(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
edges,
device):
print("start fed avg")
criterion = nn.CrossEntropyLoss()
C = 0.1
num_edge = int(max(C * args.n_nets, 1))
total_data_count = 0
for _, data_count in net_data_count.items():
total_data_count += data_count
print("total data: %d" % total_data_count)
for cr in range(1, args.comm_round + 1):
print("Communication round : %d" % (cr))
np.random.seed(cr) # make sure for each comparison, select the same clients each round
selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
print("selected edge", selected_edge)
for edge_progress, edge_index in enumerate(selected_edge):
train_data_set.set_idx_map(data_idx_map[edge_index])
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
shuffle=True, num_workers=2)
print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
edges[edge_index] = copy.deepcopy(fed_model)
edges[edge_index].to(device)
edges[edge_index].train()
edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
# train
for data_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.float().to(device), labels.long().to(device)
edge_opt.zero_grad()
edge_pred = edges[edge_index](inputs)
edge_loss = criterion(edge_pred, labels)
edge_loss.backward()
edge_opt.step()
edge_loss = edge_loss.item()
if data_idx % 100 == 0:
print('[%4d] loss: %.3f' % (data_idx, edge_loss))
# break
edges[edge_index].to('cpu')
# cal weight using fed avg
update_state = OrderedDict()
for k, edge in enumerate(edges):
local_state = edge.state_dict()
for key in fed_model.state_dict().keys():
if k == 0:
update_state[key] = local_state[key] * net_data_count[k] / total_data_count
else:
update_state[key] += local_state[key] * net_data_count[k] / total_data_count
fed_model.load_state_dict(update_state)
if cr % 10 == 0:
fed_model.to(device)
fed_model.eval()
total_loss = 0.0
cnt = 0
step_acc = 0.0
with torch.no_grad():
for i, data in enumerate(testloader):
inputs, labels = data
inputs, labels = inputs.float().to(device), labels.long().to(device)
outputs = fed_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
cnt += inputs.shape[0]
corr_sum = torch.sum(preds == labels.data)
step_acc += corr_sum.double()
running_loss = loss.item() * inputs.shape[0]
total_loss += running_loss
if i % 200 == 0:
print('test [%4d] loss: %.3f' % (i, loss.item()))
# break
print((step_acc / cnt).data)
print(total_loss / cnt)
fed_model.to('cpu')
def start_fedprox(fed_model, args,
train_data_set,
data_idx_map,
testloader,
device):
print("start fed prox")
criterion = nn.CrossEntropyLoss()
mu = 0.001
C = 0.1
num_edge = int(max(C * args.n_nets, 1))
fed_model.to(device)
for cr in range(1, args.comm_round + 1):
print("Communication round : %d" % (cr))
edge_weight_dict = {}
fed_weight_dict = {}
for fed_name, fed_param in fed_model.named_parameters():
edge_weight_dict[fed_name] = []
fed_weight_dict[fed_name] = fed_param
np.random.seed(cr) # make sure for each comparison, select the same clients each round
selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
print("selected edge", selected_edge)
total_data_length = 0
edge_data_len = []
for edge_progress, edge_index in enumerate(selected_edge):
train_data_set.set_idx_map(data_idx_map[edge_index])
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
shuffle=True, num_workers=2)
print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
total_data_length += len(train_data_set)
edge_data_len.append(len(train_data_set))
edge_model = copy.deepcopy(fed_model)
edge_model.to(device)
edge_opt = optim.Adam(params=edge_model.parameters(),lr=args.lr)
# train
for data_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.float().to(device), labels.long().to(device)
edge_opt.zero_grad()
edge_pred = edge_model(inputs)
edge_loss = criterion(edge_pred, labels)
# prox term
fed_prox_reg = 0.0
for edge_name, edge_param in edge_model.named_parameters():
fed_prox_reg += ((mu / 2) * torch.norm((fed_weight_dict[edge_name] - edge_param))**2)
edge_loss += fed_prox_reg
edge_loss.backward()
edge_opt.step()
edge_loss = edge_loss.item()
if data_idx % 100 == 0:
print('[%4d] loss: %.3f' % (data_idx, edge_loss))
# break
edge_model.to('cpu')
# save edge weight
for edge_name, edge_param in edge_model.named_parameters():
edge_weight_dict[edge_name].append(edge_param)
fed_model.to('cpu')
# cal weight, / number of edge
for fed_name, fed_param in fed_model.named_parameters():
fed_param.data.copy_( sum(weight / num_edge for weight in edge_weight_dict[fed_name]) )
fed_model.to(device)
if cr % 10 == 0:
fed_model.eval()
total_loss = 0.0
cnt = 0
step_acc = 0.0
with torch.no_grad():
for i, data in enumerate(testloader):
inputs, labels = data
inputs, labels = inputs.float().to(device), labels.long().to(device)
outputs = fed_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
cnt += inputs.shape[0]
corr_sum = torch.sum(preds == labels.data)
step_acc += corr_sum.double()
running_loss = loss.item() * inputs.shape[0]
total_loss += running_loss
if i % 200 == 0:
print('test [%4d] loss: %.3f' % (i, loss.item()))
# break
print((step_acc / cnt).data)
print(total_loss / cnt)
def start_fedtwa(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
edges,
device):
# TEFL, without asynchronous model update
print("start fed temporally weighted aggregation")
criterion = nn.CrossEntropyLoss()
time_stamp = [0 for worker in range(args.n_nets)]
twa_exp = math.e / 2.0
C = 0.1
num_edge = int(max(C * args.n_nets, 1))
total_data_count = 0
for _, data_count in net_data_count.items():
total_data_count += data_count
print("total data: %d" % total_data_count)
for cr in range(1, args.comm_round + 1):
print("Communication round : %d" % (cr))
np.random.seed(cr) # make sure for each comparison, select the same clients each round
selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
print("selected edge", selected_edge)
for edge_progress, edge_index in enumerate(selected_edge):
time_stamp[edge_index] = cr
train_data_set.set_idx_map(data_idx_map[edge_index])
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
shuffle=True, num_workers=2)
print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
edges[edge_index] = copy.deepcopy(fed_model)
edges[edge_index].to(device)
edges[edge_index].train()
edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
# train
for data_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.float().to(device), labels.long().to(device)
edge_opt.zero_grad()
edge_pred = edges[edge_index](inputs)
edge_loss = criterion(edge_pred, labels)
edge_loss.backward()
edge_opt.step()
edge_loss = edge_loss.item()
if data_idx % 100 == 0:
print('[%4d] loss: %.3f' % (data_idx, edge_loss))
# break
edges[edge_index].to('cpu')
# cal weight using time stamp
update_state = OrderedDict()
for k, edge in enumerate(edges):
local_state = edge.state_dict()
for key in fed_model.state_dict().keys():
if k == 0:
update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k]))
else:
update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) * math.pow(twa_exp, -(cr - time_stamp[k]))
fed_model.load_state_dict(update_state)
if cr % 10 == 0:
fed_model.to(device)
fed_model.eval()
total_loss = 0.0
cnt = 0
step_acc = 0.0
with torch.no_grad():
for i, data in enumerate(testloader):
inputs, labels = data
inputs, labels = inputs.float().to(device), labels.long().to(device)
outputs = fed_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
cnt += inputs.shape[0]
corr_sum = torch.sum(preds == labels.data)
step_acc += corr_sum.double()
running_loss = loss.item() * inputs.shape[0]
total_loss += running_loss
if i % 200 == 0:
print('test [%4d] loss: %.3f' % (i, loss.item()))
# break
print((step_acc / cnt).data)
print(total_loss / cnt)
fed_model.to('cpu')
def start_feddw(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
local_test_loader,
edges,
device):
print("start fed Node-aware Dynamic Weighting")
worker_selected_frequency = [0 for worker in range(args.n_nets)]
criterion = nn.CrossEntropyLoss()
H = 0.5
P = 0.5
G = 0.1
R = 0.1
alpha, beta, gamma = 30.0/100.0, 50.0/100.0, 20.0/100.0
num_edge = int(max(G * args.n_nets, 1))
# cal data weight for selecting participants
total_data_count = 0
for _, data_count in net_data_count.items():
total_data_count += data_count
print("total data: %d" % total_data_count)
total_data_weight = 0.0
net_weight_dict = {}
for net_key, data_count in net_data_count.items():
net_data_count[net_key] = data_count / total_data_count
net_weight_dict[net_key] = total_data_count / data_count
total_data_weight += net_weight_dict[net_key]
for net_key, data_count in net_weight_dict.items():
net_weight_dict[net_key] = net_weight_dict[net_key] / total_data_weight
# end
worker_local_accuracy = [0 for worker in range(args.n_nets)]
for cr in range(1, args.comm_round + 1):
print("Communication round : %d" % (cr))
# select participants
candidates = []
sum_frequency = sum(worker_selected_frequency)
if sum_frequency == 0:
sum_frequency = 1
for worker_index in range(args.n_nets):
candidates.append((H * worker_selected_frequency[worker_index] / sum_frequency + (1 - H) * net_weight_dict[worker_index], worker_index))
candidates = sorted(candidates)[:int(R * args.n_nets)]
candidates = [temp[1] for temp in candidates]
np.random.seed(cr)
selected_edge = np.random.choice(candidates, num_edge, replace=False)
# end select
# weighted frequency
avg_selected_frequency = sum(worker_selected_frequency) / len(worker_selected_frequency)
weighted_frequency = [P * (avg_selected_frequency - worker_frequency) for worker_frequency in worker_selected_frequency]
frequency_prime = min(weighted_frequency)
weighted_frequency = [frequency + frequency_prime + 1 for frequency in weighted_frequency]
# end weigthed
print("selected edge", selected_edge)
for edge_progress, edge_index in enumerate(selected_edge):
worker_selected_frequency[edge_index] += 1
train_data_set.set_idx_map(data_idx_map[edge_index])
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
shuffle=True, num_workers=2)
print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
edges[edge_index] = copy.deepcopy(fed_model)
edges[edge_index].to(device)
edges[edge_index].train()
edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
# train
for data_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.float().to(device), labels.long().to(device)
edge_opt.zero_grad()
edge_pred = edges[edge_index](inputs)
edge_loss = criterion(edge_pred, labels)
edge_loss.backward()
edge_opt.step()
edge_loss = edge_loss.item()
if data_idx % 100 == 0:
print('[%4d] loss: %.3f' % (data_idx, edge_loss))
# break
# get edge accuracy using subset of testset
edges[edge_index].eval()
print("[%2d/%2d] edge: %d, cal accuracy" % (edge_progress, len(selected_edge), edge_index))
cnt = 0
step_acc = 0.0
with torch.no_grad():
for inputs, labels in local_test_loader:
inputs, labels = inputs.float().to(device), labels.long().to(device)
outputs = edges[edge_index](inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
cnt += inputs.shape[0]
corr_sum = torch.sum(preds == labels.data)
step_acc += corr_sum.double()
# break
worker_local_accuracy[edge_index] = (step_acc / cnt).item()
print(worker_local_accuracy[edge_index])
edges[edge_index].to('cpu')
# cal weight dynamically
sum_accuracy = sum(worker_local_accuracy)
sum_weighted_frequency = sum(weighted_frequency)
update_state = OrderedDict()
for k, edge in enumerate(edges):
local_state = edge.state_dict()
for key in fed_model.state_dict().keys():
if k == 0:
update_state[key] = local_state[key] \
* (net_data_count[k] * alpha \
+ worker_local_accuracy[k] / sum_accuracy * beta \
+ weighted_frequency[k] / sum_weighted_frequency * gamma)
else:
update_state[key] += local_state[key] \
* (net_data_count[k] * alpha \
+ worker_local_accuracy[k] / sum_accuracy * beta \
+ weighted_frequency[k] / sum_weighted_frequency * gamma)
fed_model.load_state_dict(update_state)
if cr % 10 == 0:
fed_model.to(device)
fed_model.eval()
total_loss = 0.0
cnt = 0
step_acc = 0.0
with torch.no_grad():
for i, data in enumerate(testloader):
inputs, labels = data
inputs, labels = inputs.float().to(device), labels.long().to(device)
outputs = fed_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
cnt += inputs.shape[0]
corr_sum = torch.sum(preds == labels.data)
step_acc += corr_sum.double()
running_loss = loss.item() * inputs.shape[0]
total_loss += running_loss
if i % 200 == 0:
print('test [%4d] loss: %.3f' % (i, loss.item()))
# break
print((step_acc / cnt).data)
print(total_loss / cnt)
fed_model.to('cpu')
def start_train():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
args = add_args(argparse.ArgumentParser())
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
print("Loading data...")
# kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt',
# "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt',
# "./dataset/RPM_dataset.csv" : './RPM_dataset.txt',
# "./dataset/gear_dataset.csv" : './gear_dataset.txt'
# }
kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs)
testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size,
shuffle=False, num_workers=2)
fed_model = model.Net()
args.comm_type = 'feddw'
if args.comm_type == "fedavg":
edges, _, _ = init_models(args.n_nets, args)
start_fedavg(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
edges,
device)
elif args.comm_type == "fedprox":
start_fedprox(fed_model, args,
train_data_set,
data_idx_map,
testloader,
device)
elif args.comm_type == "fedtwa":
edges, _, _ = init_models(args.n_nets, args)
start_fedtwa(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
edges,
device)
elif args.comm_type == "feddw":
local_test_set = copy.deepcopy(test_data_set)
# mnist train 60,000 / test 10,000 / 1,000
# CAN train ~ 13,000,000 / test 2,000,000 / for speed 40,000
local_test_idx = np.random.choice(len(local_test_set), len(local_test_set) // 50, replace=False)
local_test_set.set_idx_map(local_test_idx)
local_test_loader = torch.utils.data.DataLoader(local_test_set, batch_size=args.batch_size,
shuffle=False, num_workers=2)
edges, _, _ = init_models(args.n_nets, args)
start_feddw(fed_model, args,
train_data_set,
data_idx_map,
net_data_count,
testloader,
local_test_loader,
edges,
device)
if __name__ == "__main__":
start_train()
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
import torch
import const
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.f1 = nn.Sequential(
nn.Conv2d(1, 2, 3),
nn.ReLU(True),
)
self.f2 = nn.Sequential(
nn.Conv2d(2, 4, 3),
nn.ReLU(True),
)
self.f3 = nn.Sequential(
nn.Conv2d(4, 8, 3),
nn.ReLU(True),
)
self.f4 = nn.Sequential(
nn.Linear(8 * 23 * 23, 2),
)
def forward(self, x):
x = self.f1(x)
x = self.f2(x)
x = self.f3(x)
x = torch.flatten(x, 1)
x = self.f4(x)
return x
\ No newline at end of file