Showing
2 changed files
with
13 additions
and
9 deletions
코드/연합학습/README.md
0 → 100644
1 | +# khu_capstone_1 | ||
2 | + | ||
3 | +## 연합학습 기반 유해트래픽 탐지 | ||
4 | +- Pytorch | ||
5 | +- CAN protocol 유해 트래픽 데이터 셋 | ||
6 | +- FedAvg, FedProx, Fed using timestamp, Fed dynamic weight 논문 구현 및 성능 비교 | ||
7 | + | ||
8 | +## Model train | ||
9 | +- Install [PyTorch](http://pytorch.org) | ||
10 | +- Train model | ||
11 | +```bash | ||
12 | +python3 fed_train.py --packet_num 3 --fold_num 0 --batch_size 128 --lr 0.001 --n_nets 100 --comm_type fedprox --comm_round 50 | ||
13 | +``` |
... | @@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 | ... | @@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 |
49 | def add_args(parser): | 49 | def add_args(parser): |
50 | parser.add_argument('--packet_num', type=int, default=1, | 50 | parser.add_argument('--packet_num', type=int, default=1, |
51 | help='packet number used in training, 1 ~ 3') | 51 | help='packet number used in training, 1 ~ 3') |
52 | - parser.add_argument('--dataset', type=str, default='can', | ||
53 | - help='dataset used for training, can or syncan') | ||
54 | parser.add_argument('--fold_num', type=int, default=0, | 52 | parser.add_argument('--fold_num', type=int, default=0, |
55 | help='5-fold, 0 ~ 4') | 53 | help='5-fold, 0 ~ 4') |
56 | parser.add_argument('--batch_size', type=int, default=128, | 54 | parser.add_argument('--batch_size', type=int, default=128, |
... | @@ -549,19 +547,12 @@ def start_train(): | ... | @@ -549,19 +547,12 @@ def start_train(): |
549 | torch.manual_seed(seed) | 547 | torch.manual_seed(seed) |
550 | 548 | ||
551 | print("Loading data...") | 549 | print("Loading data...") |
552 | - if args.dataset == 'can': | ||
553 | train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") | 550 | train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") |
554 | - elif args.dataset == 'syncan': | ||
555 | - train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt") | ||
556 | 551 | ||
557 | sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) | 552 | sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) |
558 | testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, | 553 | testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, |
559 | shuffle=False, num_workers=2, drop_last=True) | 554 | shuffle=False, num_workers=2, drop_last=True) |
560 | 555 | ||
561 | - if args.dataset == 'can': | ||
562 | - fed_model = model.OneNet(args.packet_num) | ||
563 | - edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] | ||
564 | - elif args.dataset == 'syncan': | ||
565 | fed_model = model.OneNet(args.packet_num) | 556 | fed_model = model.OneNet(args.packet_num) |
566 | edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] | 557 | edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] |
567 | 558 | ... | ... |
-
Please register or login to post a comment