Showing
21 changed files
with
1636 additions
and
0 deletions
| 1 | +# Byte-compiled / optimized / DLL files | ||
| 2 | +__pycache__/ | ||
| 3 | +*.py[cod] | ||
| 4 | +*$py.class | ||
| 5 | + | ||
| 6 | +# C extensions | ||
| 7 | +*.so | ||
| 8 | + | ||
| 9 | +# Distribution / packaging | ||
| 10 | +.Python | ||
| 11 | +build/ | ||
| 12 | +develop-eggs/ | ||
| 13 | +dist/ | ||
| 14 | +downloads/ | ||
| 15 | +eggs/ | ||
| 16 | +.eggs/ | ||
| 17 | +lib/ | ||
| 18 | +lib64/ | ||
| 19 | +parts/ | ||
| 20 | +sdist/ | ||
| 21 | +var/ | ||
| 22 | +wheels/ | ||
| 23 | +*.egg-info/ | ||
| 24 | +.installed.cfg | ||
| 25 | +*.egg | ||
| 26 | +MANIFEST | ||
| 27 | + | ||
| 28 | +# PyInstaller | ||
| 29 | +# Usually these files are written by a python script from a template | ||
| 30 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| 31 | +*.manifest | ||
| 32 | +*.spec | ||
| 33 | + | ||
| 34 | +# Installer logs | ||
| 35 | +pip-log.txt | ||
| 36 | +pip-delete-this-directory.txt | ||
| 37 | + | ||
| 38 | +# Unit test / coverage reports | ||
| 39 | +htmlcov/ | ||
| 40 | +.tox/ | ||
| 41 | +.coverage | ||
| 42 | +.coverage.* | ||
| 43 | +.cache | ||
| 44 | +nosetests.xml | ||
| 45 | +coverage.xml | ||
| 46 | +*.cover | ||
| 47 | +.hypothesis/ | ||
| 48 | +.pytest_cache/ | ||
| 49 | + | ||
| 50 | +# Translations | ||
| 51 | +*.mo | ||
| 52 | +*.pot | ||
| 53 | + | ||
| 54 | +# Django stuff: | ||
| 55 | +*.log | ||
| 56 | +local_settings.py | ||
| 57 | +db.sqlite3 | ||
| 58 | + | ||
| 59 | +# Flask stuff: | ||
| 60 | +instance/ | ||
| 61 | +.webassets-cache | ||
| 62 | + | ||
| 63 | +# Scrapy stuff: | ||
| 64 | +.scrapy | ||
| 65 | + | ||
| 66 | +# Sphinx documentation | ||
| 67 | +docs/_build/ | ||
| 68 | + | ||
| 69 | +# PyBuilder | ||
| 70 | +target/ | ||
| 71 | + | ||
| 72 | +# Jupyter Notebook | ||
| 73 | +.ipynb_checkpoints | ||
| 74 | + | ||
| 75 | +# pyenv | ||
| 76 | +.python-version | ||
| 77 | + | ||
| 78 | +# celery beat schedule file | ||
| 79 | +celerybeat-schedule | ||
| 80 | + | ||
| 81 | +# SageMath parsed files | ||
| 82 | +*.sage.py | ||
| 83 | + | ||
| 84 | +# Environments | ||
| 85 | +.env | ||
| 86 | +.venv | ||
| 87 | +env/ | ||
| 88 | +venv/ | ||
| 89 | +ENV/ | ||
| 90 | +env.bak/ | ||
| 91 | +venv.bak/ | ||
| 92 | + | ||
| 93 | +# Spyder project settings | ||
| 94 | +.spyderproject | ||
| 95 | +.spyproject | ||
| 96 | + | ||
| 97 | +# Rope project settings | ||
| 98 | +.ropeproject | ||
| 99 | + | ||
| 100 | +# mkdocs documentation | ||
| 101 | +/site | ||
| 102 | + | ||
| 103 | +# mypy | ||
| 104 | +.mypy_cache/ | ||
| 105 | + |
| 1 | +# Fast Autoaugment | ||
| 2 | +<img src="figures/faa.png" width=800px> | ||
| 3 | + | ||
| 4 | +A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946). | ||
| 5 | + | ||
| 6 | +## Prerequisite | ||
| 7 | +* torch==1.1.0 | ||
| 8 | +* torchvision==0.2.2 | ||
| 9 | +* hyperopt==0.1.2 | ||
| 10 | +* future==0.17.1 | ||
| 11 | +* tb-nightly==1.15.0a20190622 | ||
| 12 | + | ||
| 13 | +## Usage | ||
| 14 | +### Training | ||
| 15 | +#### CIFAR10 | ||
| 16 | +```bash | ||
| 17 | +# ResNet20 (w/o FastAutoAugment) | ||
| 18 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False | ||
| 19 | + | ||
| 20 | +# ResNet20 (w/ FastAutoAugment) | ||
| 21 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True | ||
| 22 | + | ||
| 23 | +# ResNet20 (w/ FastAutoAugment, Pre-found policy) | ||
| 24 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \ | ||
| 25 | + --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp | ||
| 26 | + | ||
| 27 | +# ResNet32 (w/o FastAutoAugment) | ||
| 28 | +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False | ||
| 29 | + | ||
| 30 | +# ResNet32 (w/ FastAutoAugment) | ||
| 31 | +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True | ||
| 32 | + | ||
| 33 | +# EfficientNet (w/ FastAutoAugment) | ||
| 34 | +python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
| 35 | + --network=efficientnet_cifar10 --activation=swish | ||
| 36 | +``` | ||
| 37 | + | ||
| 38 | +#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html)) | ||
| 39 | +```bash | ||
| 40 | + | ||
| 41 | +# BaseNet (w/o FastAutoAugment) | ||
| 42 | +python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50 | ||
| 43 | + | ||
| 44 | +# EfficientNet (w/ FastAutoAugment) (UnderConstruction) | ||
| 45 | +python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
| 46 | + --network=efficientnet --activation=swish | ||
| 47 | +``` | ||
| 48 | + | ||
| 49 | +### Eval | ||
| 50 | +```bash | ||
| 51 | +# Single Image testing | ||
| 52 | +python eval.py --model_path=runs/ResNet_Scale3_Basline | ||
| 53 | + | ||
| 54 | +# 5-crops testing | ||
| 55 | +python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True | ||
| 56 | +``` | ||
| 57 | + | ||
| 58 | +## Experiments | ||
| 59 | +### Fast AutoAugment | ||
| 60 | +#### ResNet20 (CIFAR10) | ||
| 61 | +* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)] | ||
| 62 | +* Validation Curve | ||
| 63 | +<img src="figures/resnet20_valid.png"> | ||
| 64 | + | ||
| 65 | +* Evaluation (Acc @1) | ||
| 66 | + | ||
| 67 | +| | Valid | Test(Single) | | ||
| 68 | +|----------------|-------|-------------| | ||
| 69 | +| ResNet20 | 90.70 | **91.45** | | ||
| 70 | +| ResNet20 + FAA |**92.46**| **91.45** | | ||
| 71 | + | ||
| 72 | +#### ResNet34 (CIFAR10) | ||
| 73 | +* Validation Curve | ||
| 74 | +<img src="figures/resnet34_valid.png"> | ||
| 75 | + | ||
| 76 | +* Evaluation (Acc @1) | ||
| 77 | + | ||
| 78 | +| | Valid | Test(Single) | | ||
| 79 | +|----------------|-------|-------------| | ||
| 80 | +| ResNet34 | 91.54 | 91.47 | | ||
| 81 | +| ResNet34 + FAA |**92.76**| **91.99** | | ||
| 82 | + | ||
| 83 | +### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)] | ||
| 84 | +<img src="figures/pm.png"> | ||
| 85 | + | ||
| 86 | +### Augmented images | ||
| 87 | +<img src="figures/augmented_images.png"> | ||
| 88 | +<img src="figures/augmented_images2.png"> |
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | + | ||
| 6 | +import torch | ||
| 7 | +import torch.nn as nn | ||
| 8 | + | ||
| 9 | +from utils import * | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def eval(model_path): | ||
| 13 | + print('\n[+] Parse arguments') | ||
| 14 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 15 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 16 | + args, kwargs = parse_args(kwargs) | ||
| 17 | + pprint(args) | ||
| 18 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 19 | + | ||
| 20 | + print('\n[+] Create network') | ||
| 21 | + model = select_model(args) | ||
| 22 | + optimizer = select_optimizer(args, model) | ||
| 23 | + criterion = nn.CrossEntropyLoss() | ||
| 24 | + if args.use_cuda: | ||
| 25 | + model = model.cuda() | ||
| 26 | + criterion = criterion.cuda() | ||
| 27 | + | ||
| 28 | + print('\n[+] Load model') | ||
| 29 | + weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
| 30 | + model.load_state_dict(torch.load(weight_path)) | ||
| 31 | + | ||
| 32 | + print('\n[+] Load dataset') | ||
| 33 | + test_transform = get_valid_transform(args, model) | ||
| 34 | + test_dataset = get_dataset(args, test_transform, 'test') | ||
| 35 | + test_loader = iter(get_dataloader(args, test_dataset)) | ||
| 36 | + | ||
| 37 | + print('\n[+] Start testing') | ||
| 38 | + _test_res = validate(args, model, criterion, test_loader, step=0, writer=None) | ||
| 39 | + | ||
| 40 | + print('\n[+] Valid results') | ||
| 41 | + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
| 42 | + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
| 43 | + print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
| 44 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +if __name__ == '__main__': | ||
| 48 | + fire.Fire(eval) |
| 1 | +import copy | ||
| 2 | +import json | ||
| 3 | +import time | ||
| 4 | +import torch | ||
| 5 | +import random | ||
| 6 | +import torchvision.transforms as transforms | ||
| 7 | + | ||
| 8 | +from torch.utils.data import Subset | ||
| 9 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
| 10 | +from concurrent.futures import ProcessPoolExecutor | ||
| 11 | + | ||
| 12 | +from transforms import * | ||
| 13 | +from hyperopt import fmin, tpe, hp, STATUS_OK, Trials | ||
| 14 | +from utils import * | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +DEFALUT_CANDIDATES = [ | ||
| 18 | + ShearXY, | ||
| 19 | + TranslateXY, | ||
| 20 | + Rotate, | ||
| 21 | + AutoContrast, | ||
| 22 | + Invert, | ||
| 23 | + Equalize, | ||
| 24 | + Solarize, | ||
| 25 | + Posterize, | ||
| 26 | + Contrast, | ||
| 27 | + Color, | ||
| 28 | + Brightness, | ||
| 29 | + Sharpness, | ||
| 30 | + Cutout, | ||
| 31 | +# SamplePairing, | ||
| 32 | +] | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +def train_child(args, model, dataset, subset_indx, device=None): | ||
| 36 | + optimizer = select_optimizer(args, model) | ||
| 37 | + scheduler = select_scheduler(args, optimizer) | ||
| 38 | + criterion = nn.CrossEntropyLoss() | ||
| 39 | + | ||
| 40 | + dataset.transform = transforms.Compose([ | ||
| 41 | + transforms.Resize(32), | ||
| 42 | + transforms.ToTensor()]) | ||
| 43 | + subset = Subset(dataset, subset_indx) | ||
| 44 | + data_loader = get_inf_dataloader(args, subset) | ||
| 45 | + | ||
| 46 | + if device: | ||
| 47 | + model = model.to(device) | ||
| 48 | + criterion = criterion.to(device) | ||
| 49 | + | ||
| 50 | + elif args.use_cuda: | ||
| 51 | + model = model.cuda() | ||
| 52 | + criterion = criterion.cuda() | ||
| 53 | + | ||
| 54 | + if torch.cuda.device_count() > 1: | ||
| 55 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 56 | + model = nn.DataParallel(model) | ||
| 57 | + | ||
| 58 | + start_t = time.time() | ||
| 59 | + for step in range(args.start_step, args.max_step): | ||
| 60 | + batch = next(data_loader) | ||
| 61 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | ||
| 62 | + | ||
| 63 | + if step % args.print_step == 0: | ||
| 64 | + print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format( | ||
| 65 | + step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device)) | ||
| 66 | + | ||
| 67 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
| 68 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
| 69 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
| 70 | + | ||
| 71 | + return _train_res | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +def validate_child(args, model, dataset, subset_indx, transform, device=None): | ||
| 75 | + criterion = nn.CrossEntropyLoss() | ||
| 76 | + | ||
| 77 | + if device: | ||
| 78 | + model = model.to(device) | ||
| 79 | + criterion = criterion.to(device) | ||
| 80 | + | ||
| 81 | + elif args.use_cuda: | ||
| 82 | + model = model.cuda() | ||
| 83 | + criterion = criterion.cuda() | ||
| 84 | + | ||
| 85 | + dataset.transform = transform | ||
| 86 | + subset = Subset(dataset, subset_indx) | ||
| 87 | + data_loader = get_dataloader(args, subset, pin_memory=False) | ||
| 88 | + | ||
| 89 | + return validate(args, model, criterion, data_loader, 0, None, device) | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +def get_next_subpolicy(transform_candidates, op_per_subpolicy=2): | ||
| 93 | + n_candidates = len(transform_candidates) | ||
| 94 | + subpolicy = [] | ||
| 95 | + | ||
| 96 | + for i in range(op_per_subpolicy): | ||
| 97 | + indx = random.randrange(n_candidates) | ||
| 98 | + prob = random.random() | ||
| 99 | + mag = random.random() | ||
| 100 | + subpolicy.append(transform_candidates[indx](prob, mag)) | ||
| 101 | + | ||
| 102 | + subpolicy = transforms.Compose([ | ||
| 103 | + *subpolicy, | ||
| 104 | + transforms.Resize(32), | ||
| 105 | + transforms.ToTensor()]) | ||
| 106 | + | ||
| 107 | + return subpolicy | ||
| 108 | + | ||
| 109 | + | ||
| 110 | +def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
| 111 | + subpolicies = [] | ||
| 112 | + | ||
| 113 | + for b in range(B): | ||
| 114 | + subpolicy = get_next_subpolicy(transform_candidates) | ||
| 115 | + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
| 116 | + subpolicies.append((subpolicy, val_res[2])) | ||
| 117 | + | ||
| 118 | + return subpolicies | ||
| 119 | + | ||
| 120 | + | ||
| 121 | +def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
| 122 | + | ||
| 123 | + def _objective(sampled): | ||
| 124 | + subpolicy = [transform(prob, mag) | ||
| 125 | + for transform, prob, mag in sampled] | ||
| 126 | + | ||
| 127 | + subpolicy = transforms.Compose([ | ||
| 128 | + transforms.Resize(32), | ||
| 129 | + *subpolicy, | ||
| 130 | + transforms.ToTensor()]) | ||
| 131 | + | ||
| 132 | + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
| 133 | + loss = val_res[2].cpu().numpy() | ||
| 134 | + return {'loss': loss, 'status': STATUS_OK } | ||
| 135 | + | ||
| 136 | + space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)), | ||
| 137 | + (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))] | ||
| 138 | + | ||
| 139 | + trials = Trials() | ||
| 140 | + best = fmin(_objective, | ||
| 141 | + space=space, | ||
| 142 | + algo=tpe.suggest, | ||
| 143 | + max_evals=B, | ||
| 144 | + trials=trials) | ||
| 145 | + | ||
| 146 | + subpolicies = [] | ||
| 147 | + for t in trials.trials: | ||
| 148 | + vals = t['misc']['vals'] | ||
| 149 | + subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]), | ||
| 150 | + transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])] | ||
| 151 | + subpolicy = transforms.Compose([ | ||
| 152 | + ## baseline augmentation | ||
| 153 | + transforms.Pad(4), | ||
| 154 | + transforms.RandomCrop(32), | ||
| 155 | + transforms.RandomHorizontalFlip(), | ||
| 156 | + ## policy | ||
| 157 | + *subpolicy, | ||
| 158 | + ## to tensor | ||
| 159 | + transforms.ToTensor()]) | ||
| 160 | + subpolicies.append((subpolicy, t['result']['loss'])) | ||
| 161 | + | ||
| 162 | + return subpolicies | ||
| 163 | + | ||
| 164 | + | ||
| 165 | +def get_topn_subpolicies(subpolicies, N=10): | ||
| 166 | + return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N] | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k): | ||
| 170 | + kwargs = json.loads(args_str) | ||
| 171 | + args, kwargs = parse_args(kwargs) | ||
| 172 | + device_id = k % torch.cuda.device_count() | ||
| 173 | + device = torch.device('cuda:%d' % device_id) | ||
| 174 | + _transform = [] | ||
| 175 | + | ||
| 176 | + print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | ||
| 177 | + | ||
| 178 | + # train child model | ||
| 179 | + child_model = copy.deepcopy(model) | ||
| 180 | + train_res = train_child(args, child_model, dataset, Dm_indx, device) | ||
| 181 | + | ||
| 182 | + # search sub policy | ||
| 183 | + for t in range(T): | ||
| 184 | + #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
| 185 | + subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
| 186 | + subpolicies = get_topn_subpolicies(subpolicies, N) | ||
| 187 | + _transform.extend([subpolicy[0] for subpolicy in subpolicies]) | ||
| 188 | + | ||
| 189 | + return _transform | ||
| 190 | + | ||
| 191 | + | ||
| 192 | +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | ||
| 193 | + args_str = json.dumps(args._asdict()) | ||
| 194 | + dataset = get_dataset(args, None, 'trainval') | ||
| 195 | + num_process = min(torch.cuda.device_count(), num_process) | ||
| 196 | + transform, futures = [], [] | ||
| 197 | + | ||
| 198 | + torch.multiprocessing.set_start_method('spawn', force=True) | ||
| 199 | + | ||
| 200 | + if not transform_candidates: | ||
| 201 | + transform_candidates = DEFALUT_CANDIDATES | ||
| 202 | + | ||
| 203 | + # split | ||
| 204 | + Dm_indexes, Da_indexes = split_dataset(args, dataset, K) | ||
| 205 | + | ||
| 206 | + with ProcessPoolExecutor(max_workers=num_process) as executor: | ||
| 207 | + for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)): | ||
| 208 | + future = executor.submit(process_fn, | ||
| 209 | + args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k) | ||
| 210 | + futures.append(future) | ||
| 211 | + | ||
| 212 | + for future in futures: | ||
| 213 | + transform.extend(future.result()) | ||
| 214 | + | ||
| 215 | + transform = transforms.RandomChoice(transform) | ||
| 216 | + | ||
| 217 | + return transform |
File mode changed
92.9 KB
410 KB
425 KB
423 KB
48.3 KB
229 KB
227 KB
| 1 | +from .basenet import BaseNet |
| 1 | +import torch.nn as nn | ||
| 2 | + | ||
| 3 | +class BaseNet(nn.Module): | ||
| 4 | + def __init__(self, backbone, args): | ||
| 5 | + super(BaseNet, self).__init__() | ||
| 6 | + | ||
| 7 | + # Separate layers | ||
| 8 | + self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
| 9 | + self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
| 10 | + self.fc = list(backbone.children())[-1] | ||
| 11 | + | ||
| 12 | + self.img_size = (224, 224) | ||
| 13 | + | ||
| 14 | + def forward(self, x): | ||
| 15 | + f = self.first(x) | ||
| 16 | + x = self.after(f) | ||
| 17 | + x = x.reshape(x.size(0), -1) | ||
| 18 | + x = self.fc(x) | ||
| 19 | + return x, f |
| 1 | +import math | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.nn.functional as F | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def round_fn(orig, multiplier): | ||
| 7 | + if not multiplier: | ||
| 8 | + return orig | ||
| 9 | + | ||
| 10 | + return int(math.ceil(multiplier * orig)) | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def get_activation_fn(activation): | ||
| 14 | + if activation == "swish": | ||
| 15 | + return Swish | ||
| 16 | + | ||
| 17 | + elif activation == "relu": | ||
| 18 | + return nn.ReLU | ||
| 19 | + | ||
| 20 | + else: | ||
| 21 | + raise Exception('Unkown activation %s' % activation) | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +class Swish(nn.Module): | ||
| 25 | + """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
| 26 | + | ||
| 27 | + def __init__(self, inplace=False): | ||
| 28 | + super().__init__() | ||
| 29 | + self.inplace = True | ||
| 30 | + | ||
| 31 | + def forward(self, x): | ||
| 32 | + if self.inplace: | ||
| 33 | + x.mul_(F.sigmoid(x)) | ||
| 34 | + return x | ||
| 35 | + else: | ||
| 36 | + return x * F.sigmoid(x) | ||
| 37 | + | ||
| 38 | + | ||
| 39 | +class ConvBlock(nn.Module): | ||
| 40 | + """ Conv + BatchNorm + Activation """ | ||
| 41 | + | ||
| 42 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
| 43 | + padding=0, stride=1, activation="swish"): | ||
| 44 | + super().__init__() | ||
| 45 | + self.fw = nn.Sequential( | ||
| 46 | + nn.Conv2d(in_channel, out_channel, kernel_size, | ||
| 47 | + padding=padding, stride=stride, bias=False), | ||
| 48 | + nn.BatchNorm2d(out_channel), | ||
| 49 | + get_activation_fn(activation)()) | ||
| 50 | + | ||
| 51 | + def forward(self, x): | ||
| 52 | + return self.fw(x) | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +class DepthwiseConvBlock(nn.Module): | ||
| 56 | + """ DepthwiseConv2D + BatchNorm + Activation """ | ||
| 57 | + | ||
| 58 | + def __init__(self, in_channel, kernel_size, | ||
| 59 | + padding=0, stride=1, activation="swish"): | ||
| 60 | + super().__init__() | ||
| 61 | + self.fw = nn.Sequential( | ||
| 62 | + nn.Conv2d(in_channel, in_channel, kernel_size, | ||
| 63 | + padding=padding, stride=stride, groups=in_channel, bias=False), | ||
| 64 | + nn.BatchNorm2d(in_channel), | ||
| 65 | + get_activation_fn(activation)()) | ||
| 66 | + | ||
| 67 | + def forward(self, x): | ||
| 68 | + return self.fw(x) | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +class MBConv(nn.Module): | ||
| 72 | + """ Inverted residual block """ | ||
| 73 | + | ||
| 74 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
| 75 | + stride=1, expand_ratio=1, activation="swish"): | ||
| 76 | + super().__init__() | ||
| 77 | + self.in_channel = in_channel | ||
| 78 | + self.out_channel = out_channel | ||
| 79 | + self.expand_ratio = expand_ratio | ||
| 80 | + self.stride = stride | ||
| 81 | + | ||
| 82 | + if expand_ratio != 1: | ||
| 83 | + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
| 84 | + activation=activation) | ||
| 85 | + | ||
| 86 | + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
| 87 | + padding=(kernel_size-1)//2, | ||
| 88 | + stride=stride, activation=activation) | ||
| 89 | + | ||
| 90 | + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
| 91 | + activation=activation) | ||
| 92 | + | ||
| 93 | + def forward(self, inputs): | ||
| 94 | + if self.expand_ratio != 1: | ||
| 95 | + x = self.expand(inputs) | ||
| 96 | + else: | ||
| 97 | + x = inputs | ||
| 98 | + | ||
| 99 | + x = self.dw_conv(x) | ||
| 100 | + x = self.pw_conv(x) | ||
| 101 | + | ||
| 102 | + if self.in_channel == self.out_channel and \ | ||
| 103 | + self.stride == 1: | ||
| 104 | + x = x + inputs | ||
| 105 | + | ||
| 106 | + return x | ||
| 107 | + | ||
| 108 | + | ||
| 109 | +class Net(nn.Module): | ||
| 110 | + """ EfficientNet """ | ||
| 111 | + | ||
| 112 | + def __init__(self, args): | ||
| 113 | + super(Net, self).__init__() | ||
| 114 | + pi = args.pi | ||
| 115 | + activation = args.activation | ||
| 116 | + num_classes = args.num_classes | ||
| 117 | + | ||
| 118 | + self.d = 1.2 ** pi | ||
| 119 | + self.w = 1.1 ** pi | ||
| 120 | + self.r = 1.15 ** pi | ||
| 121 | + self.img_size = (round_fn(224, self.r), round_fn(224, self.r)) | ||
| 122 | + | ||
| 123 | + self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
| 124 | + kernel_size=3, padding=1, stride=2, activation=activation) | ||
| 125 | + | ||
| 126 | + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
| 127 | + depth=round_fn(1, self.d), kernel_size=3, | ||
| 128 | + half_resolution=False, expand_ratio=1, activation=activation) | ||
| 129 | + | ||
| 130 | + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
| 131 | + depth=round_fn(2, self.d), kernel_size=3, | ||
| 132 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 133 | + | ||
| 134 | + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
| 135 | + depth=round_fn(2, self.d), kernel_size=5, | ||
| 136 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 137 | + | ||
| 138 | + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
| 139 | + depth=round_fn(3, self.d), kernel_size=3, | ||
| 140 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 141 | + | ||
| 142 | + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
| 143 | + depth=round_fn(3, self.d), kernel_size=5, | ||
| 144 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
| 145 | + | ||
| 146 | + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
| 147 | + depth=round_fn(4, self.d), kernel_size=5, | ||
| 148 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 149 | + | ||
| 150 | + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
| 151 | + depth=round_fn(1, self.d), kernel_size=3, | ||
| 152 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
| 153 | + | ||
| 154 | + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
| 155 | + kernel_size=1, activation=activation) | ||
| 156 | + | ||
| 157 | + self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes) | ||
| 158 | + | ||
| 159 | + def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
| 160 | + half_resolution=False, expand_ratio=1, activation="swish"): | ||
| 161 | + blocks = [] | ||
| 162 | + for i in range(depth): | ||
| 163 | + stride = 2 if half_resolution and i==0 else 1 | ||
| 164 | + blocks.append( | ||
| 165 | + MBConv(in_channel, out_channel, kernel_size, | ||
| 166 | + stride=stride, expand_ratio=expand_ratio, activation=activation)) | ||
| 167 | + in_channel = out_channel | ||
| 168 | + | ||
| 169 | + return nn.Sequential(*blocks) | ||
| 170 | + | ||
| 171 | + def forward(self, x): | ||
| 172 | + assert x.size()[-2:] == self.img_size, \ | ||
| 173 | + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
| 174 | + | ||
| 175 | + x = self.stage1(x) | ||
| 176 | + x = self.stage2(x) | ||
| 177 | + x = self.stage3(x) | ||
| 178 | + x = self.stage4(x) | ||
| 179 | + x = self.stage5(x) | ||
| 180 | + x = self.stage6(x) | ||
| 181 | + x = self.stage7(x) | ||
| 182 | + x = self.stage8(x) | ||
| 183 | + x = self.stage9(x) | ||
| 184 | + x = x.reshape(x.size(0), -1) | ||
| 185 | + x = self.fc(x) | ||
| 186 | + return x, x |
| 1 | +import math | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.nn.functional as F | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def round_fn(orig, multiplier): | ||
| 7 | + if not multiplier: | ||
| 8 | + return orig | ||
| 9 | + | ||
| 10 | + return int(math.ceil(multiplier * orig)) | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def get_activation_fn(activation): | ||
| 14 | + if activation == "swish": | ||
| 15 | + return Swish | ||
| 16 | + | ||
| 17 | + elif activation == "relu": | ||
| 18 | + return nn.ReLU | ||
| 19 | + | ||
| 20 | + else: | ||
| 21 | + raise Exception('Unkown activation %s' % activation) | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +class Swish(nn.Module): | ||
| 25 | + """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
| 26 | + | ||
| 27 | + def __init__(self, inplace=False): | ||
| 28 | + super().__init__() | ||
| 29 | + self.inplace = True | ||
| 30 | + | ||
| 31 | + def forward(self, x): | ||
| 32 | + if self.inplace: | ||
| 33 | + x.mul_(F.sigmoid(x)) | ||
| 34 | + return x | ||
| 35 | + else: | ||
| 36 | + return x * F.sigmoid(x) | ||
| 37 | + | ||
| 38 | + | ||
| 39 | +class ConvBlock(nn.Module): | ||
| 40 | + """ Conv + BatchNorm + Activation """ | ||
| 41 | + | ||
| 42 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
| 43 | + padding=0, stride=1, activation="swish"): | ||
| 44 | + super().__init__() | ||
| 45 | + self.fw = nn.Sequential( | ||
| 46 | + nn.Conv2d(in_channel, out_channel, kernel_size, | ||
| 47 | + padding=padding, stride=stride, bias=False), | ||
| 48 | + nn.BatchNorm2d(out_channel), | ||
| 49 | + get_activation_fn(activation)()) | ||
| 50 | + | ||
| 51 | + def forward(self, x): | ||
| 52 | + return self.fw(x) | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +class DepthwiseConvBlock(nn.Module): | ||
| 56 | + """ DepthwiseConv2D + BatchNorm + Activation """ | ||
| 57 | + | ||
| 58 | + def __init__(self, in_channel, kernel_size, | ||
| 59 | + padding=0, stride=1, activation="swish"): | ||
| 60 | + super().__init__() | ||
| 61 | + self.fw = nn.Sequential( | ||
| 62 | + nn.Conv2d(in_channel, in_channel, kernel_size, | ||
| 63 | + padding=padding, stride=stride, groups=in_channel, bias=False), | ||
| 64 | + nn.BatchNorm2d(in_channel), | ||
| 65 | + get_activation_fn(activation)()) | ||
| 66 | + | ||
| 67 | + def forward(self, x): | ||
| 68 | + return self.fw(x) | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +class SEBlock(nn.Module): | ||
| 72 | + """ Squeeze and Excitation Block """ | ||
| 73 | + | ||
| 74 | + def __init__(self, in_channel, se_ratio=16): | ||
| 75 | + super().__init__() | ||
| 76 | + self.global_avgpool = nn.AdaptiveAvgPool2d((1,1)) | ||
| 77 | + inter_channel = in_channel // se_ratio | ||
| 78 | + | ||
| 79 | + self.reduce = nn.Sequential( | ||
| 80 | + nn.Conv2d(in_channel, inter_channel, | ||
| 81 | + kernel_size=1, padding=0, stride=1), | ||
| 82 | + nn.ReLU()) | ||
| 83 | + | ||
| 84 | + self.expand = nn.Sequential( | ||
| 85 | + nn.Conv2d(inter_channel, in_channel, | ||
| 86 | + kernel_size=1, padding=0, stride=1), | ||
| 87 | + nn.Sigmoid()) | ||
| 88 | + | ||
| 89 | + | ||
| 90 | + def forward(self, x): | ||
| 91 | + s = self.global_avgpool(x) | ||
| 92 | + s = self.reduce(s) | ||
| 93 | + s = self.expand(s) | ||
| 94 | + return x * s | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +class MBConv(nn.Module): | ||
| 98 | + """ Inverted residual block """ | ||
| 99 | + | ||
| 100 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
| 101 | + stride=1, expand_ratio=1, activation="swish", use_seblock=False): | ||
| 102 | + super().__init__() | ||
| 103 | + self.in_channel = in_channel | ||
| 104 | + self.out_channel = out_channel | ||
| 105 | + self.expand_ratio = expand_ratio | ||
| 106 | + self.stride = stride | ||
| 107 | + self.use_seblock = use_seblock | ||
| 108 | + | ||
| 109 | + if expand_ratio != 1: | ||
| 110 | + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
| 111 | + activation=activation) | ||
| 112 | + | ||
| 113 | + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
| 114 | + padding=(kernel_size-1)//2, | ||
| 115 | + stride=stride, activation=activation) | ||
| 116 | + | ||
| 117 | + if use_seblock: | ||
| 118 | + self.seblock = SEBlock(in_channel*expand_ratio) | ||
| 119 | + | ||
| 120 | + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
| 121 | + activation=activation) | ||
| 122 | + | ||
| 123 | + def forward(self, inputs): | ||
| 124 | + if self.expand_ratio != 1: | ||
| 125 | + x = self.expand(inputs) | ||
| 126 | + else: | ||
| 127 | + x = inputs | ||
| 128 | + | ||
| 129 | + x = self.dw_conv(x) | ||
| 130 | + | ||
| 131 | + if self.use_seblock: | ||
| 132 | + x = self.seblock(x) | ||
| 133 | + | ||
| 134 | + x = self.pw_conv(x) | ||
| 135 | + | ||
| 136 | + if self.in_channel == self.out_channel and \ | ||
| 137 | + self.stride == 1: | ||
| 138 | + x = x + inputs | ||
| 139 | + | ||
| 140 | + return x | ||
| 141 | + | ||
| 142 | + | ||
| 143 | +class Net(nn.Module): | ||
| 144 | + """ EfficientNet """ | ||
| 145 | + | ||
| 146 | + def __init__(self, args): | ||
| 147 | + super(Net, self).__init__() | ||
| 148 | + pi = args.pi | ||
| 149 | + activation = args.activation | ||
| 150 | + num_classes = 10 | ||
| 151 | + | ||
| 152 | + self.d = 1.2 ** pi | ||
| 153 | + self.w = 1.1 ** pi | ||
| 154 | + self.r = 1.15 ** pi | ||
| 155 | + self.img_size = (round_fn(32, self.r), round_fn(32, self.r)) | ||
| 156 | + self.use_seblock = args.use_seblock | ||
| 157 | + | ||
| 158 | + self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
| 159 | + kernel_size=3, padding=1, stride=2, activation=activation) | ||
| 160 | + | ||
| 161 | + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
| 162 | + depth=round_fn(1, self.d), kernel_size=3, | ||
| 163 | + half_resolution=False, expand_ratio=1, activation=activation) | ||
| 164 | + | ||
| 165 | + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
| 166 | + depth=round_fn(2, self.d), kernel_size=3, | ||
| 167 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 168 | + | ||
| 169 | + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
| 170 | + depth=round_fn(2, self.d), kernel_size=5, | ||
| 171 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 172 | + | ||
| 173 | + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
| 174 | + depth=round_fn(3, self.d), kernel_size=3, | ||
| 175 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 176 | + | ||
| 177 | + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
| 178 | + depth=round_fn(3, self.d), kernel_size=5, | ||
| 179 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
| 180 | + | ||
| 181 | + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
| 182 | + depth=round_fn(4, self.d), kernel_size=5, | ||
| 183 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
| 184 | + | ||
| 185 | + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
| 186 | + depth=round_fn(1, self.d), kernel_size=3, | ||
| 187 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
| 188 | + | ||
| 189 | + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
| 190 | + kernel_size=1, activation=activation) | ||
| 191 | + | ||
| 192 | + self.fc = nn.Linear(round_fn(1280, self.w), num_classes) | ||
| 193 | + | ||
| 194 | + def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
| 195 | + half_resolution=False, expand_ratio=1, activation="swish"): | ||
| 196 | + blocks = [] | ||
| 197 | + for i in range(depth): | ||
| 198 | + stride = 2 if half_resolution and i==0 else 1 | ||
| 199 | + blocks.append( | ||
| 200 | + MBConv(in_channel, out_channel, kernel_size, | ||
| 201 | + stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock)) | ||
| 202 | + in_channel = out_channel | ||
| 203 | + | ||
| 204 | + return nn.Sequential(*blocks) | ||
| 205 | + | ||
| 206 | + def forward(self, x): | ||
| 207 | + assert x.size()[-2:] == self.img_size, \ | ||
| 208 | + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
| 209 | + | ||
| 210 | + s = self.stage1(x) | ||
| 211 | + x = self.stage2(s) | ||
| 212 | + x = self.stage3(x) | ||
| 213 | + x = self.stage4(x) | ||
| 214 | + x = self.stage5(x) | ||
| 215 | + x = self.stage6(x) | ||
| 216 | + x = self.stage7(x) | ||
| 217 | + x = self.stage8(x) | ||
| 218 | + x = self.stage9(x) | ||
| 219 | + x = x.reshape(x.size(0), -1) | ||
| 220 | + x = self.fc(x) | ||
| 221 | + return x, s |
| 1 | +import torch.nn as nn | ||
| 2 | + | ||
| 3 | + | ||
| 4 | +class ResidualBlock(nn.Module): | ||
| 5 | + def __init__(self, in_channel, out_channel, stride): | ||
| 6 | + super(ResidualBlock, self).__init__() | ||
| 7 | + self.in_channel = in_channel | ||
| 8 | + self.out_channel = out_channel | ||
| 9 | + self.stride = stride | ||
| 10 | + | ||
| 11 | + self.conv1 = nn.Sequential( | ||
| 12 | + nn.Conv2d(in_channel, out_channel, | ||
| 13 | + kernel_size=3, padding=1, stride=stride), | ||
| 14 | + nn.BatchNorm2d(out_channel)) | ||
| 15 | + | ||
| 16 | + self.relu = nn.ReLU(inplace=True) | ||
| 17 | + | ||
| 18 | + self.conv2 = nn.Sequential( | ||
| 19 | + nn.Conv2d(out_channel, out_channel, | ||
| 20 | + kernel_size=3, padding=1), | ||
| 21 | + nn.BatchNorm2d(out_channel)) | ||
| 22 | + | ||
| 23 | + if self.in_channel != self.out_channel or \ | ||
| 24 | + self.stride != 1: | ||
| 25 | + self.down = nn.Sequential( | ||
| 26 | + nn.Conv2d(in_channel, out_channel, | ||
| 27 | + kernel_size=1, stride=stride), | ||
| 28 | + nn.BatchNorm2d(out_channel)) | ||
| 29 | + | ||
| 30 | + def forward(self, b): | ||
| 31 | + t = self.conv1(b) | ||
| 32 | + t = self.relu(t) | ||
| 33 | + t = self.conv2(t) | ||
| 34 | + | ||
| 35 | + if self.in_channel != self.out_channel or \ | ||
| 36 | + self.stride != 1: | ||
| 37 | + b = self.down(b) | ||
| 38 | + | ||
| 39 | + t += b | ||
| 40 | + t = self.relu(t) | ||
| 41 | + | ||
| 42 | + return t | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +class Net(nn.Module): | ||
| 46 | + def __init__(self, args): | ||
| 47 | + super(Net, self).__init__() | ||
| 48 | + scale = args.scale | ||
| 49 | + | ||
| 50 | + self.stem = nn.Sequential( | ||
| 51 | + nn.Conv2d(3, 16, | ||
| 52 | + kernel_size=3, padding=1), | ||
| 53 | + nn.BatchNorm2d(16), | ||
| 54 | + nn.ReLU(inplace=True)) | ||
| 55 | + | ||
| 56 | + self.layer1 = nn.Sequential(*[ | ||
| 57 | + ResidualBlock(16, 16, 1) for _ in range(2*scale)]) | ||
| 58 | + | ||
| 59 | + self.layer2 = nn.Sequential(*[ | ||
| 60 | + ResidualBlock(in_channel=(16 if i==0 else 32), | ||
| 61 | + out_channel=32, | ||
| 62 | + stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
| 63 | + | ||
| 64 | + self.layer3 = nn.Sequential(*[ | ||
| 65 | + ResidualBlock(in_channel=(32 if i==0 else 64), | ||
| 66 | + out_channel=64, | ||
| 67 | + stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
| 68 | + | ||
| 69 | + self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) | ||
| 70 | + | ||
| 71 | + self.fc = nn.Linear(64, 10) | ||
| 72 | + | ||
| 73 | + for m in self.modules(): | ||
| 74 | + if isinstance(m, nn.Conv2d): | ||
| 75 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
| 76 | + | ||
| 77 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
| 78 | + nn.init.constant_(m.weight, 1) | ||
| 79 | + nn.init.constant_(m.bias, 0) | ||
| 80 | + | ||
| 81 | + def forward(self, x): | ||
| 82 | + s = self.stem(x) | ||
| 83 | + x = self.layer1(s) | ||
| 84 | + x = self.layer2(x) | ||
| 85 | + x = self.layer3(x) | ||
| 86 | + x = self.avg_pool(x) | ||
| 87 | + x = x.reshape(x.size(0), -1) | ||
| 88 | + x = self.fc(x) | ||
| 89 | + | ||
| 90 | + return x, s |
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import time | ||
| 4 | +import json | ||
| 5 | +import random | ||
| 6 | +from pprint import pprint | ||
| 7 | + | ||
| 8 | +import torch.nn as nn | ||
| 9 | +import torch.backends.cudnn as cudnn | ||
| 10 | +from torch.utils.tensorboard import SummaryWriter | ||
| 11 | + | ||
| 12 | +from networks import * | ||
| 13 | +from utils import * | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +def train(**kwargs): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + args, kwargs = parse_args(kwargs) | ||
| 19 | + pprint(args) | ||
| 20 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 21 | + | ||
| 22 | + print('\n[+] Create log dir') | ||
| 23 | + model_name = get_model_name(args) | ||
| 24 | + log_dir = os.path.join('./runs', model_name) | ||
| 25 | + os.makedirs(os.path.join(log_dir, 'model')) | ||
| 26 | + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
| 27 | + writer = SummaryWriter(log_dir=log_dir) | ||
| 28 | + | ||
| 29 | + if args.seed is not None: | ||
| 30 | + random.seed(args.seed) | ||
| 31 | + torch.manual_seed(args.seed) | ||
| 32 | + cudnn.deterministic = True | ||
| 33 | + | ||
| 34 | + print('\n[+] Create network') | ||
| 35 | + model = select_model(args) | ||
| 36 | + optimizer = select_optimizer(args, model) | ||
| 37 | + scheduler = select_scheduler(args, optimizer) | ||
| 38 | + criterion = nn.CrossEntropyLoss() | ||
| 39 | + if args.use_cuda: | ||
| 40 | + model = model.cuda() | ||
| 41 | + criterion = criterion.cuda() | ||
| 42 | + #writer.add_graph(model) | ||
| 43 | + | ||
| 44 | + print('\n[+] Load dataset') | ||
| 45 | + transform = get_train_transform(args, model, log_dir) | ||
| 46 | + val_transform = get_valid_transform(args, model) | ||
| 47 | + train_dataset = get_dataset(args, transform, 'train') | ||
| 48 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
| 49 | + train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
| 50 | + max_epoch = len(train_dataset) // args.batch_size | ||
| 51 | + best_acc = -1 | ||
| 52 | + | ||
| 53 | + print('\n[+] Start training') | ||
| 54 | + if torch.cuda.device_count() > 1: | ||
| 55 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 56 | + model = nn.DataParallel(model) | ||
| 57 | + | ||
| 58 | + start_t = time.time() | ||
| 59 | + for step in range(args.start_step, args.max_step): | ||
| 60 | + batch = next(train_loader) | ||
| 61 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
| 62 | + | ||
| 63 | + if step % args.print_step == 0: | ||
| 64 | + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
| 65 | + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
| 66 | + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
| 67 | + writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
| 68 | + writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
| 69 | + writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
| 70 | + writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
| 71 | + writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
| 72 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
| 73 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
| 74 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
| 75 | + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
| 76 | + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
| 77 | + | ||
| 78 | + if step % args.val_step == args.val_step-1: | ||
| 79 | + valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
| 80 | + _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
| 81 | + print('\n[+] Valid results') | ||
| 82 | + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
| 83 | + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
| 84 | + writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
| 85 | + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
| 86 | + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
| 87 | + print(' Loss : {}'.format(_valid_res[2].data)) | ||
| 88 | + | ||
| 89 | + if _valid_res[0] > best_acc: | ||
| 90 | + best_acc = _valid_res[0] | ||
| 91 | + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
| 92 | + print('\n[+] Model saved') | ||
| 93 | + | ||
| 94 | + writer.close() | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +if __name__ == '__main__': | ||
| 98 | + fire.Fire(train) |
| 1 | +import numpy as np | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torchvision.transforms as transforms | ||
| 4 | + | ||
| 5 | +from abc import ABC, abstractmethod | ||
| 6 | +from PIL import Image, ImageOps, ImageEnhance | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +class BaseTransform(ABC): | ||
| 10 | + | ||
| 11 | + def __init__(self, prob, mag): | ||
| 12 | + self.prob = prob | ||
| 13 | + self.mag = mag | ||
| 14 | + | ||
| 15 | + def __call__(self, img): | ||
| 16 | + return transforms.RandomApply([self.transform], self.prob)(img) | ||
| 17 | + | ||
| 18 | + def __repr__(self): | ||
| 19 | + return '%s(prob=%.2f, magnitude=%.2f)' % \ | ||
| 20 | + (self.__class__.__name__, self.prob, self.mag) | ||
| 21 | + | ||
| 22 | + @abstractmethod | ||
| 23 | + def transform(self, img): | ||
| 24 | + pass | ||
| 25 | + | ||
| 26 | + | ||
| 27 | +class ShearXY(BaseTransform): | ||
| 28 | + | ||
| 29 | + def transform(self, img): | ||
| 30 | + degrees = self.mag * 360 | ||
| 31 | + t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR) | ||
| 32 | + return t(img) | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +class TranslateXY(BaseTransform): | ||
| 36 | + | ||
| 37 | + def transform(self, img): | ||
| 38 | + translate = (self.mag, self.mag) | ||
| 39 | + t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR) | ||
| 40 | + return t(img) | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +class Rotate(BaseTransform): | ||
| 44 | + | ||
| 45 | + def transform(self, img): | ||
| 46 | + degrees = self.mag * 360 | ||
| 47 | + t = transforms.RandomRotation(degrees, Image.BILINEAR) | ||
| 48 | + return t(img) | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +class AutoContrast(BaseTransform): | ||
| 52 | + | ||
| 53 | + def transform(self, img): | ||
| 54 | + cutoff = int(self.mag * 49) | ||
| 55 | + return ImageOps.autocontrast(img, cutoff=cutoff) | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +class Invert(BaseTransform): | ||
| 59 | + | ||
| 60 | + def transform(self, img): | ||
| 61 | + return ImageOps.invert(img) | ||
| 62 | + | ||
| 63 | + | ||
| 64 | +class Equalize(BaseTransform): | ||
| 65 | + | ||
| 66 | + def transform(self, img): | ||
| 67 | + return ImageOps.equalize(img) | ||
| 68 | + | ||
| 69 | + | ||
| 70 | +class Solarize(BaseTransform): | ||
| 71 | + | ||
| 72 | + def transform(self, img): | ||
| 73 | + threshold = (1-self.mag) * 255 | ||
| 74 | + return ImageOps.solarize(img, threshold) | ||
| 75 | + | ||
| 76 | + | ||
| 77 | +class Posterize(BaseTransform): | ||
| 78 | + | ||
| 79 | + def transform(self, img): | ||
| 80 | + bits = int((1-self.mag) * 8) | ||
| 81 | + return ImageOps.posterize(img, bits=bits) | ||
| 82 | + | ||
| 83 | + | ||
| 84 | +class Contrast(BaseTransform): | ||
| 85 | + | ||
| 86 | + def transform(self, img): | ||
| 87 | + factor = self.mag * 10 | ||
| 88 | + return ImageEnhance.Contrast(img).enhance(factor) | ||
| 89 | + | ||
| 90 | + | ||
| 91 | +class Color(BaseTransform): | ||
| 92 | + | ||
| 93 | + def transform(self, img): | ||
| 94 | + factor = self.mag * 10 | ||
| 95 | + return ImageEnhance.Color(img).enhance(factor) | ||
| 96 | + | ||
| 97 | + | ||
| 98 | +class Brightness(BaseTransform): | ||
| 99 | + | ||
| 100 | + def transform(self, img): | ||
| 101 | + factor = self.mag * 10 | ||
| 102 | + return ImageEnhance.Brightness(img).enhance(factor) | ||
| 103 | + | ||
| 104 | + | ||
| 105 | +class Sharpness(BaseTransform): | ||
| 106 | + | ||
| 107 | + def transform(self, img): | ||
| 108 | + factor = self.mag * 10 | ||
| 109 | + return ImageEnhance.Sharpness(img).enhance(factor) | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +class Cutout(BaseTransform): | ||
| 113 | + | ||
| 114 | + def transform(self, img): | ||
| 115 | + n_holes = 1 | ||
| 116 | + length = 24 * self.mag | ||
| 117 | + cutout_op = CutoutOp(n_holes=n_holes, length=length) | ||
| 118 | + return cutout_op(img) | ||
| 119 | + | ||
| 120 | + | ||
| 121 | +class CutoutOp(object): | ||
| 122 | + """ | ||
| 123 | + https://github.com/uoguelph-mlrg/Cutout | ||
| 124 | + | ||
| 125 | + Randomly mask out one or more patches from an image. | ||
| 126 | + | ||
| 127 | + Args: | ||
| 128 | + n_holes (int): Number of patches to cut out of each image. | ||
| 129 | + length (int): The length (in pixels) of each square patch. | ||
| 130 | + """ | ||
| 131 | + def __init__(self, n_holes, length): | ||
| 132 | + self.n_holes = n_holes | ||
| 133 | + self.length = length | ||
| 134 | + | ||
| 135 | + def __call__(self, img): | ||
| 136 | + """ | ||
| 137 | + Args: | ||
| 138 | + img (Tensor): Tensor image of size (C, H, W). | ||
| 139 | + Returns: | ||
| 140 | + Tensor: Image with n_holes of dimension length x length cut out of it. | ||
| 141 | + """ | ||
| 142 | + w, h = img.size | ||
| 143 | + | ||
| 144 | + mask = np.ones((h, w, 1), np.uint8) | ||
| 145 | + | ||
| 146 | + for n in range(self.n_holes): | ||
| 147 | + y = np.random.randint(h) | ||
| 148 | + x = np.random.randint(w) | ||
| 149 | + | ||
| 150 | + y1 = np.clip(y - self.length // 2, 0, h).astype(int) | ||
| 151 | + y2 = np.clip(y + self.length // 2, 0, h).astype(int) | ||
| 152 | + x1 = np.clip(x - self.length // 2, 0, w).astype(int) | ||
| 153 | + x2 = np.clip(x + self.length // 2, 0, w).astype(int) | ||
| 154 | + | ||
| 155 | + mask[y1: y2, x1: x2, :] = 0. | ||
| 156 | + | ||
| 157 | + img = mask*np.asarray(img).astype(np.uint8) | ||
| 158 | + img = Image.fromarray(mask*np.asarray(img)) | ||
| 159 | + | ||
| 160 | + return img | ||
| 161 | + |
| 1 | +import os | ||
| 2 | +import time | ||
| 3 | +import importlib | ||
| 4 | +import collections | ||
| 5 | +import pickle as cp | ||
| 6 | +import glob | ||
| 7 | +import numpy as np | ||
| 8 | + | ||
| 9 | +import torch | ||
| 10 | +import torchvision | ||
| 11 | +import torch.nn.functional as F | ||
| 12 | +import torchvision.models as models | ||
| 13 | +import torchvision.transforms as transforms | ||
| 14 | +from torch.utils.data import Subset | ||
| 15 | +from torch.utils.data import Dataset, DataLoader | ||
| 16 | + | ||
| 17 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | ||
| 21 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | ||
| 22 | +current_epoch = 0 | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def split_dataset(args, dataset, k): | ||
| 26 | + # load dataset | ||
| 27 | + X = list(range(len(dataset))) | ||
| 28 | + Y = dataset.targets | ||
| 29 | + | ||
| 30 | + # split to k-fold | ||
| 31 | + assert len(X) == len(Y) | ||
| 32 | + | ||
| 33 | + def _it_to_list(_it): | ||
| 34 | + return list(zip(*list(_it))) | ||
| 35 | + | ||
| 36 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
| 37 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
| 38 | + | ||
| 39 | + return Dm_indexes, Da_indexes | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +def concat_image_features(image, features, max_features=3): | ||
| 43 | + _, h, w = image.shape | ||
| 44 | + | ||
| 45 | + max_features = min(features.size(0), max_features) | ||
| 46 | + image_feature = image.clone() | ||
| 47 | + | ||
| 48 | + for i in range(max_features): | ||
| 49 | + feature = features[i:i+1] | ||
| 50 | + _min, _max = torch.min(feature), torch.max(feature) | ||
| 51 | + feature = (feature - _min) / (_max - _min + 1e-6) | ||
| 52 | + feature = torch.cat([feature]*3, 0) | ||
| 53 | + feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
| 54 | + feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
| 55 | + feature = feature.view(3, h, w) | ||
| 56 | + image_feature = torch.cat((image_feature, feature), 2) | ||
| 57 | + | ||
| 58 | + return image_feature | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +def get_model_name(args): | ||
| 62 | + from datetime import datetime | ||
| 63 | + now = datetime.now() | ||
| 64 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
| 65 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
| 66 | + return model_name | ||
| 67 | + | ||
| 68 | + | ||
| 69 | +def dict_to_namedtuple(d): | ||
| 70 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
| 71 | + | ||
| 72 | + for k,v in d.items(): | ||
| 73 | + if type(v) is dict: | ||
| 74 | + d[k] = dict_to_namedtuple(v) | ||
| 75 | + | ||
| 76 | + elif type(v) is str: | ||
| 77 | + try: | ||
| 78 | + d[k] = eval(v) | ||
| 79 | + except: | ||
| 80 | + d[k] = v | ||
| 81 | + | ||
| 82 | + args = Args(**d) | ||
| 83 | + return args | ||
| 84 | + | ||
| 85 | + | ||
| 86 | +def parse_args(kwargs): | ||
| 87 | + # combine with default args | ||
| 88 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10' | ||
| 89 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10' | ||
| 90 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
| 91 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1 | ||
| 92 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
| 93 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
| 94 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
| 95 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
| 96 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000 | ||
| 97 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000 | ||
| 98 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
| 99 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
| 100 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
| 101 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000 | ||
| 102 | + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | ||
| 103 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
| 104 | + | ||
| 105 | + # to named tuple | ||
| 106 | + args = dict_to_namedtuple(kwargs) | ||
| 107 | + return args, kwargs | ||
| 108 | + | ||
| 109 | + | ||
| 110 | +def select_model(args): | ||
| 111 | + if args.network in models.__dict__: | ||
| 112 | + backbone = models.__dict__[args.network]() | ||
| 113 | + model = BaseNet(backbone, args) | ||
| 114 | + else: | ||
| 115 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
| 116 | + model = Net(args) | ||
| 117 | + | ||
| 118 | + print(model) | ||
| 119 | + return model | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +def select_optimizer(args, model): | ||
| 123 | + if args.optimizer == 'sgd': | ||
| 124 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
| 125 | + elif args.optimizer == 'rms': | ||
| 126 | + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5) | ||
| 127 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
| 128 | + elif args.optimizer == 'adam': | ||
| 129 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
| 130 | + else: | ||
| 131 | + raise Exception('Unknown Optimizer') | ||
| 132 | + return optimizer | ||
| 133 | + | ||
| 134 | + | ||
| 135 | +def select_scheduler(args, optimizer): | ||
| 136 | + if not args.scheduler or args.scheduler == 'None': | ||
| 137 | + return None | ||
| 138 | + elif args.scheduler =='clr': | ||
| 139 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
| 140 | + elif args.scheduler =='exp': | ||
| 141 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
| 142 | + else: | ||
| 143 | + raise Exception('Unknown Scheduler') | ||
| 144 | + | ||
| 145 | + | ||
| 146 | +class CustomDataset(Dataset): | ||
| 147 | + def __init__(self, path, transform = None): | ||
| 148 | + self.path = path | ||
| 149 | + self.transform = transform | ||
| 150 | + self.img = np.load(path) | ||
| 151 | + self.len = self.img.shape[0] | ||
| 152 | + | ||
| 153 | + def __len__(self): | ||
| 154 | + return self.len | ||
| 155 | + | ||
| 156 | + def __getitem__(self, idx): | ||
| 157 | + if self.transforms is not None: | ||
| 158 | + img = self.transforms(img) | ||
| 159 | + return img | ||
| 160 | + | ||
| 161 | +def get_dataset(args, transform, split='train'): | ||
| 162 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
| 163 | + | ||
| 164 | + if args.dataset == 'cifar10': | ||
| 165 | + train = split in ['train', 'val', 'trainval'] | ||
| 166 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
| 167 | + train=train, | ||
| 168 | + transform=transform, | ||
| 169 | + download=True) | ||
| 170 | + | ||
| 171 | + if split in ['train', 'val']: | ||
| 172 | + split_path = os.path.join(DATASET_PATH, | ||
| 173 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
| 174 | + | ||
| 175 | + if not os.path.exists(split_path): | ||
| 176 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
| 177 | + split_index = {'train':train_index, 'val':val_index} | ||
| 178 | + cp.dump(split_index, open(split_path, 'wb')) | ||
| 179 | + | ||
| 180 | + split_index = cp.load(open(split_path, 'rb')) | ||
| 181 | + dataset = Subset(dataset, split_index[split]) | ||
| 182 | + | ||
| 183 | + elif args.dataset == 'imagenet': | ||
| 184 | + dataset = torchvision.datasets.ImageNet(DATASET_PATH, | ||
| 185 | + split=split, | ||
| 186 | + transform=transform, | ||
| 187 | + download=(split is 'val')) | ||
| 188 | + | ||
| 189 | + elif args.dataset == 'BraTS': | ||
| 190 | + if split in ['train']: | ||
| 191 | + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) | ||
| 192 | + else: | ||
| 193 | + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) | ||
| 194 | + | ||
| 195 | + | ||
| 196 | + else: | ||
| 197 | + raise Exception('Unknown dataset') | ||
| 198 | + | ||
| 199 | + return dataset | ||
| 200 | + | ||
| 201 | + | ||
| 202 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 203 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
| 204 | + batch_size=args.batch_size, | ||
| 205 | + shuffle=shuffle, | ||
| 206 | + num_workers=args.num_workers, | ||
| 207 | + pin_memory=pin_memory) | ||
| 208 | + return data_loader | ||
| 209 | + | ||
| 210 | + | ||
| 211 | +def get_inf_dataloader(args, dataset): | ||
| 212 | + global current_epoch | ||
| 213 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 214 | + | ||
| 215 | + while True: | ||
| 216 | + try: | ||
| 217 | + batch = next(data_loader) | ||
| 218 | + | ||
| 219 | + except StopIteration: | ||
| 220 | + current_epoch += 1 | ||
| 221 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 222 | + batch = next(data_loader) | ||
| 223 | + | ||
| 224 | + yield batch | ||
| 225 | + | ||
| 226 | + | ||
| 227 | +def get_train_transform(args, model, log_dir=None): | ||
| 228 | + if args.fast_auto_augment: | ||
| 229 | + assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet | ||
| 230 | + | ||
| 231 | + from fast_auto_augment import fast_auto_augment | ||
| 232 | + if args.augment_path: | ||
| 233 | + transform = cp.load(open(args.augment_path, 'rb')) | ||
| 234 | + os.system('cp {} {}'.format( | ||
| 235 | + args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | ||
| 236 | + else: | ||
| 237 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | ||
| 238 | + if log_dir: | ||
| 239 | + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | ||
| 240 | + | ||
| 241 | + elif args.dataset == 'cifar10': | ||
| 242 | + transform = transforms.Compose([ | ||
| 243 | + transforms.Pad(4), | ||
| 244 | + transforms.RandomCrop(32), | ||
| 245 | + transforms.RandomHorizontalFlip(), | ||
| 246 | + transforms.ToTensor() | ||
| 247 | + ]) | ||
| 248 | + | ||
| 249 | + elif args.dataset == 'imagenet': | ||
| 250 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
| 251 | + transform = transforms.Compose([ | ||
| 252 | + transforms.Resize([resize_h, resize_w]), | ||
| 253 | + transforms.RandomCrop(model.img_size), | ||
| 254 | + transforms.RandomHorizontalFlip(), | ||
| 255 | + transforms.ToTensor() | ||
| 256 | + ]) | ||
| 257 | + | ||
| 258 | + elif args.dataset == 'BraTS': | ||
| 259 | + resize_h, resize_w = 256, 256 | ||
| 260 | + transform = transforms.Compose([ | ||
| 261 | + transforms.Resize([resize_h, resize_w]), | ||
| 262 | + transforms.RandomCrop(model.img_size), | ||
| 263 | + transforms.RandomHorizontalFlip(), | ||
| 264 | + transforms.ToTensor() | ||
| 265 | + ]) | ||
| 266 | + else: | ||
| 267 | + raise Exception('Unknown Dataset') | ||
| 268 | + | ||
| 269 | + print(transform) | ||
| 270 | + | ||
| 271 | + return transform | ||
| 272 | + | ||
| 273 | + | ||
| 274 | +def get_valid_transform(args, model): | ||
| 275 | + if args.dataset == 'cifar10': | ||
| 276 | + val_transform = transforms.Compose([ | ||
| 277 | + transforms.Resize(32), | ||
| 278 | + transforms.ToTensor() | ||
| 279 | + ]) | ||
| 280 | + | ||
| 281 | + elif args.dataset == 'imagenet': | ||
| 282 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
| 283 | + val_transform = transforms.Compose([ | ||
| 284 | + transforms.Resize([resize_h, resize_w]), | ||
| 285 | + transforms.ToTensor() | ||
| 286 | + ]) | ||
| 287 | + elif args.dataset == 'BraTS': | ||
| 288 | + resize_h, resize_w = 256, 256 | ||
| 289 | + val_transform = transforms.Compose([ | ||
| 290 | + transforms.Resize([resize_h, resize_w]), | ||
| 291 | + transforms.ToTensor() | ||
| 292 | + ]) | ||
| 293 | + else: | ||
| 294 | + raise Exception('Unknown Dataset') | ||
| 295 | + | ||
| 296 | + return val_transform | ||
| 297 | + | ||
| 298 | + | ||
| 299 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
| 300 | + model.train() | ||
| 301 | + images, target = batch | ||
| 302 | + | ||
| 303 | + if device: | ||
| 304 | + images = images.to(device) | ||
| 305 | + target = target.to(device) | ||
| 306 | + | ||
| 307 | + elif args.use_cuda: | ||
| 308 | + images = images.cuda(non_blocking=True) | ||
| 309 | + target = target.cuda(non_blocking=True) | ||
| 310 | + | ||
| 311 | + # compute output | ||
| 312 | + start_t = time.time() | ||
| 313 | + output, first = model(images) | ||
| 314 | + forward_t = time.time() - start_t | ||
| 315 | + loss = criterion(output, target) | ||
| 316 | + | ||
| 317 | + # measure accuracy and record loss | ||
| 318 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 319 | + acc1 /= images.size(0) | ||
| 320 | + acc5 /= images.size(0) | ||
| 321 | + | ||
| 322 | + # compute gradient and do SGD step | ||
| 323 | + optimizer.zero_grad() | ||
| 324 | + start_t = time.time() | ||
| 325 | + loss.backward() | ||
| 326 | + backward_t = time.time() - start_t | ||
| 327 | + optimizer.step() | ||
| 328 | + if scheduler: scheduler.step() | ||
| 329 | + | ||
| 330 | + if writer and step % args.print_step == 0: | ||
| 331 | + n_imgs = min(images.size(0), 10) | ||
| 332 | + for j in range(n_imgs): | ||
| 333 | + writer.add_image('train/input_image', | ||
| 334 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 335 | + | ||
| 336 | + return acc1, acc5, loss, forward_t, backward_t | ||
| 337 | + | ||
| 338 | + | ||
| 339 | +def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
| 340 | + # switch to evaluate mode | ||
| 341 | + model.eval() | ||
| 342 | + | ||
| 343 | + acc1, acc5 = 0, 0 | ||
| 344 | + samples = 0 | ||
| 345 | + infer_t = 0 | ||
| 346 | + | ||
| 347 | + with torch.no_grad(): | ||
| 348 | + for i, (images, target) in enumerate(valid_loader): | ||
| 349 | + | ||
| 350 | + start_t = time.time() | ||
| 351 | + if device: | ||
| 352 | + images = images.to(device) | ||
| 353 | + target = target.to(device) | ||
| 354 | + | ||
| 355 | + elif args.use_cuda is not None: | ||
| 356 | + images = images.cuda(non_blocking=True) | ||
| 357 | + target = target.cuda(non_blocking=True) | ||
| 358 | + | ||
| 359 | + # compute output | ||
| 360 | + output, first = model(images) | ||
| 361 | + loss = criterion(output, target) | ||
| 362 | + infer_t += time.time() - start_t | ||
| 363 | + | ||
| 364 | + # measure accuracy and record loss | ||
| 365 | + _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 366 | + acc1 += _acc1 | ||
| 367 | + acc5 += _acc5 | ||
| 368 | + samples += images.size(0) | ||
| 369 | + | ||
| 370 | + acc1 /= samples | ||
| 371 | + acc5 /= samples | ||
| 372 | + | ||
| 373 | + if writer: | ||
| 374 | + n_imgs = min(images.size(0), 10) | ||
| 375 | + for j in range(n_imgs): | ||
| 376 | + writer.add_image('valid/input_image', | ||
| 377 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 378 | + | ||
| 379 | + return acc1, acc5, loss, infer_t | ||
| 380 | + | ||
| 381 | + | ||
| 382 | +def accuracy(output, target, topk=(1,)): | ||
| 383 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
| 384 | + with torch.no_grad(): | ||
| 385 | + maxk = max(topk) | ||
| 386 | + batch_size = target.size(0) | ||
| 387 | + | ||
| 388 | + _, pred = output.topk(maxk, 1, True, True) | ||
| 389 | + pred = pred.t() | ||
| 390 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 391 | + | ||
| 392 | + res = [] | ||
| 393 | + for k in topk: | ||
| 394 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
| 395 | + res.append(correct_k) | ||
| 396 | + return res | ||
| 397 | + |
-
Please register or login to post a comment