Showing
5 changed files
with
476 additions
and
0 deletions
code/classifier/eval.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | + | ||
| 6 | +import torch | ||
| 7 | +import torch.nn as nn | ||
| 8 | +from torch.utils.tensorboard import SummaryWriter | ||
| 9 | + | ||
| 10 | +from utils import * | ||
| 11 | + | ||
| 12 | +# command | ||
| 13 | +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' | ||
| 14 | + | ||
| 15 | +def eval(model_path): | ||
| 16 | + print('\n[+] Parse arguments') | ||
| 17 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 18 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 19 | + args, kwargs = parse_args(kwargs) | ||
| 20 | + pprint(args) | ||
| 21 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 22 | + | ||
| 23 | + print('\n[+] Create network') | ||
| 24 | + model = select_model(args) | ||
| 25 | + optimizer = select_optimizer(args, model) | ||
| 26 | + criterion = nn.CrossEntropyLoss() | ||
| 27 | + if args.use_cuda: | ||
| 28 | + model = model.cuda() | ||
| 29 | + criterion = criterion.cuda() | ||
| 30 | + | ||
| 31 | + print('\n[+] Load model') | ||
| 32 | + weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
| 33 | + model.load_state_dict(torch.load(weight_path)) | ||
| 34 | + | ||
| 35 | + print('\n[+] Load dataset') | ||
| 36 | + test_transform = get_valid_transform(args, model) | ||
| 37 | + #print('\nTEST Transform\n', test_transform) | ||
| 38 | + test_dataset = get_dataset(args, 'test') | ||
| 39 | + | ||
| 40 | + | ||
| 41 | + | ||
| 42 | + test_loader = iter(get_dataloader(args, test_dataset)) ### | ||
| 43 | + | ||
| 44 | + print('\n[+] Start testing') | ||
| 45 | + writer = SummaryWriter(log_dir=model_path) | ||
| 46 | + _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | ||
| 47 | + | ||
| 48 | + print('\n[+] Valid results') | ||
| 49 | + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
| 50 | + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
| 51 | + print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
| 52 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
| 53 | + | ||
| 54 | + writer.close() | ||
| 55 | + | ||
| 56 | +if __name__ == '__main__': | ||
| 57 | + fire.Fire(eval) |
code/classifier/networks/basenet.py
0 → 100644
| 1 | +import torch.nn as nn | ||
| 2 | + | ||
| 3 | +class BaseNet(nn.Module): | ||
| 4 | + def __init__(self, backbone, args): | ||
| 5 | + super(BaseNet, self).__init__() | ||
| 6 | + | ||
| 7 | + # Separate layers | ||
| 8 | + self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
| 9 | + self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
| 10 | + self.fc = list(backbone.children())[-1] | ||
| 11 | + | ||
| 12 | + self.img_size = (240, 240) | ||
| 13 | + | ||
| 14 | + def forward(self, x): | ||
| 15 | + f = self.first(x) | ||
| 16 | + x = self.after(f) | ||
| 17 | + x = x.reshape(x.size(0), -1) | ||
| 18 | + x = self.fc(x) | ||
| 19 | + return x, f | ||
| 20 | + | ||
| 21 | +""" | ||
| 22 | + print("before reshape:\n", x.size()) | ||
| 23 | + #[128, 2048, 4, 4] | ||
| 24 | + # #cifar [128, 2048, 1, 1] | ||
| 25 | + x = x.reshape(x.size(0), -1) | ||
| 26 | + print("after reshape:\n", x.size()) | ||
| 27 | + #[128, 32768] | ||
| 28 | + #cifar [128, 2048] | ||
| 29 | + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10] | ||
| 30 | + print("fc :\n", self.fc) | ||
| 31 | + #Linear(in_features=2048, out_features=10, bias=True) | ||
| 32 | + #cifar Linear(in_features=2048, out_features=1000, bias=True) | ||
| 33 | +""" |
code/classifier/requirements.txt
0 → 100644
code/classifier/train.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import time | ||
| 4 | +import json | ||
| 5 | +import random | ||
| 6 | +from pprint import pprint | ||
| 7 | + | ||
| 8 | +import torch.nn as nn | ||
| 9 | +import torch.backends.cudnn as cudnn | ||
| 10 | +from torch.utils.tensorboard import SummaryWriter | ||
| 11 | + | ||
| 12 | +from networks import * | ||
| 13 | +from utils import * | ||
| 14 | + | ||
| 15 | +# python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam | ||
| 16 | +# nohup python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam & | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +def train(**kwargs): | ||
| 20 | + print('\n[+] Parse arguments') | ||
| 21 | + args, kwargs = parse_args(kwargs) | ||
| 22 | + pprint(args) | ||
| 23 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 24 | + | ||
| 25 | + print('\n[+] Create log dir') | ||
| 26 | + model_name = get_model_name(args) | ||
| 27 | + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) | ||
| 28 | + os.makedirs(os.path.join(log_dir, 'model')) | ||
| 29 | + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
| 30 | + writer = SummaryWriter(log_dir=log_dir) | ||
| 31 | + | ||
| 32 | + if args.seed is not None: | ||
| 33 | + random.seed(args.seed) | ||
| 34 | + torch.manual_seed(args.seed) | ||
| 35 | + cudnn.deterministic = True | ||
| 36 | + | ||
| 37 | + print('\n[+] Create network') | ||
| 38 | + model = select_model(args) | ||
| 39 | + optimizer = select_optimizer(args, model) | ||
| 40 | + scheduler = select_scheduler(args, optimizer) | ||
| 41 | + criterion = nn.CrossEntropyLoss() | ||
| 42 | + if args.use_cuda: | ||
| 43 | + model = model.cuda() | ||
| 44 | + criterion = criterion.cuda() | ||
| 45 | + writer.add_graph(model) | ||
| 46 | + | ||
| 47 | + print('\n[+] Load dataset') | ||
| 48 | + transform = get_train_transform(args, model, log_dir) | ||
| 49 | + val_transform = get_valid_transform(args, model) | ||
| 50 | + train_dataset = get_dataset(args, transform, 'train') | ||
| 51 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
| 52 | + train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
| 53 | + max_epoch = len(train_dataset) // args.batch_size | ||
| 54 | + best_acc = -1 | ||
| 55 | + | ||
| 56 | + print('\n[+] Start training') | ||
| 57 | + if torch.cuda.device_count() > 1: | ||
| 58 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 59 | + model = nn.DataParallel(model) | ||
| 60 | + | ||
| 61 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 62 | + print('\n[+] Using GPU: {} '.format(torch.cuda.get_device_name(0))) | ||
| 63 | + | ||
| 64 | + start_t = time.time() | ||
| 65 | + for step in range(args.start_step, args.max_step): | ||
| 66 | + batch = next(train_loader) | ||
| 67 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
| 68 | + | ||
| 69 | + if step % args.print_step == 0: | ||
| 70 | + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
| 71 | + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
| 72 | + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
| 73 | + writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
| 74 | + writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
| 75 | + writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
| 76 | + writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
| 77 | + writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
| 78 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
| 79 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
| 80 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
| 81 | + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
| 82 | + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
| 83 | + | ||
| 84 | + if step % args.val_step == args.val_step-1: | ||
| 85 | + valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
| 86 | + _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
| 87 | + print('\n[+] Valid results') | ||
| 88 | + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
| 89 | + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
| 90 | + writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
| 91 | + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
| 92 | + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
| 93 | + print(' Loss : {}'.format(_valid_res[2].data)) | ||
| 94 | + | ||
| 95 | + if _valid_res[0] >= best_acc: | ||
| 96 | + best_acc = _valid_res[0] | ||
| 97 | + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
| 98 | + print('\n[+] Model saved') | ||
| 99 | + | ||
| 100 | + writer.close() | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +if __name__ == '__main__': | ||
| 104 | + fire.Fire(train) |
code/classifier/util.py
0 → 100644
| 1 | +import os | ||
| 2 | +import time | ||
| 3 | +import importlib | ||
| 4 | +import collections | ||
| 5 | +import pickle as cp | ||
| 6 | +import glob | ||
| 7 | +import numpy as np | ||
| 8 | +import pandas as pd | ||
| 9 | + | ||
| 10 | +from natsort import natsorted | ||
| 11 | +from PIL import Image | ||
| 12 | +import torch | ||
| 13 | +import torchvision | ||
| 14 | +import torch.nn.functional as F | ||
| 15 | +import torchvision.models as models | ||
| 16 | +import torchvision.transforms as transforms | ||
| 17 | +from torch.utils.data import Subset | ||
| 18 | +from torch.utils.data import Dataset, DataLoader | ||
| 19 | + | ||
| 20 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
| 21 | +from sklearn.model_selection import train_test_split | ||
| 22 | +from sklearn.model_selection import KFold | ||
| 23 | + | ||
| 24 | +from networks import * | ||
| 25 | + | ||
| 26 | + | ||
| 27 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
| 28 | +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
| 29 | +# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
| 30 | +# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
| 31 | + | ||
| 32 | +current_epoch = 0 | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +def split_dataset(args, dataset, k): | ||
| 36 | + # load dataset | ||
| 37 | + X = list(range(len(dataset))) | ||
| 38 | + Y = dataset.targets | ||
| 39 | + | ||
| 40 | + # split to k-fold | ||
| 41 | + assert len(X) == len(Y) | ||
| 42 | + | ||
| 43 | + def _it_to_list(_it): | ||
| 44 | + return list(zip(*list(_it))) | ||
| 45 | + | ||
| 46 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
| 47 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
| 48 | + | ||
| 49 | + return Dm_indexes, Da_indexes | ||
| 50 | + | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def get_model_name(args): | ||
| 54 | + from datetime import datetime, timedelta, timezone | ||
| 55 | + now = datetime.now(timezone.utc) | ||
| 56 | + tz = timezone(timedelta(hours=9)) | ||
| 57 | + now = now.astimezone(tz) | ||
| 58 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
| 59 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
| 60 | + return model_name | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +def dict_to_namedtuple(d): | ||
| 64 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
| 65 | + | ||
| 66 | + for k,v in d.items(): | ||
| 67 | + if type(v) is dict: | ||
| 68 | + d[k] = dict_to_namedtuple(v) | ||
| 69 | + | ||
| 70 | + elif type(v) is str: | ||
| 71 | + try: | ||
| 72 | + d[k] = eval(v) | ||
| 73 | + except: | ||
| 74 | + d[k] = v | ||
| 75 | + | ||
| 76 | + args = Args(**d) | ||
| 77 | + return args | ||
| 78 | + | ||
| 79 | + | ||
| 80 | +def parse_args(kwargs): | ||
| 81 | + # combine with default args | ||
| 82 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
| 83 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
| 84 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
| 85 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
| 86 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
| 87 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
| 88 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
| 89 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
| 90 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
| 91 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
| 92 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
| 93 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
| 94 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
| 95 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
| 96 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
| 97 | + | ||
| 98 | + # to named tuple | ||
| 99 | + args = dict_to_namedtuple(kwargs) | ||
| 100 | + return args, kwargs | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +def select_model(args): | ||
| 104 | + # grayResNet2 | ||
| 105 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
| 106 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
| 107 | + | ||
| 108 | + if args.network in resnet_dict: | ||
| 109 | + backbone = resnet_dict[args.network] | ||
| 110 | + model = basenet.BaseNet(backbone, args) | ||
| 111 | + else: | ||
| 112 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
| 113 | + model = Net(args) | ||
| 114 | + | ||
| 115 | + #print(model) # print model architecture | ||
| 116 | + return model | ||
| 117 | + | ||
| 118 | + | ||
| 119 | +def select_optimizer(args, model): | ||
| 120 | + if args.optimizer == 'sgd': | ||
| 121 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
| 122 | + elif args.optimizer == 'rms': | ||
| 123 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
| 124 | + elif args.optimizer == 'adam': | ||
| 125 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
| 126 | + else: | ||
| 127 | + raise Exception('Unknown Optimizer') | ||
| 128 | + return optimizer | ||
| 129 | + | ||
| 130 | + | ||
| 131 | +def select_scheduler(args, optimizer): | ||
| 132 | + if not args.scheduler or args.scheduler == 'None': | ||
| 133 | + return None | ||
| 134 | + elif args.scheduler =='clr': | ||
| 135 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
| 136 | + elif args.scheduler =='exp': | ||
| 137 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
| 138 | + else: | ||
| 139 | + raise Exception('Unknown Scheduler') | ||
| 140 | + | ||
| 141 | + | ||
| 142 | +class CustomDataset(Dataset): | ||
| 143 | + def __init__(self, data_path, csv_path): | ||
| 144 | + self.len = len(self.imgs) | ||
| 145 | + self.path = data_path | ||
| 146 | + self.imgs = natsorted(os.listdir(data_path)) | ||
| 147 | + | ||
| 148 | + df = pd.read_csv(csv_path) | ||
| 149 | + targets_list = [] | ||
| 150 | + | ||
| 151 | + for fname in self.imgs: | ||
| 152 | + row = df.loc[df['filename'] == fname] | ||
| 153 | + targets_list.append(row.iloc[0, 1]) | ||
| 154 | + | ||
| 155 | + self.targets = targets_list | ||
| 156 | + | ||
| 157 | + def __len__(self): | ||
| 158 | + return self.len | ||
| 159 | + | ||
| 160 | + def __getitem__(self, idx): | ||
| 161 | + img_loc = os.path.join(self.path, self.imgs[idx]) | ||
| 162 | + targets = self.targets[idx] | ||
| 163 | + image = Image.open(img_loc) | ||
| 164 | + return image, targets | ||
| 165 | + | ||
| 166 | + | ||
| 167 | + | ||
| 168 | +def get_dataset(args, transform, split='train'): | ||
| 169 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
| 170 | + | ||
| 171 | + if split in ['train']: | ||
| 172 | + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
| 173 | + else: #test | ||
| 174 | + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
| 175 | + | ||
| 176 | + return dataset | ||
| 177 | + | ||
| 178 | + | ||
| 179 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 180 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
| 181 | + batch_size=args.batch_size, | ||
| 182 | + shuffle=shuffle, | ||
| 183 | + num_workers=args.num_workers, | ||
| 184 | + pin_memory=pin_memory) | ||
| 185 | + return data_loader | ||
| 186 | + | ||
| 187 | + | ||
| 188 | +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 189 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
| 190 | + batch_size=args.batch_size, | ||
| 191 | + shuffle=shuffle, | ||
| 192 | + num_workers=args.num_workers, | ||
| 193 | + pin_memory=pin_memory) | ||
| 194 | + return data_loader | ||
| 195 | + | ||
| 196 | + | ||
| 197 | +def get_inf_dataloader(args, dataset): | ||
| 198 | + global current_epoch | ||
| 199 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 200 | + | ||
| 201 | + while True: | ||
| 202 | + try: | ||
| 203 | + batch = next(data_loader) | ||
| 204 | + | ||
| 205 | + except StopIteration: | ||
| 206 | + current_epoch += 1 | ||
| 207 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 208 | + batch = next(data_loader) | ||
| 209 | + | ||
| 210 | + yield batch | ||
| 211 | + | ||
| 212 | + | ||
| 213 | + | ||
| 214 | + | ||
| 215 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
| 216 | + model.train() | ||
| 217 | + images, target = batch | ||
| 218 | + | ||
| 219 | + if device: | ||
| 220 | + images = images.to(device) | ||
| 221 | + target = target.to(device) | ||
| 222 | + | ||
| 223 | + elif args.use_cuda: | ||
| 224 | + images = images.cuda(non_blocking=True) | ||
| 225 | + target = target.cuda(non_blocking=True) | ||
| 226 | + | ||
| 227 | + # compute output | ||
| 228 | + start_t = time.time() | ||
| 229 | + output, first = model(images) | ||
| 230 | + forward_t = time.time() - start_t | ||
| 231 | + loss = criterion(output, target) | ||
| 232 | + | ||
| 233 | + # measure accuracy and record loss | ||
| 234 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 235 | + acc1 /= images.size(0) | ||
| 236 | + acc5 /= images.size(0) | ||
| 237 | + | ||
| 238 | + # compute gradient and do SGD step | ||
| 239 | + optimizer.zero_grad() | ||
| 240 | + start_t = time.time() | ||
| 241 | + loss.backward() | ||
| 242 | + backward_t = time.time() - start_t | ||
| 243 | + optimizer.step() | ||
| 244 | + if scheduler: scheduler.step() | ||
| 245 | + | ||
| 246 | + if writer and step % args.print_step == 0: | ||
| 247 | + n_imgs = min(images.size(0), 10) | ||
| 248 | + tag = 'train/' + str(step) | ||
| 249 | + for j in range(n_imgs): | ||
| 250 | + writer.add_image(tag, | ||
| 251 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 252 | + | ||
| 253 | + return acc1, acc5, loss, forward_t, backward_t | ||
| 254 | + | ||
| 255 | + | ||
| 256 | +#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 257 | +def accuracy(output, target, topk=(1,)): | ||
| 258 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
| 259 | + with torch.no_grad(): | ||
| 260 | + maxk = max(topk) | ||
| 261 | + batch_size = target.size(0) | ||
| 262 | + | ||
| 263 | + _, pred = output.topk(maxk, 1, True, True) | ||
| 264 | + pred = pred.t() | ||
| 265 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 266 | + | ||
| 267 | + res = [] | ||
| 268 | + for k in topk: | ||
| 269 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
| 270 | + res.append(correct_k) | ||
| 271 | + return res | ||
| 272 | + |
-
Please register or login to post a comment