조현아

update classifier

......@@ -33,11 +33,8 @@ def eval(model_path):
model.load_state_dict(torch.load(weight_path))
print('\n[+] Load dataset')
test_transform = get_valid_transform(args, model)
#print('\nTEST Transform\n', test_transform)
test_dataset = get_dataset(args, 'test')
test_loader = iter(get_dataloader(args, test_dataset)) ###
......
......@@ -16,6 +16,8 @@ class BaseNet(nn.Module):
x = self.after(f)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
# output, first
return x, f
"""
......
......@@ -24,7 +24,7 @@ def train(**kwargs):
print('\n[+] Create log dir')
model_name = get_model_name(args)
log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name)
log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name)
os.makedirs(os.path.join(log_dir, 'model'))
json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
writer = SummaryWriter(log_dir=log_dir)
......@@ -42,13 +42,11 @@ def train(**kwargs):
if args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
writer.add_graph(model)
#writer.add_graph(model)
print('\n[+] Load dataset')
transform = get_train_transform(args, model, log_dir)
val_transform = get_valid_transform(args, model)
train_dataset = get_dataset(args, transform, 'train')
valid_dataset = get_dataset(args, val_transform, 'val')
train_dataset = get_dataset(args, 'train')
valid_dataset = get_dataset(args, 'val')
train_loader = iter(get_inf_dataloader(args, train_dataset))
max_epoch = len(train_dataset) // args.batch_size
best_acc = -1
......@@ -82,6 +80,7 @@ def train(**kwargs):
print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
if step % args.val_step == args.val_step-1:
# print("\nstep, args.val_step: ", step, args.val_step)
valid_loader = iter(get_dataloader(args, valid_dataset))
_valid_res = validate(args, model, criterion, valid_loader, step, writer)
print('\n[+] Valid results')
......
This diff is collapsed. Click to expand it.