김지훈

docs: federated train readme

1 +# khu_capstone_1
2 +
3 +## 연합학습 기반 유해트래픽 탐지
4 +- Pytorch
5 +- CAN protocol 유해 트래픽 데이터 셋
6 +- FedAvg, FedProx, Fed using timestamp, Fed dynamic weight 논문 구현 및 성능 비교
7 +
8 +## Model train
9 +- Install [PyTorch](http://pytorch.org)
10 +- Train model
11 +```bash
12 +python3 fed_train.py --packet_num 3 --fold_num 0 --batch_size 128 --lr 0.001 --n_nets 100 --comm_type fedprox --comm_round 50
13 +```
...@@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 ...@@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0
49 def add_args(parser): 49 def add_args(parser):
50 parser.add_argument('--packet_num', type=int, default=1, 50 parser.add_argument('--packet_num', type=int, default=1,
51 help='packet number used in training, 1 ~ 3') 51 help='packet number used in training, 1 ~ 3')
52 - parser.add_argument('--dataset', type=str, default='can',
53 - help='dataset used for training, can or syncan')
54 parser.add_argument('--fold_num', type=int, default=0, 52 parser.add_argument('--fold_num', type=int, default=0,
55 help='5-fold, 0 ~ 4') 53 help='5-fold, 0 ~ 4')
56 parser.add_argument('--batch_size', type=int, default=128, 54 parser.add_argument('--batch_size', type=int, default=128,
...@@ -549,21 +547,14 @@ def start_train(): ...@@ -549,21 +547,14 @@ def start_train():
549 torch.manual_seed(seed) 547 torch.manual_seed(seed)
550 548
551 print("Loading data...") 549 print("Loading data...")
552 - if args.dataset == 'can': 550 + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
553 - train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") 551 +
554 - elif args.dataset == 'syncan':
555 - train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt")
556 -
557 sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) 552 sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size)
558 testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, 553 testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler,
559 shuffle=False, num_workers=2, drop_last=True) 554 shuffle=False, num_workers=2, drop_last=True)
560 555
561 - if args.dataset == 'can': 556 + fed_model = model.OneNet(args.packet_num)
562 - fed_model = model.OneNet(args.packet_num) 557 + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
563 - edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
564 - elif args.dataset == 'syncan':
565 - fed_model = model.OneNet(args.packet_num)
566 - edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
567 558
568 if args.comm_type == "fedavg": 559 if args.comm_type == "fedavg":
569 start_fedavg(fed_model, args, 560 start_fedavg(fed_model, args,
......