Showing
4 changed files
with
7 additions
and
9 deletions
... | @@ -33,12 +33,9 @@ def eval(model_path): | ... | @@ -33,12 +33,9 @@ def eval(model_path): |
33 | model.load_state_dict(torch.load(weight_path)) | 33 | model.load_state_dict(torch.load(weight_path)) |
34 | 34 | ||
35 | print('\n[+] Load dataset') | 35 | print('\n[+] Load dataset') |
36 | - test_transform = get_valid_transform(args, model) | ||
37 | - #print('\nTEST Transform\n', test_transform) | ||
38 | test_dataset = get_dataset(args, 'test') | 36 | test_dataset = get_dataset(args, 'test') |
39 | 37 | ||
40 | 38 | ||
41 | - | ||
42 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 39 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
43 | 40 | ||
44 | print('\n[+] Start testing') | 41 | print('\n[+] Start testing') | ... | ... |
... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): | ... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): |
16 | x = self.after(f) | 16 | x = self.after(f) |
17 | x = x.reshape(x.size(0), -1) | 17 | x = x.reshape(x.size(0), -1) |
18 | x = self.fc(x) | 18 | x = self.fc(x) |
19 | + | ||
20 | + # output, first | ||
19 | return x, f | 21 | return x, f |
20 | 22 | ||
21 | """ | 23 | """ | ... | ... |
... | @@ -24,7 +24,7 @@ def train(**kwargs): | ... | @@ -24,7 +24,7 @@ def train(**kwargs): |
24 | 24 | ||
25 | print('\n[+] Create log dir') | 25 | print('\n[+] Create log dir') |
26 | model_name = get_model_name(args) | 26 | model_name = get_model_name(args) |
27 | - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) | 27 | + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) |
28 | os.makedirs(os.path.join(log_dir, 'model')) | 28 | os.makedirs(os.path.join(log_dir, 'model')) |
29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) |
30 | writer = SummaryWriter(log_dir=log_dir) | 30 | writer = SummaryWriter(log_dir=log_dir) |
... | @@ -42,13 +42,11 @@ def train(**kwargs): | ... | @@ -42,13 +42,11 @@ def train(**kwargs): |
42 | if args.use_cuda: | 42 | if args.use_cuda: |
43 | model = model.cuda() | 43 | model = model.cuda() |
44 | criterion = criterion.cuda() | 44 | criterion = criterion.cuda() |
45 | - writer.add_graph(model) | 45 | + #writer.add_graph(model) |
46 | 46 | ||
47 | print('\n[+] Load dataset') | 47 | print('\n[+] Load dataset') |
48 | - transform = get_train_transform(args, model, log_dir) | 48 | + train_dataset = get_dataset(args, 'train') |
49 | - val_transform = get_valid_transform(args, model) | 49 | + valid_dataset = get_dataset(args, 'val') |
50 | - train_dataset = get_dataset(args, transform, 'train') | ||
51 | - valid_dataset = get_dataset(args, val_transform, 'val') | ||
52 | train_loader = iter(get_inf_dataloader(args, train_dataset)) | 50 | train_loader = iter(get_inf_dataloader(args, train_dataset)) |
53 | max_epoch = len(train_dataset) // args.batch_size | 51 | max_epoch = len(train_dataset) // args.batch_size |
54 | best_acc = -1 | 52 | best_acc = -1 |
... | @@ -82,6 +80,7 @@ def train(**kwargs): | ... | @@ -82,6 +80,7 @@ def train(**kwargs): |
82 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | 80 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) |
83 | 81 | ||
84 | if step % args.val_step == args.val_step-1: | 82 | if step % args.val_step == args.val_step-1: |
83 | + # print("\nstep, args.val_step: ", step, args.val_step) | ||
85 | valid_loader = iter(get_dataloader(args, valid_dataset)) | 84 | valid_loader = iter(get_dataloader(args, valid_dataset)) |
86 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) | 85 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) |
87 | print('\n[+] Valid results') | 86 | print('\n[+] Valid results') | ... | ... |
code/classifier/utils.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment