김재형
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +env/
12 +build/
13 +develop-eggs/
14 +dist/
15 +downloads/
16 +eggs/
17 +.eggs/
18 +lib/
19 +lib64/
20 +parts/
21 +sdist/
22 +var/
23 +*.egg-info/
24 +.installed.cfg
25 +*.egg
26 +
27 +# PyInstaller
28 +# Usually these files are written by a python script from a template
29 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
30 +*.manifest
31 +*.spec
32 +
33 +# Installer logs
34 +pip-log.txt
35 +pip-delete-this-directory.txt
36 +
37 +# Unit test / coverage reports
38 +htmlcov/
39 +.tox/
40 +.coverage
41 +.coverage.*
42 +.cache
43 +nosetests.xml
44 +coverage.xml
45 +*,cover
46 +
47 +# Translations
48 +*.mo
49 +*.pot
50 +
51 +# Django stuff:
52 +*.log
53 +
54 +# Sphinx documentation
55 +docs/_build/
56 +
57 +# PyBuilder
58 +target/
59 +
60 +# PyTorch
61 +*.pt
62 +*.pdf
63 +*.png
64 +*.txt
65 +*.swp
66 +.vscode
1 +MIT License
2 +
3 +Copyright (c) 2018 Sanghyun Son
4 +
5 +Permission is hereby granted, free of charge, to any person obtaining a copy
6 +of this software and associated documentation files (the "Software"), to deal
7 +in the Software without restriction, including without limitation the rights
8 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 +copies of the Software, and to permit persons to whom the Software is
10 +furnished to do so, subject to the following conditions:
11 +
12 +The above copyright notice and this permission notice shall be included in all
13 +copies or substantial portions of the Software.
14 +
15 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 +SOFTWARE.
1 +**About PyTorch 1.2.0**
2 + * Now the master branch supports PyTorch 1.2.0 by default.
3 + * Due to the serious version problem (especially torch.utils.data.dataloader), MDSR functions are temporarily disabled. If you have to train/evaluate the MDSR model, please use legacy branches.
4 +
5 +# EDSR-PyTorch
6 +
7 +**About PyTorch 1.1.0**
8 + * There have been minor changes with the 1.1.0 update. Now we support PyTorch 1.1.0 by default, and please use the legacy branch if you prefer older version.
9 +
10 +![](/figs/main.png)
11 +
12 +This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**.
13 +You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).
14 +
15 +If you find our work useful in your research or publication, please cite our work:
16 +
17 +[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** <i>2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. </i> [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
18 +```
19 +@InProceedings{Lim_2017_CVPR_Workshops,
20 + author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
21 + title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
22 + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
23 + month = {July},
24 + year = {2017}
25 +}
26 +```
27 +We provide scripts for reproducing all the results from our paper. You can train your model from scratch, or use a pre-trained model to enlarge your images.
28 +
29 +**Differences between Torch version**
30 +* Codes are much more compact. (Removed all unnecessary parts.)
31 +* Models are smaller. (About half.)
32 +* Slightly better performances.
33 +* Training and evaluation requires less memory.
34 +* Python-based.
35 +
36 +## Dependencies
37 +* Python 3.6
38 +* PyTorch >= 1.0.0
39 +* numpy
40 +* skimage
41 +* **imageio**
42 +* matplotlib
43 +* tqdm
44 +* cv2 >= 3.xx (Only if you want to use video input/output)
45 +
46 +## Code
47 +Clone this repository into any place you want.
48 +```bash
49 +git clone https://github.com/thstkdgus35/EDSR-PyTorch
50 +cd EDSR-PyTorch
51 +```
52 +
53 +## Quickstart (Demo)
54 +You can test our super-resolution algorithm with your images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files.
55 +
56 +Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
57 +```bash
58 +cd src # You are now in */EDSR-PyTorch/src
59 +sh demo.sh
60 +```
61 +
62 +You can find the result images from ```experiment/test/results``` folder.
63 +
64 +| Model | Scale | File name (.pt) | Parameters | ****PSNR** |
65 +| --- | --- | --- | --- | --- |
66 +| **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB |
67 +| | | *EDSR_x2 | 40.7 M | 35.03 dB |
68 +| | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB |
69 +| | | *EDSR_x3 | 43.7 M | 31.26 dB |
70 +| | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB |
71 +| | | *EDSR_x4 | 43.1 M | 29.25 dB |
72 +| **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB |
73 +| | | *MDSR | 7.95 M| 34.92 dB |
74 +| | 3 | MDSR_baseline | | 30.94 dB |
75 +| | | *MDSR | | 31.22 dB |
76 +| | 4 | MDSR_baseline | | 28.97 dB |
77 +| | | *MDSR | | 29.24 dB |
78 +
79 +*Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB)
80 +**We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored.
81 +
82 +You can evaluate your models with widely-used benchmark datasets:
83 +
84 +[Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),
85 +
86 +[Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),
87 +
88 +[B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
89 +
90 +[Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
91 +
92 +For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data <where_benchmark_folder_located>`` to evaluate the EDSR and MDSR with the benchmarks.
93 +
94 +You can download some results from [here](https://cv.snu.ac.kr/research/EDSR/result_image/edsr-results.tar).
95 +The link contains **EDSR+_baseline_x4** and **EDSR+_x4**.
96 +Otherwise, you can easily generate result images with ``demo.sh`` scripts.
97 +
98 +## How to train EDSR and MDSR
99 +We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
100 +
101 +Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located.
102 +
103 +We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.
104 +
105 +If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.
106 +
107 +You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train <x2 model>``` argument.
108 +
109 +```bash
110 +cd src # You are now in */EDSR-PyTorch/src
111 +sh demo.sh
112 +```
113 +
114 +**Update log**
115 +* Jan 04, 2018
116 + * Many parts are re-written. You cannot use previous scripts and models directly.
117 + * Pre-trained MDSR is temporarily disabled.
118 + * Training details are included.
119 +
120 +* Jan 09, 2018
121 + * Missing files are included (```src/data/MyImage.py```).
122 + * Some links are fixed.
123 +
124 +* Jan 16, 2018
125 + * Memory efficient forward function is implemented.
126 + * Add --chop_forward argument to your script to enable it.
127 + * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)
128 +
129 +* Feb 21, 2018
130 + * Fixed the problem when loading pre-trained multi-GPU model.
131 + * Added pre-trained scale 2 baseline model.
132 + * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to keep all the intermediate models.
133 + * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.
134 +
135 +* Feb 23, 2018
136 + * Now PyTorch 0.3.1 is a default. Use legacy/0.3.0 branch if you use the old version.
137 + * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.
138 + * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)
139 + * With ``--ext bin``, this code will automatically generate and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)
140 + * If you cannot make the binary pack, use the default setting (``--ext img``).
141 +
142 + * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.
143 + * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)
144 + * Added performance comparison between Torch7 model and PyTorch models.
145 +
146 +* Mar 5, 2018
147 + * All baseline models are uploaded.
148 + * Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images.
149 +
150 +* Mar 11, 2018
151 + * Fixed some typos in the code and script.
152 + * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
153 + * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly the same with that of Torch7 version, it will work as you expected.
154 +
155 +* Mar 20, 2018
156 + * Use ``--ext sep-reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
157 + * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
158 + * Changed the behavior of skip_batch.
159 +
160 +* Mar 29, 2018
161 + * We now provide all models from our paper.
162 + * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in the original low-resolution image. Please use it if you have any trouble.
163 + * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
164 + * Some codes and script are re-written.
165 +
166 +* Apr 9, 2018
167 + * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.
168 + * Many codes are refactored. If there exists a bug, please report it.
169 + * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. The default setting is D-DBPN-L.
170 +
171 +* Apr 26, 2018
172 + * Compatible with PyTorch 0.4.0
173 + * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch.
174 + * Minor bug fixes
175 +
176 +* July 22, 2018
177 + * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
178 + * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid using ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
179 +
180 +* Oct 18, 2018
181 + * with ``--pre_train download``, pretrained models will be automatically downloaded from the server.
182 + * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``.
183 +
184 +* About PyTorch 1.0.0
185 + * We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches.
186 + * ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument.
1 +*
2 +!.gitignore
3 +!/model/*.pt
File mode changed
1 +from importlib import import_module
2 +#from dataloader import MSDataLoader
3 +from torch.utils.data import dataloader
4 +from torch.utils.data import ConcatDataset
5 +
6 +# This is a simple wrapper function for ConcatDataset
7 +class MyConcatDataset(ConcatDataset):
8 + def __init__(self, datasets):
9 + super(MyConcatDataset, self).__init__(datasets)
10 + self.train = datasets[0].train
11 +
12 + def set_scale(self, idx_scale):
13 + for d in self.datasets:
14 + if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15 +
16 +class Data:
17 + def __init__(self, args):
18 + self.loader_train = None
19 + if not args.test_only:
20 + datasets = []
21 + for d in args.data_train:
22 + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23 + m = import_module('data.' + module_name.lower())
24 + datasets.append(getattr(m, module_name)(args, name=d))
25 +
26 + self.loader_train = dataloader.DataLoader(
27 + MyConcatDataset(datasets),
28 + batch_size=args.batch_size,
29 + shuffle=True,
30 + pin_memory=not args.cpu,
31 + num_workers=args.n_threads,
32 + )
33 +
34 + self.loader_test = []
35 + for d in args.data_test:
36 + if d in ['Set5', 'Set14', 'B100', 'Urban100']:
37 + m = import_module('data.benchmark')
38 + testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39 + else:
40 + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41 + m = import_module('data.' + module_name.lower())
42 + testset = getattr(m, module_name)(args, train=False, name=d)
43 +
44 + self.loader_test.append(
45 + dataloader.DataLoader(
46 + testset,
47 + batch_size=1,
48 + shuffle=False,
49 + pin_memory=not args.cpu,
50 + num_workers=args.n_threads,
51 + )
52 + )
1 +import os
2 +
3 +from data import common
4 +from data import srdata
5 +
6 +import numpy as np
7 +
8 +import torch
9 +import torch.utils.data as data
10 +
11 +class Benchmark(srdata.SRData):
12 + def __init__(self, args, name='', train=True, benchmark=True):
13 + super(Benchmark, self).__init__(
14 + args, name=name, train=train, benchmark=True
15 + )
16 +
17 + def _set_filesystem(self, dir_data):
18 + self.apath = os.path.join(dir_data, 'benchmark', self.name)
19 + self.dir_hr = os.path.join(self.apath, 'HR')
20 + if self.input_large:
21 + self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22 + else:
23 + self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24 + self.ext = ('', '.png')
25 +
1 +import random
2 +
3 +import numpy as np
4 +import skimage.color as sc
5 +
6 +import torch
7 +
8 +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9 + ih, iw = args[0].shape[:2]
10 +
11 + if not input_large:
12 + p = scale if multi else 1
13 + tp = p * patch_size
14 + ip = tp // scale
15 + else:
16 + tp = patch_size
17 + ip = patch_size
18 +
19 + ix = random.randrange(0, iw - ip + 1)
20 + iy = random.randrange(0, ih - ip + 1)
21 +
22 + if not input_large:
23 + tx, ty = scale * ix, scale * iy
24 + else:
25 + tx, ty = ix, iy
26 +
27 + ret = [
28 + args[0][iy:iy + ip, ix:ix + ip, :],
29 + *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30 + ]
31 +
32 + return ret
33 +
34 +def set_channel(*args, n_channels=3):
35 + def _set_channel(img):
36 + if img.ndim == 2:
37 + img = np.expand_dims(img, axis=2)
38 +
39 + c = img.shape[2]
40 + if n_channels == 1 and c == 3:
41 + img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42 + elif n_channels == 3 and c == 1:
43 + img = np.concatenate([img] * n_channels, 2)
44 +
45 + return img
46 +
47 + return [_set_channel(a) for a in args]
48 +
49 +def np2Tensor(*args, rgb_range=255):
50 + def _np2Tensor(img):
51 + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52 + tensor = torch.from_numpy(np_transpose).float()
53 + tensor.mul_(rgb_range / 255)
54 +
55 + return tensor
56 +
57 + return [_np2Tensor(a) for a in args]
58 +
59 +def augment(*args, hflip=True, rot=True):
60 + hflip = hflip and random.random() < 0.5
61 + vflip = rot and random.random() < 0.5
62 + rot90 = rot and random.random() < 0.5
63 +
64 + def _augment(img):
65 + if hflip: img = img[:, ::-1, :]
66 + if vflip: img = img[::-1, :, :]
67 + if rot90: img = img.transpose(1, 0, 2)
68 +
69 + return img
70 +
71 + return [_augment(a) for a in args]
72 +
1 +import os
2 +
3 +from data import common
4 +
5 +import numpy as np
6 +import imageio
7 +
8 +import torch
9 +import torch.utils.data as data
10 +
11 +class Demo(data.Dataset):
12 + def __init__(self, args, name='Demo', train=False, benchmark=False):
13 + self.args = args
14 + self.name = name
15 + self.scale = args.scale
16 + self.idx_scale = 0
17 + self.train = False
18 + self.benchmark = benchmark
19 +
20 + self.filelist = []
21 + for f in os.listdir(args.dir_demo):
22 + if f.find('.png') >= 0 or f.find('.jp') >= 0:
23 + self.filelist.append(os.path.join(args.dir_demo, f))
24 + self.filelist.sort()
25 +
26 + def __getitem__(self, idx):
27 + filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28 + lr = imageio.imread(self.filelist[idx])
29 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30 + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31 +
32 + return lr_t, -1, filename
33 +
34 + def __len__(self):
35 + return len(self.filelist)
36 +
37 + def set_scale(self, idx_scale):
38 + self.idx_scale = idx_scale
39 +
1 +import os
2 +from data import srdata
3 +
4 +class DIV2K(srdata.SRData):
5 + def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6 + data_range = [r.split('-') for r in args.data_range.split('/')]
7 + if train:
8 + data_range = data_range[0]
9 + else:
10 + if args.test_only and len(data_range) == 1:
11 + data_range = data_range[0]
12 + else:
13 + data_range = data_range[1]
14 +
15 + self.begin, self.end = list(map(lambda x: int(x), data_range))
16 + super(DIV2K, self).__init__(
17 + args, name=name, train=train, benchmark=benchmark
18 + )
19 +
20 + def _scan(self):
21 + names_hr, names_lr = super(DIV2K, self)._scan()
22 + names_hr = names_hr[self.begin - 1:self.end]
23 + names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24 +
25 + return names_hr, names_lr
26 +
27 + def _set_filesystem(self, dir_data):
28 + super(DIV2K, self)._set_filesystem(dir_data)
29 + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30 + self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31 + if self.input_large: self.dir_lr += 'L'
32 +
1 +import os
2 +from data import srdata
3 +from data import div2k
4 +
5 +class DIV2KJPEG(div2k.DIV2K):
6 + def __init__(self, args, name='', train=True, benchmark=False):
7 + self.q_factor = int(name.replace('DIV2K-Q', ''))
8 + super(DIV2KJPEG, self).__init__(
9 + args, name=name, train=train, benchmark=benchmark
10 + )
11 +
12 + def _set_filesystem(self, dir_data):
13 + self.apath = os.path.join(dir_data, 'DIV2K')
14 + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15 + self.dir_lr = os.path.join(
16 + self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17 + )
18 + if self.input_large: self.dir_lr += 'L'
19 + self.ext = ('.png', '.jpg')
20 +
1 +from data import srdata
2 +
3 +class SR291(srdata.SRData):
4 + def __init__(self, args, name='SR291', train=True, benchmark=False):
5 + super(SR291, self).__init__(args, name=name)
6 +
1 +import os
2 +import glob
3 +import random
4 +import pickle
5 +
6 +from data import common
7 +
8 +import numpy as np
9 +import imageio
10 +import torch
11 +import torch.utils.data as data
12 +
13 +class SRData(data.Dataset):
14 + def __init__(self, args, name='', train=True, benchmark=False):
15 + self.args = args
16 + self.name = name
17 + self.train = train
18 + self.split = 'train' if train else 'test'
19 + self.do_eval = True
20 + self.benchmark = benchmark
21 + self.input_large = (args.model == 'VDSR')
22 + self.scale = args.scale
23 + self.idx_scale = 0
24 +
25 + self._set_filesystem(args.dir_data)
26 + if args.ext.find('img') < 0:
27 + path_bin = os.path.join(self.apath, 'bin')
28 + os.makedirs(path_bin, exist_ok=True)
29 +
30 + list_hr, list_lr = self._scan()
31 + if args.ext.find('img') >= 0 or benchmark:
32 + self.images_hr, self.images_lr = list_hr, list_lr
33 + elif args.ext.find('sep') >= 0:
34 + os.makedirs(
35 + self.dir_hr.replace(self.apath, path_bin),
36 + exist_ok=True
37 + )
38 + for s in self.scale:
39 + os.makedirs(
40 + os.path.join(
41 + self.dir_lr.replace(self.apath, path_bin),
42 + 'X{}'.format(s)
43 + ),
44 + exist_ok=True
45 + )
46 +
47 + self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48 + for h in list_hr:
49 + b = h.replace(self.apath, path_bin)
50 + b = b.replace(self.ext[0], '.pt')
51 + self.images_hr.append(b)
52 + self._check_and_load(args.ext, h, b, verbose=True)
53 + for i, ll in enumerate(list_lr):
54 + for l in ll:
55 + b = l.replace(self.apath, path_bin)
56 + b = b.replace(self.ext[1], '.pt')
57 + self.images_lr[i].append(b)
58 + self._check_and_load(args.ext, l, b, verbose=True)
59 + if train:
60 + n_patches = args.batch_size * args.test_every
61 + n_images = len(args.data_train) * len(self.images_hr)
62 + if n_images == 0:
63 + self.repeat = 0
64 + else:
65 + self.repeat = max(n_patches // n_images, 1)
66 +
67 + # Below functions as used to prepare images
68 + def _scan(self):
69 + names_hr = sorted(
70 + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71 + )
72 + names_lr = [[] for _ in self.scale]
73 + for f in names_hr:
74 + filename, _ = os.path.splitext(os.path.basename(f))
75 + for si, s in enumerate(self.scale):
76 + names_lr[si].append(os.path.join(
77 + self.dir_lr, 'X{}/{}x{}{}'.format(
78 + s, filename, s, self.ext[1]
79 + )
80 + ))
81 +
82 + return names_hr, names_lr
83 +
84 + def _set_filesystem(self, dir_data):
85 + self.apath = os.path.join(dir_data, self.name)
86 + self.dir_hr = os.path.join(self.apath, 'HR')
87 + self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88 + if self.input_large: self.dir_lr += 'L'
89 + self.ext = ('.png', '.png')
90 +
91 + def _check_and_load(self, ext, img, f, verbose=True):
92 + if not os.path.isfile(f) or ext.find('reset') >= 0:
93 + if verbose:
94 + print('Making a binary: {}'.format(f))
95 + with open(f, 'wb') as _f:
96 + pickle.dump(imageio.imread(img), _f)
97 +
98 + def __getitem__(self, idx):
99 + lr, hr, filename = self._load_file(idx)
100 + pair = self.get_patch(lr, hr)
101 + pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102 + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103 +
104 + return pair_t[0], pair_t[1], filename
105 +
106 + def __len__(self):
107 + if self.train:
108 + return len(self.images_hr) * self.repeat
109 + else:
110 + return len(self.images_hr)
111 +
112 + def _get_index(self, idx):
113 + if self.train:
114 + return idx % len(self.images_hr)
115 + else:
116 + return idx
117 +
118 + def _load_file(self, idx):
119 + idx = self._get_index(idx)
120 + f_hr = self.images_hr[idx]
121 + f_lr = self.images_lr[self.idx_scale][idx]
122 +
123 + filename, _ = os.path.splitext(os.path.basename(f_hr))
124 + if self.args.ext == 'img' or self.benchmark:
125 + hr = imageio.imread(f_hr)
126 + lr = imageio.imread(f_lr)
127 + elif self.args.ext.find('sep') >= 0:
128 + with open(f_hr, 'rb') as _f:
129 + hr = pickle.load(_f)
130 + with open(f_lr, 'rb') as _f:
131 + lr = pickle.load(_f)
132 +
133 + return lr, hr, filename
134 +
135 + def get_patch(self, lr, hr):
136 + scale = self.scale[self.idx_scale]
137 + if self.train:
138 + lr, hr = common.get_patch(
139 + lr, hr,
140 + patch_size=self.args.patch_size,
141 + scale=scale,
142 + multi=(len(self.scale) > 1),
143 + input_large=self.input_large
144 + )
145 + if not self.args.no_augment: lr, hr = common.augment(lr, hr)
146 + else:
147 + ih, iw = lr.shape[:2]
148 + hr = hr[0:ih * scale, 0:iw * scale]
149 +
150 + return lr, hr
151 +
152 + def set_scale(self, idx_scale):
153 + if not self.input_large:
154 + self.idx_scale = idx_scale
155 + else:
156 + self.idx_scale = random.randint(0, len(self.scale) - 1)
157 +
1 +import os
2 +
3 +from data import common
4 +
5 +import cv2
6 +import numpy as np
7 +import imageio
8 +
9 +import torch
10 +import torch.utils.data as data
11 +
12 +class Video(data.Dataset):
13 + def __init__(self, args, name='Video', train=False, benchmark=False):
14 + self.args = args
15 + self.name = name
16 + self.scale = args.scale
17 + self.idx_scale = 0
18 + self.train = False
19 + self.do_eval = False
20 + self.benchmark = benchmark
21 +
22 + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23 + self.vidcap = cv2.VideoCapture(args.dir_demo)
24 + self.n_frames = 0
25 + self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26 +
27 + def __getitem__(self, idx):
28 + success, lr = self.vidcap.read()
29 + if success:
30 + self.n_frames += 1
31 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32 + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33 +
34 + return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35 + else:
36 + vidcap.release()
37 + return None
38 +
39 + def __len__(self):
40 + return self.total_frames
41 +
42 + def set_scale(self, idx_scale):
43 + self.idx_scale = idx_scale
44 +
1 +import threading
2 +import random
3 +
4 +import torch
5 +import torch.multiprocessing as multiprocessing
6 +from torch.utils.data import DataLoader
7 +from torch.utils.data import SequentialSampler
8 +from torch.utils.data import RandomSampler
9 +from torch.utils.data import BatchSampler
10 +from torch.utils.data import _utils
11 +from torch.utils.data.dataloader import _DataLoaderIter
12 +
13 +from torch.utils.data._utils import collate
14 +from torch.utils.data._utils import signal_handling
15 +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16 +from torch.utils.data._utils import ExceptionWrapper
17 +from torch.utils.data._utils import IS_WINDOWS
18 +from torch.utils.data._utils.worker import ManagerWatchdog
19 +
20 +from torch._six import queue
21 +
22 +def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23 + try:
24 + collate._use_shared_memory = True
25 + signal_handling._set_worker_signal_handlers()
26 +
27 + torch.set_num_threads(1)
28 + random.seed(seed)
29 + torch.manual_seed(seed)
30 +
31 + data_queue.cancel_join_thread()
32 +
33 + if init_fn is not None:
34 + init_fn(worker_id)
35 +
36 + watchdog = ManagerWatchdog()
37 +
38 + while watchdog.is_alive():
39 + try:
40 + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41 + except queue.Empty:
42 + continue
43 +
44 + if r is None:
45 + assert done_event.is_set()
46 + return
47 + elif done_event.is_set():
48 + continue
49 +
50 + idx, batch_indices = r
51 + try:
52 + idx_scale = 0
53 + if len(scale) > 1 and dataset.train:
54 + idx_scale = random.randrange(0, len(scale))
55 + dataset.set_scale(idx_scale)
56 +
57 + samples = collate_fn([dataset[i] for i in batch_indices])
58 + samples.append(idx_scale)
59 + except Exception:
60 + data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61 + else:
62 + data_queue.put((idx, samples))
63 + del samples
64 +
65 + except KeyboardInterrupt:
66 + pass
67 +
68 +class _MSDataLoaderIter(_DataLoaderIter):
69 +
70 + def __init__(self, loader):
71 + self.dataset = loader.dataset
72 + self.scale = loader.scale
73 + self.collate_fn = loader.collate_fn
74 + self.batch_sampler = loader.batch_sampler
75 + self.num_workers = loader.num_workers
76 + self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77 + self.timeout = loader.timeout
78 +
79 + self.sample_iter = iter(self.batch_sampler)
80 +
81 + base_seed = torch.LongTensor(1).random_().item()
82 +
83 + if self.num_workers > 0:
84 + self.worker_init_fn = loader.worker_init_fn
85 + self.worker_queue_idx = 0
86 + self.worker_result_queue = multiprocessing.Queue()
87 + self.batches_outstanding = 0
88 + self.worker_pids_set = False
89 + self.shutdown = False
90 + self.send_idx = 0
91 + self.rcvd_idx = 0
92 + self.reorder_dict = {}
93 + self.done_event = multiprocessing.Event()
94 +
95 + base_seed = torch.LongTensor(1).random_()[0]
96 +
97 + self.index_queues = []
98 + self.workers = []
99 + for i in range(self.num_workers):
100 + index_queue = multiprocessing.Queue()
101 + index_queue.cancel_join_thread()
102 + w = multiprocessing.Process(
103 + target=_ms_loop,
104 + args=(
105 + self.dataset,
106 + index_queue,
107 + self.worker_result_queue,
108 + self.done_event,
109 + self.collate_fn,
110 + self.scale,
111 + base_seed + i,
112 + self.worker_init_fn,
113 + i
114 + )
115 + )
116 + w.daemon = True
117 + w.start()
118 + self.index_queues.append(index_queue)
119 + self.workers.append(w)
120 +
121 + if self.pin_memory:
122 + self.data_queue = queue.Queue()
123 + pin_memory_thread = threading.Thread(
124 + target=_utils.pin_memory._pin_memory_loop,
125 + args=(
126 + self.worker_result_queue,
127 + self.data_queue,
128 + torch.cuda.current_device(),
129 + self.done_event
130 + )
131 + )
132 + pin_memory_thread.daemon = True
133 + pin_memory_thread.start()
134 + self.pin_memory_thread = pin_memory_thread
135 + else:
136 + self.data_queue = self.worker_result_queue
137 +
138 + _utils.signal_handling._set_worker_pids(
139 + id(self), tuple(w.pid for w in self.workers)
140 + )
141 + _utils.signal_handling._set_SIGCHLD_handler()
142 + self.worker_pids_set = True
143 +
144 + for _ in range(2 * self.num_workers):
145 + self._put_indices()
146 +
147 +
148 +class MSDataLoader(DataLoader):
149 +
150 + def __init__(self, cfg, *args, **kwargs):
151 + super(MSDataLoader, self).__init__(
152 + *args, **kwargs, num_workers=cfg.n_threads
153 + )
154 + self.scale = cfg.scale
155 +
156 + def __iter__(self):
157 + return _MSDataLoaderIter(self)
158 +
1 +# EDSR baseline model (x2) + JPEG augmentation
2 +python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
3 +#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
4 +
5 +# EDSR baseline model (x3) - from EDSR baseline model (x2)
6 +#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
7 +
8 +# EDSR baseline model (x4) - from EDSR baseline model (x2)
9 +#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
10 +
11 +# EDSR in the paper (x2)
12 +#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
13 +
14 +# EDSR in the paper (x3) - from EDSR (x2)
15 +#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
16 +
17 +# EDSR in the paper (x4) - from EDSR (x2)
18 +#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
19 +
20 +# MDSR baseline model
21 +#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
22 +
23 +# MDSR in the paper
24 +#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
25 +
26 +# Standard benchmarks (Ex. EDSR_baseline_x4)
27 +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
28 +
29 +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
30 +
31 +# Test your own images
32 +#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
33 +
34 +# Advanced - Test with JPEG images
35 +#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
36 +
37 +# Advanced - Training with adversarial loss
38 +#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
39 +
40 +# RDN BI model (x2)
41 +#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
42 +# RDN BI model (x3)
43 +#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset
44 +# RDN BI model (x4)
45 +#python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset
46 +
47 +# RCAN_BIX2_G10R20P48, input=48x48, output=96x96
48 +# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0
49 +#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96
50 +# RCAN_BIX3_G10R20P48, input=48x48, output=144x144
51 +#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
52 +# RCAN_BIX4_G10R20P48, input=48x48, output=192x192
53 +#python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
54 +# RCAN_BIX8_G10R20P48, input=48x48, output=384x384
55 +#python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
56 +
1 +import os
2 +from importlib import import_module
3 +
4 +import matplotlib
5 +matplotlib.use('Agg')
6 +import matplotlib.pyplot as plt
7 +
8 +import numpy as np
9 +
10 +import torch
11 +import torch.nn as nn
12 +import torch.nn.functional as F
13 +
14 +class Loss(nn.modules.loss._Loss):
15 + def __init__(self, args, ckp):
16 + super(Loss, self).__init__()
17 + print('Preparing loss function:')
18 +
19 + self.n_GPUs = args.n_GPUs
20 + self.loss = []
21 + self.loss_module = nn.ModuleList()
22 + for loss in args.loss.split('+'):
23 + weight, loss_type = loss.split('*')
24 + if loss_type == 'MSE':
25 + loss_function = nn.MSELoss()
26 + elif loss_type == 'L1':
27 + loss_function = nn.L1Loss()
28 + elif loss_type.find('VGG') >= 0:
29 + module = import_module('loss.vgg')
30 + loss_function = getattr(module, 'VGG')(
31 + loss_type[3:],
32 + rgb_range=args.rgb_range
33 + )
34 + elif loss_type.find('GAN') >= 0:
35 + module = import_module('loss.adversarial')
36 + loss_function = getattr(module, 'Adversarial')(
37 + args,
38 + loss_type
39 + )
40 +
41 + self.loss.append({
42 + 'type': loss_type,
43 + 'weight': float(weight),
44 + 'function': loss_function}
45 + )
46 + if loss_type.find('GAN') >= 0:
47 + self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
48 +
49 + if len(self.loss) > 1:
50 + self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
51 +
52 + for l in self.loss:
53 + if l['function'] is not None:
54 + print('{:.3f} * {}'.format(l['weight'], l['type']))
55 + self.loss_module.append(l['function'])
56 +
57 + self.log = torch.Tensor()
58 +
59 + device = torch.device('cpu' if args.cpu else 'cuda')
60 + self.loss_module.to(device)
61 + if args.precision == 'half': self.loss_module.half()
62 + if not args.cpu and args.n_GPUs > 1:
63 + self.loss_module = nn.DataParallel(
64 + self.loss_module, range(args.n_GPUs)
65 + )
66 +
67 + if args.load != '': self.load(ckp.dir, cpu=args.cpu)
68 +
69 + def forward(self, sr, hr):
70 + losses = []
71 + for i, l in enumerate(self.loss):
72 + if l['function'] is not None:
73 + loss = l['function'](sr, hr)
74 + effective_loss = l['weight'] * loss
75 + losses.append(effective_loss)
76 + self.log[-1, i] += effective_loss.item()
77 + elif l['type'] == 'DIS':
78 + self.log[-1, i] += self.loss[i - 1]['function'].loss
79 +
80 + loss_sum = sum(losses)
81 + if len(self.loss) > 1:
82 + self.log[-1, -1] += loss_sum.item()
83 +
84 + return loss_sum
85 +
86 + def step(self):
87 + for l in self.get_loss_module():
88 + if hasattr(l, 'scheduler'):
89 + l.scheduler.step()
90 +
91 + def start_log(self):
92 + self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
93 +
94 + def end_log(self, n_batches):
95 + self.log[-1].div_(n_batches)
96 +
97 + def display_loss(self, batch):
98 + n_samples = batch + 1
99 + log = []
100 + for l, c in zip(self.loss, self.log[-1]):
101 + log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
102 +
103 + return ''.join(log)
104 +
105 + def plot_loss(self, apath, epoch):
106 + axis = np.linspace(1, epoch, epoch)
107 + for i, l in enumerate(self.loss):
108 + label = '{} Loss'.format(l['type'])
109 + fig = plt.figure()
110 + plt.title(label)
111 + plt.plot(axis, self.log[:, i].numpy(), label=label)
112 + plt.legend()
113 + plt.xlabel('Epochs')
114 + plt.ylabel('Loss')
115 + plt.grid(True)
116 + plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
117 + plt.close(fig)
118 +
119 + def get_loss_module(self):
120 + if self.n_GPUs == 1:
121 + return self.loss_module
122 + else:
123 + return self.loss_module.module
124 +
125 + def save(self, apath):
126 + torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
127 + torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
128 +
129 + def load(self, apath, cpu=False):
130 + if cpu:
131 + kwargs = {'map_location': lambda storage, loc: storage}
132 + else:
133 + kwargs = {}
134 +
135 + self.load_state_dict(torch.load(
136 + os.path.join(apath, 'loss.pt'),
137 + **kwargs
138 + ))
139 + self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
140 + for l in self.get_loss_module():
141 + if hasattr(l, 'scheduler'):
142 + for _ in range(len(self.log)): l.scheduler.step()
143 +
1 +import utility
2 +from types import SimpleNamespace
3 +
4 +from model import common
5 +from loss import discriminator
6 +
7 +import torch
8 +import torch.nn as nn
9 +import torch.nn.functional as F
10 +import torch.optim as optim
11 +
12 +class Adversarial(nn.Module):
13 + def __init__(self, args, gan_type):
14 + super(Adversarial, self).__init__()
15 + self.gan_type = gan_type
16 + self.gan_k = args.gan_k
17 + self.dis = discriminator.Discriminator(args)
18 + if gan_type == 'WGAN_GP':
19 + # see https://arxiv.org/pdf/1704.00028.pdf pp.4
20 + optim_dict = {
21 + 'optimizer': 'ADAM',
22 + 'betas': (0, 0.9),
23 + 'epsilon': 1e-8,
24 + 'lr': 1e-5,
25 + 'weight_decay': args.weight_decay,
26 + 'decay': args.decay,
27 + 'gamma': args.gamma
28 + }
29 + optim_args = SimpleNamespace(**optim_dict)
30 + else:
31 + optim_args = args
32 +
33 + self.optimizer = utility.make_optimizer(optim_args, self.dis)
34 +
35 + def forward(self, fake, real):
36 + # updating discriminator...
37 + self.loss = 0
38 + fake_detach = fake.detach() # do not backpropagate through G
39 + for _ in range(self.gan_k):
40 + self.optimizer.zero_grad()
41 + # d: B x 1 tensor
42 + d_fake = self.dis(fake_detach)
43 + d_real = self.dis(real)
44 + retain_graph = False
45 + if self.gan_type == 'GAN':
46 + loss_d = self.bce(d_real, d_fake)
47 + elif self.gan_type.find('WGAN') >= 0:
48 + loss_d = (d_fake - d_real).mean()
49 + if self.gan_type.find('GP') >= 0:
50 + epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
51 + hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
52 + hat.requires_grad = True
53 + d_hat = self.dis(hat)
54 + gradients = torch.autograd.grad(
55 + outputs=d_hat.sum(), inputs=hat,
56 + retain_graph=True, create_graph=True, only_inputs=True
57 + )[0]
58 + gradients = gradients.view(gradients.size(0), -1)
59 + gradient_norm = gradients.norm(2, dim=1)
60 + gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
61 + loss_d += gradient_penalty
62 + # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
63 + elif self.gan_type == 'RGAN':
64 + better_real = d_real - d_fake.mean(dim=0, keepdim=True)
65 + better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
66 + loss_d = self.bce(better_real, better_fake)
67 + retain_graph = True
68 +
69 + # Discriminator update
70 + self.loss += loss_d.item()
71 + loss_d.backward(retain_graph=retain_graph)
72 + self.optimizer.step()
73 +
74 + if self.gan_type == 'WGAN':
75 + for p in self.dis.parameters():
76 + p.data.clamp_(-1, 1)
77 +
78 + self.loss /= self.gan_k
79 +
80 + # updating generator...
81 + d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
82 + if self.gan_type == 'GAN':
83 + label_real = torch.ones_like(d_fake_bp)
84 + loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
85 + elif self.gan_type.find('WGAN') >= 0:
86 + loss_g = -d_fake_bp.mean()
87 + elif self.gan_type == 'RGAN':
88 + better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
89 + better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
90 + loss_g = self.bce(better_fake, better_real)
91 +
92 + # Generator loss
93 + return loss_g
94 +
95 + def state_dict(self, *args, **kwargs):
96 + state_discriminator = self.dis.state_dict(*args, **kwargs)
97 + state_optimizer = self.optimizer.state_dict()
98 +
99 + return dict(**state_discriminator, **state_optimizer)
100 +
101 + def bce(self, real, fake):
102 + label_real = torch.ones_like(real)
103 + label_fake = torch.zeros_like(fake)
104 + bce_real = F.binary_cross_entropy_with_logits(real, label_real)
105 + bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
106 + bce_loss = bce_real + bce_fake
107 + return bce_loss
108 +
109 +# Some references
110 +# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
111 +# OR
112 +# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +class Discriminator(nn.Module):
6 + '''
7 + output is not normalized
8 + '''
9 + def __init__(self, args):
10 + super(Discriminator, self).__init__()
11 +
12 + in_channels = args.n_colors
13 + out_channels = 64
14 + depth = 7
15 +
16 + def _block(_in_channels, _out_channels, stride=1):
17 + return nn.Sequential(
18 + nn.Conv2d(
19 + _in_channels,
20 + _out_channels,
21 + 3,
22 + padding=1,
23 + stride=stride,
24 + bias=False
25 + ),
26 + nn.BatchNorm2d(_out_channels),
27 + nn.LeakyReLU(negative_slope=0.2, inplace=True)
28 + )
29 +
30 + m_features = [_block(in_channels, out_channels)]
31 + for i in range(depth):
32 + in_channels = out_channels
33 + if i % 2 == 1:
34 + stride = 1
35 + out_channels *= 2
36 + else:
37 + stride = 2
38 + m_features.append(_block(in_channels, out_channels, stride=stride))
39 +
40 + patch_size = args.patch_size // (2**((depth + 1) // 2))
41 + m_classifier = [
42 + nn.Linear(out_channels * patch_size**2, 1024),
43 + nn.LeakyReLU(negative_slope=0.2, inplace=True),
44 + nn.Linear(1024, 1)
45 + ]
46 +
47 + self.features = nn.Sequential(*m_features)
48 + self.classifier = nn.Sequential(*m_classifier)
49 +
50 + def forward(self, x):
51 + features = self.features(x)
52 + output = self.classifier(features.view(features.size(0), -1))
53 +
54 + return output
55 +
1 +from model import common
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +import torchvision.models as models
7 +
8 +class VGG(nn.Module):
9 + def __init__(self, conv_index, rgb_range=1):
10 + super(VGG, self).__init__()
11 + vgg_features = models.vgg19(pretrained=True).features
12 + modules = [m for m in vgg_features]
13 + if conv_index.find('22') >= 0:
14 + self.vgg = nn.Sequential(*modules[:8])
15 + elif conv_index.find('54') >= 0:
16 + self.vgg = nn.Sequential(*modules[:35])
17 +
18 + vgg_mean = (0.485, 0.456, 0.406)
19 + vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20 + self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21 + for p in self.parameters():
22 + p.requires_grad = False
23 +
24 + def forward(self, sr, hr):
25 + def _forward(x):
26 + x = self.sub_mean(x)
27 + x = self.vgg(x)
28 + return x
29 +
30 + vgg_sr = _forward(sr)
31 + with torch.no_grad():
32 + vgg_hr = _forward(hr.detach())
33 +
34 + loss = F.mse_loss(vgg_sr, vgg_hr)
35 +
36 + return loss
1 +import torch
2 +
3 +import utility
4 +import data
5 +import model
6 +import loss
7 +from option import args
8 +from trainer import Trainer
9 +
10 +torch.manual_seed(args.seed)
11 +checkpoint = utility.checkpoint(args)
12 +
13 +def main():
14 + global model
15 + if args.data_test == ['video']:
16 + from videotester import VideoTester
17 + model = model.Model(args, checkpoint)
18 + t = VideoTester(args, model, checkpoint)
19 + t.test()
20 + else:
21 + if checkpoint.ok:
22 + loader = data.Data(args)
23 + _model = model.Model(args, checkpoint)
24 + _loss = loss.Loss(args, checkpoint) if not args.test_only else None
25 + t = Trainer(args, loader, _model, _loss, checkpoint)
26 + while not t.terminate():
27 + t.train()
28 + t.test()
29 +
30 + checkpoint.done()
31 +
32 +if __name__ == '__main__':
33 + main()
1 +import os
2 +from importlib import import_module
3 +
4 +import torch
5 +import torch.nn as nn
6 +import torch.nn.parallel as P
7 +import torch.utils.model_zoo
8 +
9 +class Model(nn.Module):
10 + def __init__(self, args, ckp):
11 + super(Model, self).__init__()
12 + print('Making model...')
13 +
14 + self.scale = args.scale
15 + self.idx_scale = 0
16 + self.input_large = (args.model == 'VDSR')
17 + self.self_ensemble = args.self_ensemble
18 + self.chop = args.chop
19 + self.precision = args.precision
20 + self.cpu = args.cpu
21 + self.device = torch.device('cpu' if args.cpu else 'cuda')
22 + self.n_GPUs = args.n_GPUs
23 + self.save_models = args.save_models
24 +
25 + module = import_module('model.' + args.model.lower())
26 + self.model = module.make_model(args).to(self.device)
27 + if args.precision == 'half':
28 + self.model.half()
29 +
30 + self.load(
31 + ckp.get_path('model'),
32 + pre_train=args.pre_train,
33 + resume=args.resume,
34 + cpu=args.cpu
35 + )
36 + print(self.model, file=ckp.log_file)
37 +
38 + def forward(self, x, idx_scale):
39 + self.idx_scale = idx_scale
40 + if hasattr(self.model, 'set_scale'):
41 + self.model.set_scale(idx_scale)
42 +
43 + if self.training:
44 + if self.n_GPUs > 1:
45 + return P.data_parallel(self.model, x, range(self.n_GPUs))
46 + else:
47 + return self.model(x)
48 + else:
49 + if self.chop:
50 + forward_function = self.forward_chop
51 + else:
52 + forward_function = self.model.forward
53 +
54 + if self.self_ensemble:
55 + return self.forward_x8(x, forward_function=forward_function)
56 + else:
57 + return forward_function(x)
58 +
59 + def save(self, apath, epoch, is_best=False):
60 + save_dirs = [os.path.join(apath, 'model_latest.pt')]
61 +
62 + if is_best:
63 + save_dirs.append(os.path.join(apath, 'model_best.pt'))
64 + if self.save_models:
65 + save_dirs.append(
66 + os.path.join(apath, 'model_{}.pt'.format(epoch))
67 + )
68 +
69 + for s in save_dirs:
70 + torch.save(self.model.state_dict(), s)
71 +
72 + def load(self, apath, pre_train='', resume=-1, cpu=False):
73 + load_from = None
74 + kwargs = {}
75 + if cpu:
76 + kwargs = {'map_location': lambda storage, loc: storage}
77 +
78 + if resume == -1:
79 + load_from = torch.load(
80 + os.path.join(apath, 'model_latest.pt'),
81 + **kwargs
82 + )
83 + elif resume == 0:
84 + if pre_train == 'download':
85 + print('Download the model')
86 + dir_model = os.path.join('..', 'models')
87 + os.makedirs(dir_model, exist_ok=True)
88 + load_from = torch.utils.model_zoo.load_url(
89 + self.model.url,
90 + model_dir=dir_model,
91 + **kwargs
92 + )
93 + elif pre_train:
94 + print('Load the model from {}'.format(pre_train))
95 + load_from = torch.load(pre_train, **kwargs)
96 + else:
97 + load_from = torch.load(
98 + os.path.join(apath, 'model_{}.pt'.format(resume)),
99 + **kwargs
100 + )
101 +
102 + if load_from:
103 + self.model.load_state_dict(load_from, strict=False)
104 +
105 + def forward_chop(self, *args, shave=10, min_size=160000):
106 + scale = 1 if self.input_large else self.scale[self.idx_scale]
107 + n_GPUs = min(self.n_GPUs, 4)
108 + # height, width
109 + h, w = args[0].size()[-2:]
110 +
111 + top = slice(0, h//2 + shave)
112 + bottom = slice(h - h//2 - shave, h)
113 + left = slice(0, w//2 + shave)
114 + right = slice(w - w//2 - shave, w)
115 + x_chops = [torch.cat([
116 + a[..., top, left],
117 + a[..., top, right],
118 + a[..., bottom, left],
119 + a[..., bottom, right]
120 + ]) for a in args]
121 +
122 + y_chops = []
123 + if h * w < 4 * min_size:
124 + for i in range(0, 4, n_GPUs):
125 + x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
126 + y = P.data_parallel(self.model, *x, range(n_GPUs))
127 + if not isinstance(y, list): y = [y]
128 + if not y_chops:
129 + y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
130 + else:
131 + for y_chop, _y in zip(y_chops, y):
132 + y_chop.extend(_y.chunk(n_GPUs, dim=0))
133 + else:
134 + for p in zip(*x_chops):
135 + y = self.forward_chop(*p, shave=shave, min_size=min_size)
136 + if not isinstance(y, list): y = [y]
137 + if not y_chops:
138 + y_chops = [[_y] for _y in y]
139 + else:
140 + for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
141 +
142 + h *= scale
143 + w *= scale
144 + top = slice(0, h//2)
145 + bottom = slice(h - h//2, h)
146 + bottom_r = slice(h//2 - h, None)
147 + left = slice(0, w//2)
148 + right = slice(w - w//2, w)
149 + right_r = slice(w//2 - w, None)
150 +
151 + # batch size, number of color channels
152 + b, c = y_chops[0][0].size()[:-2]
153 + y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
154 + for y_chop, _y in zip(y_chops, y):
155 + _y[..., top, left] = y_chop[0][..., top, left]
156 + _y[..., top, right] = y_chop[1][..., top, right_r]
157 + _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
158 + _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
159 +
160 + if len(y) == 1: y = y[0]
161 +
162 + return y
163 +
164 + def forward_x8(self, *args, forward_function=None):
165 + def _transform(v, op):
166 + if self.precision != 'single': v = v.float()
167 +
168 + v2np = v.data.cpu().numpy()
169 + if op == 'v':
170 + tfnp = v2np[:, :, :, ::-1].copy()
171 + elif op == 'h':
172 + tfnp = v2np[:, :, ::-1, :].copy()
173 + elif op == 't':
174 + tfnp = v2np.transpose((0, 1, 3, 2)).copy()
175 +
176 + ret = torch.Tensor(tfnp).to(self.device)
177 + if self.precision == 'half': ret = ret.half()
178 +
179 + return ret
180 +
181 + list_x = []
182 + for a in args:
183 + x = [a]
184 + for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
185 +
186 + list_x.append(x)
187 +
188 + list_y = []
189 + for x in zip(*list_x):
190 + y = forward_function(*x)
191 + if not isinstance(y, list): y = [y]
192 + if not list_y:
193 + list_y = [[_y] for _y in y]
194 + else:
195 + for _list_y, _y in zip(list_y, y): _list_y.append(_y)
196 +
197 + for _list_y in list_y:
198 + for i in range(len(_list_y)):
199 + if i > 3:
200 + _list_y[i] = _transform(_list_y[i], 't')
201 + if i % 4 > 1:
202 + _list_y[i] = _transform(_list_y[i], 'h')
203 + if (i % 4) % 2 == 1:
204 + _list_y[i] = _transform(_list_y[i], 'v')
205 +
206 + y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
207 + if len(y) == 1: y = y[0]
208 +
209 + return y
1 +import math
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +
7 +def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 + return nn.Conv2d(
9 + in_channels, out_channels, kernel_size,
10 + padding=(kernel_size//2), bias=bias)
11 +
12 +class MeanShift(nn.Conv2d):
13 + def __init__(
14 + self, rgb_range,
15 + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
16 +
17 + super(MeanShift, self).__init__(3, 3, kernel_size=1)
18 + std = torch.Tensor(rgb_std)
19 + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
20 + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
21 + for p in self.parameters():
22 + p.requires_grad = False
23 +
24 +class BasicBlock(nn.Sequential):
25 + def __init__(
26 + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
27 + bn=True, act=nn.ReLU(True)):
28 +
29 + m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
30 + if bn:
31 + m.append(nn.BatchNorm2d(out_channels))
32 + if act is not None:
33 + m.append(act)
34 +
35 + super(BasicBlock, self).__init__(*m)
36 +
37 +class ResBlock(nn.Module):
38 + def __init__(
39 + self, conv, n_feats, kernel_size,
40 + bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
41 +
42 + super(ResBlock, self).__init__()
43 + m = []
44 + for i in range(2):
45 + m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
46 + if bn:
47 + m.append(nn.BatchNorm2d(n_feats))
48 + if i == 0:
49 + m.append(act)
50 +
51 + self.body = nn.Sequential(*m)
52 + self.res_scale = res_scale
53 +
54 + def forward(self, x):
55 + res = self.body(x).mul(self.res_scale)
56 + res += x
57 +
58 + return res
59 +
60 +class Upsampler(nn.Sequential):
61 + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
62 +
63 + m = []
64 + if (scale & (scale - 1)) == 0: # Is scale = 2^n?
65 + for _ in range(int(math.log(scale, 2))):
66 + m.append(conv(n_feats, 4 * n_feats, 3, bias))
67 + m.append(nn.PixelShuffle(2))
68 + if bn:
69 + m.append(nn.BatchNorm2d(n_feats))
70 + if act == 'relu':
71 + m.append(nn.ReLU(True))
72 + elif act == 'prelu':
73 + m.append(nn.PReLU(n_feats))
74 +
75 + elif scale == 3:
76 + m.append(conv(n_feats, 9 * n_feats, 3, bias))
77 + m.append(nn.PixelShuffle(3))
78 + if bn:
79 + m.append(nn.BatchNorm2d(n_feats))
80 + if act == 'relu':
81 + m.append(nn.ReLU(True))
82 + elif act == 'prelu':
83 + m.append(nn.PReLU(n_feats))
84 + else:
85 + raise NotImplementedError
86 +
87 + super(Upsampler, self).__init__(*m)
88 +
1 +# Deep Back-Projection Networks For Super-Resolution
2 +# https://arxiv.org/abs/1803.02735
3 +
4 +from model import common
5 +
6 +import torch
7 +import torch.nn as nn
8 +
9 +
10 +def make_model(args, parent=False):
11 + return DDBPN(args)
12 +
13 +def projection_conv(in_channels, out_channels, scale, up=True):
14 + kernel_size, stride, padding = {
15 + 2: (6, 2, 2),
16 + 4: (8, 4, 2),
17 + 8: (12, 8, 2)
18 + }[scale]
19 + if up:
20 + conv_f = nn.ConvTranspose2d
21 + else:
22 + conv_f = nn.Conv2d
23 +
24 + return conv_f(
25 + in_channels, out_channels, kernel_size,
26 + stride=stride, padding=padding
27 + )
28 +
29 +class DenseProjection(nn.Module):
30 + def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
31 + super(DenseProjection, self).__init__()
32 + if bottleneck:
33 + self.bottleneck = nn.Sequential(*[
34 + nn.Conv2d(in_channels, nr, 1),
35 + nn.PReLU(nr)
36 + ])
37 + inter_channels = nr
38 + else:
39 + self.bottleneck = None
40 + inter_channels = in_channels
41 +
42 + self.conv_1 = nn.Sequential(*[
43 + projection_conv(inter_channels, nr, scale, up),
44 + nn.PReLU(nr)
45 + ])
46 + self.conv_2 = nn.Sequential(*[
47 + projection_conv(nr, inter_channels, scale, not up),
48 + nn.PReLU(inter_channels)
49 + ])
50 + self.conv_3 = nn.Sequential(*[
51 + projection_conv(inter_channels, nr, scale, up),
52 + nn.PReLU(nr)
53 + ])
54 +
55 + def forward(self, x):
56 + if self.bottleneck is not None:
57 + x = self.bottleneck(x)
58 +
59 + a_0 = self.conv_1(x)
60 + b_0 = self.conv_2(a_0)
61 + e = b_0.sub(x)
62 + a_1 = self.conv_3(e)
63 +
64 + out = a_0.add(a_1)
65 +
66 + return out
67 +
68 +class DDBPN(nn.Module):
69 + def __init__(self, args):
70 + super(DDBPN, self).__init__()
71 + scale = args.scale[0]
72 +
73 + n0 = 128
74 + nr = 32
75 + self.depth = 6
76 +
77 + rgb_mean = (0.4488, 0.4371, 0.4040)
78 + rgb_std = (1.0, 1.0, 1.0)
79 + self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
80 + initial = [
81 + nn.Conv2d(args.n_colors, n0, 3, padding=1),
82 + nn.PReLU(n0),
83 + nn.Conv2d(n0, nr, 1),
84 + nn.PReLU(nr)
85 + ]
86 + self.initial = nn.Sequential(*initial)
87 +
88 + self.upmodules = nn.ModuleList()
89 + self.downmodules = nn.ModuleList()
90 + channels = nr
91 + for i in range(self.depth):
92 + self.upmodules.append(
93 + DenseProjection(channels, nr, scale, True, i > 1)
94 + )
95 + if i != 0:
96 + channels += nr
97 +
98 + channels = nr
99 + for i in range(self.depth - 1):
100 + self.downmodules.append(
101 + DenseProjection(channels, nr, scale, False, i != 0)
102 + )
103 + channels += nr
104 +
105 + reconstruction = [
106 + nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
107 + ]
108 + self.reconstruction = nn.Sequential(*reconstruction)
109 +
110 + self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
111 +
112 + def forward(self, x):
113 + x = self.sub_mean(x)
114 + x = self.initial(x)
115 +
116 + h_list = []
117 + l_list = []
118 + for i in range(self.depth - 1):
119 + if i == 0:
120 + l = x
121 + else:
122 + l = torch.cat(l_list, dim=1)
123 + h_list.append(self.upmodules[i](l))
124 + l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
125 +
126 + h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
127 + out = self.reconstruction(torch.cat(h_list, dim=1))
128 + out = self.add_mean(out)
129 +
130 + return out
131 +
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +url = {
6 + 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
7 + 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
8 + 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
9 + 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
10 + 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
11 + 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
12 +}
13 +
14 +def make_model(args, parent=False):
15 + return EDSR(args)
16 +
17 +class EDSR(nn.Module):
18 + def __init__(self, args, conv=common.default_conv):
19 + super(EDSR, self).__init__()
20 +
21 + n_resblocks = args.n_resblocks
22 + n_feats = args.n_feats
23 + kernel_size = 3
24 + scale = args.scale[0]
25 + act = nn.ReLU(True)
26 + url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
27 + if url_name in url:
28 + self.url = url[url_name]
29 + else:
30 + self.url = None
31 + self.sub_mean = common.MeanShift(args.rgb_range)
32 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
33 +
34 + # define head module
35 + m_head = [conv(args.n_colors, n_feats, kernel_size)]
36 +
37 + # define body module
38 + m_body = [
39 + common.ResBlock(
40 + conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
41 + ) for _ in range(n_resblocks)
42 + ]
43 + m_body.append(conv(n_feats, n_feats, kernel_size))
44 +
45 + # define tail module
46 + m_tail = [
47 + common.Upsampler(conv, scale, n_feats, act=False),
48 + conv(n_feats, args.n_colors, kernel_size)
49 + ]
50 +
51 + self.head = nn.Sequential(*m_head)
52 + self.body = nn.Sequential(*m_body)
53 + self.tail = nn.Sequential(*m_tail)
54 +
55 + def forward(self, x):
56 + x = self.sub_mean(x)
57 + x = self.head(x)
58 +
59 + res = self.body(x)
60 + res += x
61 +
62 + x = self.tail(res)
63 + x = self.add_mean(x)
64 +
65 + return x
66 +
67 + def load_state_dict(self, state_dict, strict=True):
68 + own_state = self.state_dict()
69 + for name, param in state_dict.items():
70 + if name in own_state:
71 + if isinstance(param, nn.Parameter):
72 + param = param.data
73 + try:
74 + own_state[name].copy_(param)
75 + except Exception:
76 + if name.find('tail') == -1:
77 + raise RuntimeError('While copying the parameter named {}, '
78 + 'whose dimensions in the model are {} and '
79 + 'whose dimensions in the checkpoint are {}.'
80 + .format(name, own_state[name].size(), param.size()))
81 + elif strict:
82 + if name.find('tail') == -1:
83 + raise KeyError('unexpected key "{}" in state_dict'
84 + .format(name))
85 +
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +url = {
6 + 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt',
7 + 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt'
8 +}
9 +
10 +def make_model(args, parent=False):
11 + return MDSR(args)
12 +
13 +class MDSR(nn.Module):
14 + def __init__(self, args, conv=common.default_conv):
15 + super(MDSR, self).__init__()
16 + n_resblocks = args.n_resblocks
17 + n_feats = args.n_feats
18 + kernel_size = 3
19 + act = nn.ReLU(True)
20 + self.scale_idx = 0
21 + self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
22 + self.sub_mean = common.MeanShift(args.rgb_range)
23 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
24 +
25 + m_head = [conv(args.n_colors, n_feats, kernel_size)]
26 +
27 + self.pre_process = nn.ModuleList([
28 + nn.Sequential(
29 + common.ResBlock(conv, n_feats, 5, act=act),
30 + common.ResBlock(conv, n_feats, 5, act=act)
31 + ) for _ in args.scale
32 + ])
33 +
34 + m_body = [
35 + common.ResBlock(
36 + conv, n_feats, kernel_size, act=act
37 + ) for _ in range(n_resblocks)
38 + ]
39 + m_body.append(conv(n_feats, n_feats, kernel_size))
40 +
41 + self.upsample = nn.ModuleList([
42 + common.Upsampler(conv, s, n_feats, act=False) for s in args.scale
43 + ])
44 +
45 + m_tail = [conv(n_feats, args.n_colors, kernel_size)]
46 +
47 + self.head = nn.Sequential(*m_head)
48 + self.body = nn.Sequential(*m_body)
49 + self.tail = nn.Sequential(*m_tail)
50 +
51 + def forward(self, x):
52 + x = self.sub_mean(x)
53 + x = self.head(x)
54 + x = self.pre_process[self.scale_idx](x)
55 +
56 + res = self.body(x)
57 + res += x
58 +
59 + x = self.upsample[self.scale_idx](res)
60 + x = self.tail(x)
61 + x = self.add_mean(x)
62 +
63 + return x
64 +
65 + def set_scale(self, scale_idx):
66 + self.scale_idx = scale_idx
67 +
1 +## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks
2 +## https://arxiv.org/abs/1807.02758
3 +from model import common
4 +
5 +import torch.nn as nn
6 +
7 +def make_model(args, parent=False):
8 + return RCAN(args)
9 +
10 +## Channel Attention (CA) Layer
11 +class CALayer(nn.Module):
12 + def __init__(self, channel, reduction=16):
13 + super(CALayer, self).__init__()
14 + # global average pooling: feature --> point
15 + self.avg_pool = nn.AdaptiveAvgPool2d(1)
16 + # feature channel downscale and upscale --> channel weight
17 + self.conv_du = nn.Sequential(
18 + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
19 + nn.ReLU(inplace=True),
20 + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
21 + nn.Sigmoid()
22 + )
23 +
24 + def forward(self, x):
25 + y = self.avg_pool(x)
26 + y = self.conv_du(y)
27 + return x * y
28 +
29 +## Residual Channel Attention Block (RCAB)
30 +class RCAB(nn.Module):
31 + def __init__(
32 + self, conv, n_feat, kernel_size, reduction,
33 + bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
34 +
35 + super(RCAB, self).__init__()
36 + modules_body = []
37 + for i in range(2):
38 + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
39 + if bn: modules_body.append(nn.BatchNorm2d(n_feat))
40 + if i == 0: modules_body.append(act)
41 + modules_body.append(CALayer(n_feat, reduction))
42 + self.body = nn.Sequential(*modules_body)
43 + self.res_scale = res_scale
44 +
45 + def forward(self, x):
46 + res = self.body(x)
47 + #res = self.body(x).mul(self.res_scale)
48 + res += x
49 + return res
50 +
51 +## Residual Group (RG)
52 +class ResidualGroup(nn.Module):
53 + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
54 + super(ResidualGroup, self).__init__()
55 + modules_body = []
56 + modules_body = [
57 + RCAB(
58 + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
59 + for _ in range(n_resblocks)]
60 + modules_body.append(conv(n_feat, n_feat, kernel_size))
61 + self.body = nn.Sequential(*modules_body)
62 +
63 + def forward(self, x):
64 + res = self.body(x)
65 + res += x
66 + return res
67 +
68 +## Residual Channel Attention Network (RCAN)
69 +class RCAN(nn.Module):
70 + def __init__(self, args, conv=common.default_conv):
71 + super(RCAN, self).__init__()
72 +
73 + n_resgroups = args.n_resgroups
74 + n_resblocks = args.n_resblocks
75 + n_feats = args.n_feats
76 + kernel_size = 3
77 + reduction = args.reduction
78 + scale = args.scale[0]
79 + act = nn.ReLU(True)
80 +
81 + # RGB mean for DIV2K
82 + self.sub_mean = common.MeanShift(args.rgb_range)
83 +
84 + # define head module
85 + modules_head = [conv(args.n_colors, n_feats, kernel_size)]
86 +
87 + # define body module
88 + modules_body = [
89 + ResidualGroup(
90 + conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
91 + for _ in range(n_resgroups)]
92 +
93 + modules_body.append(conv(n_feats, n_feats, kernel_size))
94 +
95 + # define tail module
96 + modules_tail = [
97 + common.Upsampler(conv, scale, n_feats, act=False),
98 + conv(n_feats, args.n_colors, kernel_size)]
99 +
100 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
101 +
102 + self.head = nn.Sequential(*modules_head)
103 + self.body = nn.Sequential(*modules_body)
104 + self.tail = nn.Sequential(*modules_tail)
105 +
106 + def forward(self, x):
107 + x = self.sub_mean(x)
108 + x = self.head(x)
109 +
110 + res = self.body(x)
111 + res += x
112 +
113 + x = self.tail(res)
114 + x = self.add_mean(x)
115 +
116 + return x
117 +
118 + def load_state_dict(self, state_dict, strict=False):
119 + own_state = self.state_dict()
120 + for name, param in state_dict.items():
121 + if name in own_state:
122 + if isinstance(param, nn.Parameter):
123 + param = param.data
124 + try:
125 + own_state[name].copy_(param)
126 + except Exception:
127 + if name.find('tail') >= 0:
128 + print('Replace pre-trained upsampler to new one...')
129 + else:
130 + raise RuntimeError('While copying the parameter named {}, '
131 + 'whose dimensions in the model are {} and '
132 + 'whose dimensions in the checkpoint are {}.'
133 + .format(name, own_state[name].size(), param.size()))
134 + elif strict:
135 + if name.find('tail') == -1:
136 + raise KeyError('unexpected key "{}" in state_dict'
137 + .format(name))
138 +
139 + if strict:
140 + missing = set(own_state.keys()) - set(state_dict.keys())
141 + if len(missing) > 0:
142 + raise KeyError('missing keys in state_dict: "{}"'.format(missing))
1 +# Residual Dense Network for Image Super-Resolution
2 +# https://arxiv.org/abs/1802.08797
3 +
4 +from model import common
5 +
6 +import torch
7 +import torch.nn as nn
8 +
9 +
10 +def make_model(args, parent=False):
11 + return RDN(args)
12 +
13 +class RDB_Conv(nn.Module):
14 + def __init__(self, inChannels, growRate, kSize=3):
15 + super(RDB_Conv, self).__init__()
16 + Cin = inChannels
17 + G = growRate
18 + self.conv = nn.Sequential(*[
19 + nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
20 + nn.ReLU()
21 + ])
22 +
23 + def forward(self, x):
24 + out = self.conv(x)
25 + return torch.cat((x, out), 1)
26 +
27 +class RDB(nn.Module):
28 + def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
29 + super(RDB, self).__init__()
30 + G0 = growRate0
31 + G = growRate
32 + C = nConvLayers
33 +
34 + convs = []
35 + for c in range(C):
36 + convs.append(RDB_Conv(G0 + c*G, G))
37 + self.convs = nn.Sequential(*convs)
38 +
39 + # Local Feature Fusion
40 + self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
41 +
42 + def forward(self, x):
43 + return self.LFF(self.convs(x)) + x
44 +
45 +class RDN(nn.Module):
46 + def __init__(self, args):
47 + super(RDN, self).__init__()
48 + r = args.scale[0]
49 + G0 = args.G0
50 + kSize = args.RDNkSize
51 +
52 + # number of RDB blocks, conv layers, out channels
53 + self.D, C, G = {
54 + 'A': (20, 6, 32),
55 + 'B': (16, 8, 64),
56 + }[args.RDNconfig]
57 +
58 + # Shallow feature extraction net
59 + self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
60 + self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
61 +
62 + # Redidual dense blocks and dense feature fusion
63 + self.RDBs = nn.ModuleList()
64 + for i in range(self.D):
65 + self.RDBs.append(
66 + RDB(growRate0 = G0, growRate = G, nConvLayers = C)
67 + )
68 +
69 + # Global Feature Fusion
70 + self.GFF = nn.Sequential(*[
71 + nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
72 + nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
73 + ])
74 +
75 + # Up-sampling net
76 + if r == 2 or r == 3:
77 + self.UPNet = nn.Sequential(*[
78 + nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
79 + nn.PixelShuffle(r),
80 + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
81 + ])
82 + elif r == 4:
83 + self.UPNet = nn.Sequential(*[
84 + nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
85 + nn.PixelShuffle(2),
86 + nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
87 + nn.PixelShuffle(2),
88 + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
89 + ])
90 + else:
91 + raise ValueError("scale must be 2 or 3 or 4.")
92 +
93 + def forward(self, x):
94 + f__1 = self.SFENet1(x)
95 + x = self.SFENet2(f__1)
96 +
97 + RDBs_out = []
98 + for i in range(self.D):
99 + x = self.RDBs[i](x)
100 + RDBs_out.append(x)
101 +
102 + x = self.GFF(torch.cat(RDBs_out,1))
103 + x += f__1
104 +
105 + return self.UPNet(x)
1 +from model import common
2 +
3 +import torch.nn as nn
4 +import torch.nn.init as init
5 +
6 +url = {
7 + 'r20f64': ''
8 +}
9 +
10 +def make_model(args, parent=False):
11 + return VDSR(args)
12 +
13 +class VDSR(nn.Module):
14 + def __init__(self, args, conv=common.default_conv):
15 + super(VDSR, self).__init__()
16 +
17 + n_resblocks = args.n_resblocks
18 + n_feats = args.n_feats
19 + kernel_size = 3
20 + self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
21 + self.sub_mean = common.MeanShift(args.rgb_range)
22 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
23 +
24 + def basic_block(in_channels, out_channels, act):
25 + return common.BasicBlock(
26 + conv, in_channels, out_channels, kernel_size,
27 + bias=True, bn=False, act=act
28 + )
29 +
30 + # define body module
31 + m_body = []
32 + m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
33 + for _ in range(n_resblocks - 2):
34 + m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
35 + m_body.append(basic_block(n_feats, args.n_colors, None))
36 +
37 + self.body = nn.Sequential(*m_body)
38 +
39 + def forward(self, x):
40 + x = self.sub_mean(x)
41 + res = self.body(x)
42 + res += x
43 + x = self.add_mean(res)
44 +
45 + return x
46 +
1 +import argparse
2 +import template
3 +
4 +parser = argparse.ArgumentParser(description='EDSR and MDSR')
5 +
6 +parser.add_argument('--debug', action='store_true',
7 + help='Enables debug mode')
8 +parser.add_argument('--template', default='.',
9 + help='You can set various templates in option.py')
10 +
11 +# Hardware specifications
12 +parser.add_argument('--n_threads', type=int, default=6,
13 + help='number of threads for data loading')
14 +parser.add_argument('--cpu', action='store_true',
15 + help='use cpu only')
16 +parser.add_argument('--n_GPUs', type=int, default=1,
17 + help='number of GPUs')
18 +parser.add_argument('--seed', type=int, default=1,
19 + help='random seed')
20 +
21 +# Data specifications
22 +parser.add_argument('--dir_data', type=str, default='../../../dataset',
23 + help='dataset directory')
24 +parser.add_argument('--dir_demo', type=str, default='../test',
25 + help='demo image directory')
26 +parser.add_argument('--data_train', type=str, default='DIV2K',
27 + help='train dataset name')
28 +parser.add_argument('--data_test', type=str, default='DIV2K',
29 + help='test dataset name')
30 +parser.add_argument('--data_range', type=str, default='1-800/801-810',
31 + help='train/test data range')
32 +parser.add_argument('--ext', type=str, default='sep',
33 + help='dataset file extension')
34 +parser.add_argument('--scale', type=str, default='4',
35 + help='super resolution scale')
36 +parser.add_argument('--patch_size', type=int, default=192,
37 + help='output patch size')
38 +parser.add_argument('--rgb_range', type=int, default=255,
39 + help='maximum value of RGB')
40 +parser.add_argument('--n_colors', type=int, default=3,
41 + help='number of color channels to use')
42 +parser.add_argument('--chop', action='store_true',
43 + help='enable memory-efficient forward')
44 +parser.add_argument('--no_augment', action='store_true',
45 + help='do not use data augmentation')
46 +
47 +# Model specifications
48 +parser.add_argument('--model', default='EDSR',
49 + help='model name')
50 +
51 +parser.add_argument('--act', type=str, default='relu',
52 + help='activation function')
53 +parser.add_argument('--pre_train', type=str, default='',
54 + help='pre-trained model directory')
55 +parser.add_argument('--extend', type=str, default='.',
56 + help='pre-trained model directory')
57 +parser.add_argument('--n_resblocks', type=int, default=16,
58 + help='number of residual blocks')
59 +parser.add_argument('--n_feats', type=int, default=64,
60 + help='number of feature maps')
61 +parser.add_argument('--res_scale', type=float, default=1,
62 + help='residual scaling')
63 +parser.add_argument('--shift_mean', default=True,
64 + help='subtract pixel mean from the input')
65 +parser.add_argument('--dilation', action='store_true',
66 + help='use dilated convolution')
67 +parser.add_argument('--precision', type=str, default='single',
68 + choices=('single', 'half'),
69 + help='FP precision for test (single | half)')
70 +
71 +# Option for Residual dense network (RDN)
72 +parser.add_argument('--G0', type=int, default=64,
73 + help='default number of filters. (Use in RDN)')
74 +parser.add_argument('--RDNkSize', type=int, default=3,
75 + help='default kernel size. (Use in RDN)')
76 +parser.add_argument('--RDNconfig', type=str, default='B',
77 + help='parameters config of RDN. (Use in RDN)')
78 +
79 +# Option for Residual channel attention network (RCAN)
80 +parser.add_argument('--n_resgroups', type=int, default=10,
81 + help='number of residual groups')
82 +parser.add_argument('--reduction', type=int, default=16,
83 + help='number of feature maps reduction')
84 +
85 +# Training specifications
86 +parser.add_argument('--reset', action='store_true',
87 + help='reset the training')
88 +parser.add_argument('--test_every', type=int, default=1000,
89 + help='do test per every N batches')
90 +parser.add_argument('--epochs', type=int, default=300,
91 + help='number of epochs to train')
92 +parser.add_argument('--batch_size', type=int, default=16,
93 + help='input batch size for training')
94 +parser.add_argument('--split_batch', type=int, default=1,
95 + help='split the batch into smaller chunks')
96 +parser.add_argument('--self_ensemble', action='store_true',
97 + help='use self-ensemble method for test')
98 +parser.add_argument('--test_only', action='store_true',
99 + help='set this option to test the model')
100 +parser.add_argument('--gan_k', type=int, default=1,
101 + help='k value for adversarial loss')
102 +
103 +# Optimization specifications
104 +parser.add_argument('--lr', type=float, default=1e-4,
105 + help='learning rate')
106 +parser.add_argument('--decay', type=str, default='200',
107 + help='learning rate decay type')
108 +parser.add_argument('--gamma', type=float, default=0.5,
109 + help='learning rate decay factor for step decay')
110 +parser.add_argument('--optimizer', default='ADAM',
111 + choices=('SGD', 'ADAM', 'RMSprop'),
112 + help='optimizer to use (SGD | ADAM | RMSprop)')
113 +parser.add_argument('--momentum', type=float, default=0.9,
114 + help='SGD momentum')
115 +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
116 + help='ADAM beta')
117 +parser.add_argument('--epsilon', type=float, default=1e-8,
118 + help='ADAM epsilon for numerical stability')
119 +parser.add_argument('--weight_decay', type=float, default=0,
120 + help='weight decay')
121 +parser.add_argument('--gclip', type=float, default=0,
122 + help='gradient clipping threshold (0 = no clipping)')
123 +
124 +# Loss specifications
125 +parser.add_argument('--loss', type=str, default='1*L1',
126 + help='loss function configuration')
127 +parser.add_argument('--skip_threshold', type=float, default='1e8',
128 + help='skipping batch that has large error')
129 +
130 +# Log specifications
131 +parser.add_argument('--save', type=str, default='test',
132 + help='file name to save')
133 +parser.add_argument('--load', type=str, default='',
134 + help='file name to load')
135 +parser.add_argument('--resume', type=int, default=0,
136 + help='resume from specific checkpoint')
137 +parser.add_argument('--save_models', action='store_true',
138 + help='save all intermediate models')
139 +parser.add_argument('--print_every', type=int, default=100,
140 + help='how many batches to wait before logging training status')
141 +parser.add_argument('--save_results', action='store_true',
142 + help='save output results')
143 +parser.add_argument('--save_gt', action='store_true',
144 + help='save low-resolution and high-resolution images together')
145 +
146 +args = parser.parse_args()
147 +template.set_template(args)
148 +
149 +args.scale = list(map(lambda x: int(x), args.scale.split('+')))
150 +args.data_train = args.data_train.split('+')
151 +args.data_test = args.data_test.split('+')
152 +
153 +if args.epochs == 0:
154 + args.epochs = 1e8
155 +
156 +for arg in vars(args):
157 + if vars(args)[arg] == 'True':
158 + vars(args)[arg] = True
159 + elif vars(args)[arg] == 'False':
160 + vars(args)[arg] = False
161 +
1 +def set_template(args):
2 + # Set the templates here
3 + if args.template.find('jpeg') >= 0:
4 + args.data_train = 'DIV2K_jpeg'
5 + args.data_test = 'DIV2K_jpeg'
6 + args.epochs = 200
7 + args.decay = '100'
8 +
9 + if args.template.find('EDSR_paper') >= 0:
10 + args.model = 'EDSR'
11 + args.n_resblocks = 32
12 + args.n_feats = 256
13 + args.res_scale = 0.1
14 +
15 + if args.template.find('MDSR') >= 0:
16 + args.model = 'MDSR'
17 + args.patch_size = 48
18 + args.epochs = 650
19 +
20 + if args.template.find('DDBPN') >= 0:
21 + args.model = 'DDBPN'
22 + args.patch_size = 128
23 + args.scale = '4'
24 +
25 + args.data_test = 'Set5'
26 +
27 + args.batch_size = 20
28 + args.epochs = 1000
29 + args.decay = '500'
30 + args.gamma = 0.1
31 + args.weight_decay = 1e-4
32 +
33 + args.loss = '1*MSE'
34 +
35 + if args.template.find('GAN') >= 0:
36 + args.epochs = 200
37 + args.lr = 5e-5
38 + args.decay = '150'
39 +
40 + if args.template.find('RCAN') >= 0:
41 + args.model = 'RCAN'
42 + args.n_resgroups = 10
43 + args.n_resblocks = 20
44 + args.n_feats = 64
45 + args.chop = True
46 +
47 + if args.template.find('VDSR') >= 0:
48 + args.model = 'VDSR'
49 + args.n_resblocks = 20
50 + args.n_feats = 64
51 + args.patch_size = 41
52 + args.lr = 1e-1
53 +
1 +import os
2 +import math
3 +from decimal import Decimal
4 +
5 +import utility
6 +
7 +import torch
8 +import torch.nn.utils as utils
9 +from tqdm import tqdm
10 +
11 +class Trainer():
12 + def __init__(self, args, loader, my_model, my_loss, ckp):
13 + self.args = args
14 + self.scale = args.scale
15 +
16 + self.ckp = ckp
17 + self.loader_train = loader.loader_train
18 + self.loader_test = loader.loader_test
19 + self.model = my_model
20 + self.loss = my_loss
21 + self.optimizer = utility.make_optimizer(args, self.model)
22 +
23 + if self.args.load != '':
24 + self.optimizer.load(ckp.dir, epoch=len(ckp.log))
25 +
26 + self.error_last = 1e8
27 +
28 + def train(self):
29 + self.loss.step()
30 + epoch = self.optimizer.get_last_epoch() + 1
31 + lr = self.optimizer.get_lr()
32 +
33 + self.ckp.write_log(
34 + '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
35 + )
36 + self.loss.start_log()
37 + self.model.train()
38 +
39 + timer_data, timer_model = utility.timer(), utility.timer()
40 + # TEMP
41 + self.loader_train.dataset.set_scale(0)
42 + for batch, (lr, hr, _,) in enumerate(self.loader_train):
43 + lr, hr = self.prepare(lr, hr)
44 + timer_data.hold()
45 + timer_model.tic()
46 +
47 + self.optimizer.zero_grad()
48 + sr = self.model(lr, 0)
49 + loss = self.loss(sr, hr)
50 + loss.backward()
51 + if self.args.gclip > 0:
52 + utils.clip_grad_value_(
53 + self.model.parameters(),
54 + self.args.gclip
55 + )
56 + self.optimizer.step()
57 +
58 + timer_model.hold()
59 +
60 + if (batch + 1) % self.args.print_every == 0:
61 + self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
62 + (batch + 1) * self.args.batch_size,
63 + len(self.loader_train.dataset),
64 + self.loss.display_loss(batch),
65 + timer_model.release(),
66 + timer_data.release()))
67 +
68 + timer_data.tic()
69 +
70 + self.loss.end_log(len(self.loader_train))
71 + self.error_last = self.loss.log[-1, -1]
72 + self.optimizer.schedule()
73 +
74 + def test(self):
75 + torch.set_grad_enabled(False)
76 +
77 + epoch = self.optimizer.get_last_epoch()
78 + self.ckp.write_log('\nEvaluation:')
79 + self.ckp.add_log(
80 + torch.zeros(1, len(self.loader_test), len(self.scale))
81 + )
82 + self.model.eval()
83 +
84 + timer_test = utility.timer()
85 + if self.args.save_results: self.ckp.begin_background()
86 + for idx_data, d in enumerate(self.loader_test):
87 + for idx_scale, scale in enumerate(self.scale):
88 + d.dataset.set_scale(idx_scale)
89 + for lr, hr, filename in tqdm(d, ncols=80):
90 + lr, hr = self.prepare(lr, hr)
91 + sr = self.model(lr, idx_scale)
92 + sr = utility.quantize(sr, self.args.rgb_range)
93 +
94 + save_list = [sr]
95 + self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
96 + sr, hr, scale, self.args.rgb_range, dataset=d
97 + )
98 + if self.args.save_gt:
99 + save_list.extend([lr, hr])
100 +
101 + if self.args.save_results:
102 + self.ckp.save_results(d, filename[0], save_list, scale)
103 +
104 + self.ckp.log[-1, idx_data, idx_scale] /= len(d)
105 + best = self.ckp.log.max(0)
106 + self.ckp.write_log(
107 + '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
108 + d.dataset.name,
109 + scale,
110 + self.ckp.log[-1, idx_data, idx_scale],
111 + best[0][idx_data, idx_scale],
112 + best[1][idx_data, idx_scale] + 1
113 + )
114 + )
115 +
116 + self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
117 + self.ckp.write_log('Saving...')
118 +
119 + if self.args.save_results:
120 + self.ckp.end_background()
121 +
122 + if not self.args.test_only:
123 + self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
124 +
125 + self.ckp.write_log(
126 + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
127 + )
128 +
129 + torch.set_grad_enabled(True)
130 +
131 + def prepare(self, *args):
132 + device = torch.device('cpu' if self.args.cpu else 'cuda')
133 + def _prepare(tensor):
134 + if self.args.precision == 'half': tensor = tensor.half()
135 + return tensor.to(device)
136 +
137 + return [_prepare(a) for a in args]
138 +
139 + def terminate(self):
140 + if self.args.test_only:
141 + self.test()
142 + return True
143 + else:
144 + epoch = self.optimizer.get_last_epoch() + 1
145 + return epoch >= self.args.epochs
146 +
1 +import os
2 +import math
3 +import time
4 +import datetime
5 +from multiprocessing import Process
6 +from multiprocessing import Queue
7 +
8 +import matplotlib
9 +matplotlib.use('Agg')
10 +import matplotlib.pyplot as plt
11 +
12 +import numpy as np
13 +import imageio
14 +
15 +import torch
16 +import torch.optim as optim
17 +import torch.optim.lr_scheduler as lrs
18 +
19 +class timer():
20 + def __init__(self):
21 + self.acc = 0
22 + self.tic()
23 +
24 + def tic(self):
25 + self.t0 = time.time()
26 +
27 + def toc(self, restart=False):
28 + diff = time.time() - self.t0
29 + if restart: self.t0 = time.time()
30 + return diff
31 +
32 + def hold(self):
33 + self.acc += self.toc()
34 +
35 + def release(self):
36 + ret = self.acc
37 + self.acc = 0
38 +
39 + return ret
40 +
41 + def reset(self):
42 + self.acc = 0
43 +
44 +class checkpoint():
45 + def __init__(self, args):
46 + self.args = args
47 + self.ok = True
48 + self.log = torch.Tensor()
49 + now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
50 +
51 + if not args.load:
52 + if not args.save:
53 + args.save = now
54 + self.dir = os.path.join('..', 'experiment', args.save)
55 + else:
56 + self.dir = os.path.join('..', 'experiment', args.load)
57 + if os.path.exists(self.dir):
58 + self.log = torch.load(self.get_path('psnr_log.pt'))
59 + print('Continue from epoch {}...'.format(len(self.log)))
60 + else:
61 + args.load = ''
62 +
63 + if args.reset:
64 + os.system('rm -rf ' + self.dir)
65 + args.load = ''
66 +
67 + os.makedirs(self.dir, exist_ok=True)
68 + os.makedirs(self.get_path('model'), exist_ok=True)
69 + for d in args.data_test:
70 + os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
71 +
72 + open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
73 + self.log_file = open(self.get_path('log.txt'), open_type)
74 + with open(self.get_path('config.txt'), open_type) as f:
75 + f.write(now + '\n\n')
76 + for arg in vars(args):
77 + f.write('{}: {}\n'.format(arg, getattr(args, arg)))
78 + f.write('\n')
79 +
80 + self.n_processes = 8
81 +
82 + def get_path(self, *subdir):
83 + return os.path.join(self.dir, *subdir)
84 +
85 + def save(self, trainer, epoch, is_best=False):
86 + trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
87 + trainer.loss.save(self.dir)
88 + trainer.loss.plot_loss(self.dir, epoch)
89 +
90 + self.plot_psnr(epoch)
91 + trainer.optimizer.save(self.dir)
92 + torch.save(self.log, self.get_path('psnr_log.pt'))
93 +
94 + def add_log(self, log):
95 + self.log = torch.cat([self.log, log])
96 +
97 + def write_log(self, log, refresh=False):
98 + print(log)
99 + self.log_file.write(log + '\n')
100 + if refresh:
101 + self.log_file.close()
102 + self.log_file = open(self.get_path('log.txt'), 'a')
103 +
104 + def done(self):
105 + self.log_file.close()
106 +
107 + def plot_psnr(self, epoch):
108 + axis = np.linspace(1, epoch, epoch)
109 + for idx_data, d in enumerate(self.args.data_test):
110 + label = 'SR on {}'.format(d)
111 + fig = plt.figure()
112 + plt.title(label)
113 + for idx_scale, scale in enumerate(self.args.scale):
114 + plt.plot(
115 + axis,
116 + self.log[:, idx_data, idx_scale].numpy(),
117 + label='Scale {}'.format(scale)
118 + )
119 + plt.legend()
120 + plt.xlabel('Epochs')
121 + plt.ylabel('PSNR')
122 + plt.grid(True)
123 + plt.savefig(self.get_path('test_{}.pdf'.format(d)))
124 + plt.close(fig)
125 +
126 + def begin_background(self):
127 + self.queue = Queue()
128 +
129 + def bg_target(queue):
130 + while True:
131 + if not queue.empty():
132 + filename, tensor = queue.get()
133 + if filename is None: break
134 + imageio.imwrite(filename, tensor.numpy())
135 +
136 + self.process = [
137 + Process(target=bg_target, args=(self.queue,)) \
138 + for _ in range(self.n_processes)
139 + ]
140 +
141 + for p in self.process: p.start()
142 +
143 + def end_background(self):
144 + for _ in range(self.n_processes): self.queue.put((None, None))
145 + while not self.queue.empty(): time.sleep(1)
146 + for p in self.process: p.join()
147 +
148 + def save_results(self, dataset, filename, save_list, scale):
149 + if self.args.save_results:
150 + filename = self.get_path(
151 + 'results-{}'.format(dataset.dataset.name),
152 + '{}_x{}_'.format(filename, scale)
153 + )
154 +
155 + postfix = ('SR', 'LR', 'HR')
156 + for v, p in zip(save_list, postfix):
157 + normalized = v[0].mul(255 / self.args.rgb_range)
158 + tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
159 + self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
160 +
161 +def quantize(img, rgb_range):
162 + pixel_range = 255 / rgb_range
163 + return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
164 +
165 +def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
166 + if hr.nelement() == 1: return 0
167 +
168 + diff = (sr - hr) / rgb_range
169 + if dataset and dataset.dataset.benchmark:
170 + shave = scale
171 + if diff.size(1) > 1:
172 + gray_coeffs = [65.738, 129.057, 25.064]
173 + convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
174 + diff = diff.mul(convert).sum(dim=1)
175 + else:
176 + shave = scale + 6
177 +
178 + valid = diff[..., shave:-shave, shave:-shave]
179 + mse = valid.pow(2).mean()
180 +
181 + return -10 * math.log10(mse)
182 +
183 +def make_optimizer(args, target):
184 + '''
185 + make optimizer and scheduler together
186 + '''
187 + # optimizer
188 + trainable = filter(lambda x: x.requires_grad, target.parameters())
189 + kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
190 +
191 + if args.optimizer == 'SGD':
192 + optimizer_class = optim.SGD
193 + kwargs_optimizer['momentum'] = args.momentum
194 + elif args.optimizer == 'ADAM':
195 + optimizer_class = optim.Adam
196 + kwargs_optimizer['betas'] = args.betas
197 + kwargs_optimizer['eps'] = args.epsilon
198 + elif args.optimizer == 'RMSprop':
199 + optimizer_class = optim.RMSprop
200 + kwargs_optimizer['eps'] = args.epsilon
201 +
202 + # scheduler
203 + milestones = list(map(lambda x: int(x), args.decay.split('-')))
204 + kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
205 + scheduler_class = lrs.MultiStepLR
206 +
207 + class CustomOptimizer(optimizer_class):
208 + def __init__(self, *args, **kwargs):
209 + super(CustomOptimizer, self).__init__(*args, **kwargs)
210 +
211 + def _register_scheduler(self, scheduler_class, **kwargs):
212 + self.scheduler = scheduler_class(self, **kwargs)
213 +
214 + def save(self, save_dir):
215 + torch.save(self.state_dict(), self.get_dir(save_dir))
216 +
217 + def load(self, load_dir, epoch=1):
218 + self.load_state_dict(torch.load(self.get_dir(load_dir)))
219 + if epoch > 1:
220 + for _ in range(epoch): self.scheduler.step()
221 +
222 + def get_dir(self, dir_path):
223 + return os.path.join(dir_path, 'optimizer.pt')
224 +
225 + def schedule(self):
226 + self.scheduler.step()
227 +
228 + def get_lr(self):
229 + return self.scheduler.get_lr()[0]
230 +
231 + def get_last_epoch(self):
232 + return self.scheduler.last_epoch
233 +
234 + optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
235 + optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
236 + return optimizer
237 +
1 +import os
2 +import math
3 +
4 +import utility
5 +from data import common
6 +
7 +import torch
8 +import cv2
9 +
10 +from tqdm import tqdm
11 +
12 +class VideoTester():
13 + def __init__(self, args, my_model, ckp):
14 + self.args = args
15 + self.scale = args.scale
16 +
17 + self.ckp = ckp
18 + self.model = my_model
19 +
20 + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
21 +
22 + def test(self):
23 + torch.set_grad_enabled(False)
24 +
25 + self.ckp.write_log('\nEvaluation on video:')
26 + self.model.eval()
27 +
28 + timer_test = utility.timer()
29 + for idx_scale, scale in enumerate(self.scale):
30 + vidcap = cv2.VideoCapture(self.args.dir_demo)
31 + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
32 + vidwri = cv2.VideoWriter(
33 + self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
34 + cv2.VideoWriter_fourcc(*'XVID'),
35 + vidcap.get(cv2.CAP_PROP_FPS),
36 + (
37 + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
38 + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39 + )
40 + )
41 +
42 + tqdm_test = tqdm(range(total_frames), ncols=80)
43 + for _ in tqdm_test:
44 + success, lr = vidcap.read()
45 + if not success: break
46 +
47 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
48 + lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
49 + lr, = self.prepare(lr.unsqueeze(0))
50 + sr = self.model(lr, idx_scale)
51 + sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
52 +
53 + normalized = sr * 255 / self.args.rgb_range
54 + ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
55 + vidwri.write(ndarr)
56 +
57 + vidcap.release()
58 + vidwri.release()
59 +
60 + self.ckp.write_log(
61 + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
62 + )
63 + torch.set_grad_enabled(True)
64 +
65 + def prepare(self, *args):
66 + device = torch.device('cpu' if self.args.cpu else 'cuda')
67 + def _prepare(tensor):
68 + if self.args.precision == 'half': tensor = tensor.half()
69 + return tensor.to(device)
70 +
71 + return [_prepare(a) for a in args]
72 +