조현아

update classifier

...@@ -33,12 +33,9 @@ def eval(model_path): ...@@ -33,12 +33,9 @@ def eval(model_path):
33 model.load_state_dict(torch.load(weight_path)) 33 model.load_state_dict(torch.load(weight_path))
34 34
35 print('\n[+] Load dataset') 35 print('\n[+] Load dataset')
36 - test_transform = get_valid_transform(args, model)
37 - #print('\nTEST Transform\n', test_transform)
38 test_dataset = get_dataset(args, 'test') 36 test_dataset = get_dataset(args, 'test')
39 37
40 38
41 -
42 test_loader = iter(get_dataloader(args, test_dataset)) ### 39 test_loader = iter(get_dataloader(args, test_dataset)) ###
43 40
44 print('\n[+] Start testing') 41 print('\n[+] Start testing')
......
...@@ -16,6 +16,8 @@ class BaseNet(nn.Module): ...@@ -16,6 +16,8 @@ class BaseNet(nn.Module):
16 x = self.after(f) 16 x = self.after(f)
17 x = x.reshape(x.size(0), -1) 17 x = x.reshape(x.size(0), -1)
18 x = self.fc(x) 18 x = self.fc(x)
19 +
20 + # output, first
19 return x, f 21 return x, f
20 22
21 """ 23 """
......
...@@ -24,7 +24,7 @@ def train(**kwargs): ...@@ -24,7 +24,7 @@ def train(**kwargs):
24 24
25 print('\n[+] Create log dir') 25 print('\n[+] Create log dir')
26 model_name = get_model_name(args) 26 model_name = get_model_name(args)
27 - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) 27 + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name)
28 os.makedirs(os.path.join(log_dir, 'model')) 28 os.makedirs(os.path.join(log_dir, 'model'))
29 json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) 29 json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
30 writer = SummaryWriter(log_dir=log_dir) 30 writer = SummaryWriter(log_dir=log_dir)
...@@ -42,13 +42,11 @@ def train(**kwargs): ...@@ -42,13 +42,11 @@ def train(**kwargs):
42 if args.use_cuda: 42 if args.use_cuda:
43 model = model.cuda() 43 model = model.cuda()
44 criterion = criterion.cuda() 44 criterion = criterion.cuda()
45 - writer.add_graph(model) 45 + #writer.add_graph(model)
46 46
47 print('\n[+] Load dataset') 47 print('\n[+] Load dataset')
48 - transform = get_train_transform(args, model, log_dir) 48 + train_dataset = get_dataset(args, 'train')
49 - val_transform = get_valid_transform(args, model) 49 + valid_dataset = get_dataset(args, 'val')
50 - train_dataset = get_dataset(args, transform, 'train')
51 - valid_dataset = get_dataset(args, val_transform, 'val')
52 train_loader = iter(get_inf_dataloader(args, train_dataset)) 50 train_loader = iter(get_inf_dataloader(args, train_dataset))
53 max_epoch = len(train_dataset) // args.batch_size 51 max_epoch = len(train_dataset) // args.batch_size
54 best_acc = -1 52 best_acc = -1
...@@ -82,6 +80,7 @@ def train(**kwargs): ...@@ -82,6 +80,7 @@ def train(**kwargs):
82 print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) 80 print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
83 81
84 if step % args.val_step == args.val_step-1: 82 if step % args.val_step == args.val_step-1:
83 + # print("\nstep, args.val_step: ", step, args.val_step)
85 valid_loader = iter(get_dataloader(args, valid_dataset)) 84 valid_loader = iter(get_dataloader(args, valid_dataset))
86 _valid_res = validate(args, model, criterion, valid_loader, step, writer) 85 _valid_res = validate(args, model, criterion, valid_loader, step, writer)
87 print('\n[+] Valid results') 86 print('\n[+] Valid results')
......
This diff is collapsed. Click to expand it.