Showing
3 changed files
with
102 additions
and
49 deletions
| ... | @@ -5,19 +5,19 @@ from pprint import pprint | ... | @@ -5,19 +5,19 @@ from pprint import pprint |
| 5 | 5 | ||
| 6 | import torch | 6 | import torch |
| 7 | import torch.nn as nn | 7 | import torch.nn as nn |
| 8 | -from torch.utils.tensorboard import SummaryWriter | 8 | +import torchvision.transforms as transforms |
| 9 | +#from torch.utils.tensorboard import SummaryWriter | ||
| 9 | 10 | ||
| 10 | from utils import * | 11 | from utils import * |
| 11 | 12 | ||
| 12 | # command | 13 | # command |
| 13 | # python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' | 14 | # python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' |
| 14 | 15 | ||
| 15 | -def eval(model_path, num_data): | 16 | +def eval(model_path): |
| 16 | print('\n[+] Parse arguments') | 17 | print('\n[+] Parse arguments') |
| 17 | kwargs_path = os.path.join(model_path, 'kwargs.json') | 18 | kwargs_path = os.path.join(model_path, 'kwargs.json') |
| 18 | kwargs = json.loads(open(kwargs_path).read()) | 19 | kwargs = json.loads(open(kwargs_path).read()) |
| 19 | args, kwargs = parse_args(kwargs) | 20 | args, kwargs = parse_args(kwargs) |
| 20 | - args.batch_size = num_data | ||
| 21 | pprint(args) | 21 | pprint(args) |
| 22 | device = torch.device('cuda' if args.use_cuda else 'cpu') | 22 | device = torch.device('cuda' if args.use_cuda else 'cpu') |
| 23 | 23 | ||
| ... | @@ -35,23 +35,25 @@ def eval(model_path, num_data): | ... | @@ -35,23 +35,25 @@ def eval(model_path, num_data): |
| 35 | model.load_state_dict(torch.load(weight_path)) | 35 | model.load_state_dict(torch.load(weight_path)) |
| 36 | 36 | ||
| 37 | print('\n[+] Load dataset') | 37 | print('\n[+] Load dataset') |
| 38 | - test_dataset = get_dataset(args, 'test') | 38 | + transform = transforms.Compose([ |
| 39 | + transforms.Resize([240, 240]), | ||
| 40 | + transforms.ToTensor() | ||
| 41 | + ]) | ||
| 42 | + test_dataset = get_dataset(args, transform, 'test') | ||
| 39 | 43 | ||
| 40 | 44 | ||
| 41 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 45 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
| 42 | 46 | ||
| 43 | - print('\n[+] Start testing') | 47 | + # print('\n[+] Start testing') |
| 44 | - writer = SummaryWriter(log_dir=model_path) | 48 | + # writer = SummaryWriter(log_dir=model_path) |
| 45 | - _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | 49 | + _test_res = validate(args, model, criterion, test_loader, step=0) |
| 46 | 50 | ||
| 47 | print('\n[+] Valid results') | 51 | print('\n[+] Valid results') |
| 48 | print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | 52 | print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) |
| 49 | - print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | 53 | + print(' Loss : {:.3f}'.format(_test_res[1].data)) |
| 50 | - print(' Acc_all : {:.3f}%'.format(_test_res[2].data.cpu().numpy()[0]*100)) | 54 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[2]*1000 / len(test_dataset))) |
| 51 | - print(' Loss : {:.3f}'.format(_test_res[3].data)) | ||
| 52 | - print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[4]*1000 / len(test_dataset))) | ||
| 53 | 55 | ||
| 54 | - writer.close() | 56 | + #writer.close() |
| 55 | 57 | ||
| 56 | if __name__ == '__main__': | 58 | if __name__ == '__main__': |
| 57 | fire.Fire(eval) | 59 | fire.Fire(eval) | ... | ... |
| ... | @@ -7,7 +7,7 @@ from pprint import pprint | ... | @@ -7,7 +7,7 @@ from pprint import pprint |
| 7 | 7 | ||
| 8 | import torch.nn as nn | 8 | import torch.nn as nn |
| 9 | import torch.backends.cudnn as cudnn | 9 | import torch.backends.cudnn as cudnn |
| 10 | -from torch.utils.tensorboard import SummaryWriter | 10 | +#from torch.utils.tensorboard import SummaryWriter |
| 11 | 11 | ||
| 12 | from networks import * | 12 | from networks import * |
| 13 | from utils import * | 13 | from utils import * |
| ... | @@ -27,7 +27,7 @@ def train(**kwargs): | ... | @@ -27,7 +27,7 @@ def train(**kwargs): |
| 27 | log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) | 27 | log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) |
| 28 | os.makedirs(os.path.join(log_dir, 'model')) | 28 | os.makedirs(os.path.join(log_dir, 'model')) |
| 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) |
| 30 | - writer = SummaryWriter(log_dir=log_dir) | 30 | + #writer = SummaryWriter(log_dir=log_dir) |
| 31 | 31 | ||
| 32 | if args.seed is not None: | 32 | if args.seed is not None: |
| 33 | random.seed(args.seed) | 33 | random.seed(args.seed) |
| ... | @@ -45,8 +45,10 @@ def train(**kwargs): | ... | @@ -45,8 +45,10 @@ def train(**kwargs): |
| 45 | #writer.add_graph(model) | 45 | #writer.add_graph(model) |
| 46 | 46 | ||
| 47 | print('\n[+] Load dataset') | 47 | print('\n[+] Load dataset') |
| 48 | - train_dataset = get_dataset(args, 'train') | 48 | + transform = get_train_transform(args, model, log_dir) |
| 49 | - valid_dataset = get_dataset(args, 'val') | 49 | + val_transform = get_valid_transform(args, model) |
| 50 | + train_dataset = get_dataset(args, transform, 'train') | ||
| 51 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
| 50 | train_loader = iter(get_inf_dataloader(args, train_dataset)) | 52 | train_loader = iter(get_inf_dataloader(args, train_dataset)) |
| 51 | max_epoch = len(train_dataset) // args.batch_size | 53 | max_epoch = len(train_dataset) // args.batch_size |
| 52 | best_acc = -1 | 54 | best_acc = -1 |
| ... | @@ -62,16 +64,16 @@ def train(**kwargs): | ... | @@ -62,16 +64,16 @@ def train(**kwargs): |
| 62 | start_t = time.time() | 64 | start_t = time.time() |
| 63 | for step in range(args.start_step, args.max_step): | 65 | for step in range(args.start_step, args.max_step): |
| 64 | batch = next(train_loader) | 66 | batch = next(train_loader) |
| 65 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | 67 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step) |
| 66 | 68 | ||
| 67 | if step % args.print_step == 0: | 69 | if step % args.print_step == 0: |
| 68 | print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | 70 | print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( |
| 69 | step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | 71 | step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) |
| 70 | - writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | 72 | + # writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) |
| 71 | - writer.add_scalar('train/acc1', _train_res[0], global_step=step) | 73 | + # writer.add_scalar('train/acc1', _train_res[0], global_step=step) |
| 72 | - writer.add_scalar('train/loss', _train_res[1], global_step=step) | 74 | + # writer.add_scalar('train/loss', _train_res[1], global_step=step) |
| 73 | - writer.add_scalar('train/forward_time', _train_res[2], global_step=step) | 75 | + # writer.add_scalar('train/forward_time', _train_res[2], global_step=step) |
| 74 | - writer.add_scalar('train/backward_time', _train_res[3], global_step=step) | 76 | + # writer.add_scalar('train/backward_time', _train_res[3], global_step=step) |
| 75 | print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | 77 | print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) |
| 76 | print(' Loss : {}'.format(_train_res[1].data)) | 78 | print(' Loss : {}'.format(_train_res[1].data)) |
| 77 | print(' FW Time : {:.3f}ms'.format(_train_res[2]*1000)) | 79 | print(' FW Time : {:.3f}ms'.format(_train_res[2]*1000)) |
| ... | @@ -80,10 +82,10 @@ def train(**kwargs): | ... | @@ -80,10 +82,10 @@ def train(**kwargs): |
| 80 | if step % args.val_step == args.val_step-1: | 82 | if step % args.val_step == args.val_step-1: |
| 81 | # print("\nstep, args.val_step: ", step, args.val_step) | 83 | # print("\nstep, args.val_step: ", step, args.val_step) |
| 82 | valid_loader = iter(get_dataloader(args, valid_dataset)) | 84 | valid_loader = iter(get_dataloader(args, valid_dataset)) |
| 83 | - _valid_res = validate(args, model, criterion, valid_loader, step, writer) | 85 | + _valid_res = validate(args, model, criterion, valid_loader, step) |
| 84 | - print('\n[+] Valid results') | 86 | + print('\n[+] (Valid results) Valid step: {}/{}'.format(step, args.max_step)) |
| 85 | - writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | 87 | + # writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) |
| 86 | - writer.add_scalar('valid/loss', _valid_res[1], global_step=step) | 88 | + # writer.add_scalar('valid/loss', _valid_res[1], global_step=step) |
| 87 | print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | 89 | print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) |
| 88 | print(' Loss : {}'.format(_valid_res[1].data)) | 90 | print(' Loss : {}'.format(_valid_res[1].data)) |
| 89 | 91 | ||
| ... | @@ -92,7 +94,7 @@ def train(**kwargs): | ... | @@ -92,7 +94,7 @@ def train(**kwargs): |
| 92 | torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | 94 | torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) |
| 93 | print('\n[+] Model saved') | 95 | print('\n[+] Model saved') |
| 94 | 96 | ||
| 95 | - writer.close() | 97 | + # writer.close() |
| 96 | 98 | ||
| 97 | 99 | ||
| 98 | if __name__ == '__main__': | 100 | if __name__ == '__main__': | ... | ... |
| ... | @@ -23,6 +23,7 @@ from sklearn.model_selection import KFold | ... | @@ -23,6 +23,7 @@ from sklearn.model_selection import KFold |
| 23 | 23 | ||
| 24 | from networks import basenet, grayResNet2 | 24 | from networks import basenet, grayResNet2 |
| 25 | 25 | ||
| 26 | +DATASET_PATH = '/content/drive/My Drive/CD2 Project/' | ||
| 26 | 27 | ||
| 27 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' |
| 28 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' |
| ... | @@ -131,17 +132,17 @@ def parse_args(kwargs): | ... | @@ -131,17 +132,17 @@ def parse_args(kwargs): |
| 131 | kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | 132 | kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' |
| 132 | kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | 133 | kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' |
| 133 | kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | 134 | kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' |
| 134 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.001 | 135 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.01 |
| 135 | kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | 136 | kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None |
| 136 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | 137 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True |
| 137 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | 138 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() |
| 138 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | 139 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 |
| 139 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 100 | 140 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 50 |
| 140 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 100 | 141 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 50 |
| 141 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | 142 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' |
| 142 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 32 | 143 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 16 |
| 143 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | 144 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 |
| 144 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 2500 | 145 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 500 |
| 145 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | 146 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None |
| 146 | 147 | ||
| 147 | # to named tuple | 148 | # to named tuple |
| ... | @@ -155,11 +156,10 @@ def select_model(args): | ... | @@ -155,11 +156,10 @@ def select_model(args): |
| 155 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | 156 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} |
| 156 | 157 | ||
| 157 | if args.network in resnet_dict: | 158 | if args.network in resnet_dict: |
| 158 | - backbone = resnet_dict[args.network] | 159 | + model = resnet_dict[args.network] |
| 159 | - model = basenet.BaseNet(backbone, args) | 160 | + # else: # 3 channels |
| 160 | - else: | 161 | + # backbone = models.__dict__[args.network]() |
| 161 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 162 | + # model = basenet.BaseNet(backbone, args) |
| 162 | - model = Net(args) | ||
| 163 | 163 | ||
| 164 | #print(model) # print model architecture | 164 | #print(model) # print model architecture |
| 165 | return model | 165 | return model |
| ... | @@ -187,16 +187,44 @@ def select_scheduler(args, optimizer): | ... | @@ -187,16 +187,44 @@ def select_scheduler(args, optimizer): |
| 187 | else: | 187 | else: |
| 188 | raise Exception('Unknown Scheduler') | 188 | raise Exception('Unknown Scheduler') |
| 189 | 189 | ||
| 190 | +def get_train_transform(args, model, transform, log_dir=None): | ||
| 191 | + if args.dataset == 'cifar10': | ||
| 192 | + transform = transforms.Compose([ | ||
| 193 | + transforms.Pad(4), | ||
| 194 | + transforms.RandomCrop(32), | ||
| 195 | + transforms.RandomHorizontalFlip(), | ||
| 196 | + transforms.ToTensor() | ||
| 197 | + ]) | ||
| 198 | + | ||
| 199 | + else: | ||
| 200 | + transform = transforms.Compose([ | ||
| 201 | + transforms.Resize([240, 240]), | ||
| 202 | + transforms.ToTensor() | ||
| 203 | + ]) | ||
| 204 | + | ||
| 205 | + return transform | ||
| 206 | + | ||
| 207 | +def get_valid_transform(args, model): | ||
| 208 | + if args.dataset == 'cifar10': | ||
| 209 | + val_transform = transforms.Compose([ | ||
| 210 | + transforms.Resize(32), | ||
| 211 | + transforms.ToTensor() | ||
| 212 | + ]) | ||
| 213 | + | ||
| 214 | + else: | ||
| 215 | + val_transform = transforms.Compose([ | ||
| 216 | + transforms.Resize([240, 240]), | ||
| 217 | + transforms.ToTensor() | ||
| 218 | + ]) | ||
| 219 | + | ||
| 220 | + return val_transform | ||
| 190 | 221 | ||
| 191 | class CustomDataset(Dataset): | 222 | class CustomDataset(Dataset): |
| 192 | - def __init__(self, data_path, csv_path): | 223 | + def __init__(self, data_path, csv_path, transform): |
| 193 | self.path = data_path | 224 | self.path = data_path |
| 194 | self.imgs = natsorted(os.listdir(data_path)) | 225 | self.imgs = natsorted(os.listdir(data_path)) |
| 195 | self.len = len(self.imgs) | 226 | self.len = len(self.imgs) |
| 196 | - self.transform = transforms.Compose([ | 227 | + self.transform = transform |
| 197 | - transforms.Resize([240, 240]), | ||
| 198 | - transforms.ToTensor() | ||
| 199 | - ]) | ||
| 200 | 228 | ||
| 201 | df = pd.read_csv(csv_path) | 229 | df = pd.read_csv(csv_path) |
| 202 | targets_list = [] | 230 | targets_list = [] |
| ... | @@ -215,6 +243,7 @@ class CustomDataset(Dataset): | ... | @@ -215,6 +243,7 @@ class CustomDataset(Dataset): |
| 215 | targets = self.targets[idx] | 243 | targets = self.targets[idx] |
| 216 | image = Image.open(img_loc) | 244 | image = Image.open(img_loc) |
| 217 | image = self.transform(image) | 245 | image = self.transform(image) |
| 246 | + #print("\n idx, img, targets: ", idx, img_loc, targets) | ||
| 218 | return image, targets | 247 | return image, targets |
| 219 | 248 | ||
| 220 | 249 | ||
| ... | @@ -222,12 +251,32 @@ class CustomDataset(Dataset): | ... | @@ -222,12 +251,32 @@ class CustomDataset(Dataset): |
| 222 | def get_dataset(args, transform, split='train'): | 251 | def get_dataset(args, transform, split='train'): |
| 223 | assert split in ['train', 'val', 'test'] | 252 | assert split in ['train', 'val', 'test'] |
| 224 | 253 | ||
| 254 | + if args.dataset == 'cifar10': | ||
| 255 | + train = split in ['train', 'val', 'trainval'] | ||
| 256 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
| 257 | + train=train, | ||
| 258 | + transform=transform, | ||
| 259 | + download=True) | ||
| 260 | + | ||
| 261 | + if split in ['train', 'val']: | ||
| 262 | + split_path = os.path.join(DATASET_PATH, | ||
| 263 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
| 264 | + | ||
| 265 | + if not os.path.exists(split_path): | ||
| 266 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
| 267 | + split_index = {'train':train_index, 'val':val_index} | ||
| 268 | + cp.dump(split_index, open(split_path, 'wb')) | ||
| 269 | + | ||
| 270 | + split_index = cp.load(open(split_path, 'rb')) | ||
| 271 | + dataset = Subset(dataset, split_index[split]) | ||
| 272 | + | ||
| 273 | + else: | ||
| 225 | if split in ['train']: | 274 | if split in ['train']: |
| 226 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH) | 275 | + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform) |
| 227 | elif split in ['val']: | 276 | elif split in ['val']: |
| 228 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH) | 277 | + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform) |
| 229 | else : # test | 278 | else : # test |
| 230 | - dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH) | 279 | + dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH, transform) |
| 231 | 280 | ||
| 232 | 281 | ||
| 233 | return dataset | 282 | return dataset |
| ... | @@ -261,7 +310,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -261,7 +310,7 @@ def get_inf_dataloader(args, dataset): |
| 261 | 310 | ||
| 262 | 311 | ||
| 263 | 312 | ||
| 264 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 313 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, device=None): |
| 265 | model.train() | 314 | model.train() |
| 266 | images, target = batch | 315 | images, target = batch |
| 267 | 316 | ||
| ... | @@ -275,7 +324,7 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer | ... | @@ -275,7 +324,7 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer |
| 275 | 324 | ||
| 276 | # compute output | 325 | # compute output |
| 277 | start_t = time.time() | 326 | start_t = time.time() |
| 278 | - output, first = model(images) | 327 | + output= model(images) |
| 279 | forward_t = time.time() - start_t | 328 | forward_t = time.time() - start_t |
| 280 | loss = criterion(output, target) | 329 | loss = criterion(output, target) |
| 281 | 330 | ||
| ... | @@ -323,7 +372,7 @@ def accuracy(output, target, topk=(1,)): | ... | @@ -323,7 +372,7 @@ def accuracy(output, target, topk=(1,)): |
| 323 | res.append(correct_k) | 372 | res.append(correct_k) |
| 324 | return res | 373 | return res |
| 325 | 374 | ||
| 326 | -def validate(args, model, criterion, valid_loader, step, writer, device=None): | 375 | +def validate(args, model, criterion, valid_loader, step, device=None): |
| 327 | # switch to evaluate mode | 376 | # switch to evaluate mode |
| 328 | model.eval() | 377 | model.eval() |
| 329 | 378 | ||
| ... | @@ -344,7 +393,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -344,7 +393,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 344 | target = target.cuda(non_blocking=True) | 393 | target = target.cuda(non_blocking=True) |
| 345 | 394 | ||
| 346 | # compute output | 395 | # compute output |
| 347 | - output, first = model(images) | 396 | + output = model(images) |
| 348 | loss = criterion(output, target) | 397 | loss = criterion(output, target) |
| 349 | infer_t += time.time() - start_t | 398 | infer_t += time.time() - start_t |
| 350 | 399 | ... | ... |
-
Please register or login to post a comment