Showing
34 changed files
with
2672 additions
and
0 deletions
edsr/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# C extensions | ||
7 | +*.so | ||
8 | + | ||
9 | +# Distribution / packaging | ||
10 | +.Python | ||
11 | +env/ | ||
12 | +build/ | ||
13 | +develop-eggs/ | ||
14 | +dist/ | ||
15 | +downloads/ | ||
16 | +eggs/ | ||
17 | +.eggs/ | ||
18 | +lib/ | ||
19 | +lib64/ | ||
20 | +parts/ | ||
21 | +sdist/ | ||
22 | +var/ | ||
23 | +*.egg-info/ | ||
24 | +.installed.cfg | ||
25 | +*.egg | ||
26 | + | ||
27 | +# PyInstaller | ||
28 | +# Usually these files are written by a python script from a template | ||
29 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
30 | +*.manifest | ||
31 | +*.spec | ||
32 | + | ||
33 | +# Installer logs | ||
34 | +pip-log.txt | ||
35 | +pip-delete-this-directory.txt | ||
36 | + | ||
37 | +# Unit test / coverage reports | ||
38 | +htmlcov/ | ||
39 | +.tox/ | ||
40 | +.coverage | ||
41 | +.coverage.* | ||
42 | +.cache | ||
43 | +nosetests.xml | ||
44 | +coverage.xml | ||
45 | +*,cover | ||
46 | + | ||
47 | +# Translations | ||
48 | +*.mo | ||
49 | +*.pot | ||
50 | + | ||
51 | +# Django stuff: | ||
52 | +*.log | ||
53 | + | ||
54 | +# Sphinx documentation | ||
55 | +docs/_build/ | ||
56 | + | ||
57 | +# PyBuilder | ||
58 | +target/ | ||
59 | + | ||
60 | +# PyTorch | ||
61 | +*.pt | ||
62 | |||
63 | +*.png | ||
64 | +*.txt | ||
65 | +*.swp | ||
66 | +.vscode |
edsr/LICENSE
0 → 100644
1 | +MIT License | ||
2 | + | ||
3 | +Copyright (c) 2018 Sanghyun Son | ||
4 | + | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
6 | +of this software and associated documentation files (the "Software"), to deal | ||
7 | +in the Software without restriction, including without limitation the rights | ||
8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
9 | +copies of the Software, and to permit persons to whom the Software is | ||
10 | +furnished to do so, subject to the following conditions: | ||
11 | + | ||
12 | +The above copyright notice and this permission notice shall be included in all | ||
13 | +copies or substantial portions of the Software. | ||
14 | + | ||
15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
21 | +SOFTWARE. |
edsr/README.md
0 → 100755
This diff is collapsed. Click to expand it.
edsr/experiment/.gitignore
0 → 100644
edsr/src/__init__.py
0 → 100644
File mode changed
edsr/src/data/__init__.py
0 → 100644
1 | +from importlib import import_module | ||
2 | +#from dataloader import MSDataLoader | ||
3 | +from torch.utils.data import dataloader | ||
4 | +from torch.utils.data import ConcatDataset | ||
5 | + | ||
6 | +# This is a simple wrapper function for ConcatDataset | ||
7 | +class MyConcatDataset(ConcatDataset): | ||
8 | + def __init__(self, datasets): | ||
9 | + super(MyConcatDataset, self).__init__(datasets) | ||
10 | + self.train = datasets[0].train | ||
11 | + | ||
12 | + def set_scale(self, idx_scale): | ||
13 | + for d in self.datasets: | ||
14 | + if hasattr(d, 'set_scale'): d.set_scale(idx_scale) | ||
15 | + | ||
16 | +class Data: | ||
17 | + def __init__(self, args): | ||
18 | + self.loader_train = None | ||
19 | + if not args.test_only: | ||
20 | + datasets = [] | ||
21 | + for d in args.data_train: | ||
22 | + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | ||
23 | + m = import_module('data.' + module_name.lower()) | ||
24 | + datasets.append(getattr(m, module_name)(args, name=d)) | ||
25 | + | ||
26 | + self.loader_train = dataloader.DataLoader( | ||
27 | + MyConcatDataset(datasets), | ||
28 | + batch_size=args.batch_size, | ||
29 | + shuffle=True, | ||
30 | + pin_memory=not args.cpu, | ||
31 | + num_workers=args.n_threads, | ||
32 | + ) | ||
33 | + | ||
34 | + self.loader_test = [] | ||
35 | + for d in args.data_test: | ||
36 | + if d in ['Set5', 'Set14', 'B100', 'Urban100']: | ||
37 | + m = import_module('data.benchmark') | ||
38 | + testset = getattr(m, 'Benchmark')(args, train=False, name=d) | ||
39 | + else: | ||
40 | + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | ||
41 | + m = import_module('data.' + module_name.lower()) | ||
42 | + testset = getattr(m, module_name)(args, train=False, name=d) | ||
43 | + | ||
44 | + self.loader_test.append( | ||
45 | + dataloader.DataLoader( | ||
46 | + testset, | ||
47 | + batch_size=1, | ||
48 | + shuffle=False, | ||
49 | + pin_memory=not args.cpu, | ||
50 | + num_workers=args.n_threads, | ||
51 | + ) | ||
52 | + ) |
edsr/src/data/benchmark.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +from data import common | ||
4 | +from data import srdata | ||
5 | + | ||
6 | +import numpy as np | ||
7 | + | ||
8 | +import torch | ||
9 | +import torch.utils.data as data | ||
10 | + | ||
11 | +class Benchmark(srdata.SRData): | ||
12 | + def __init__(self, args, name='', train=True, benchmark=True): | ||
13 | + super(Benchmark, self).__init__( | ||
14 | + args, name=name, train=train, benchmark=True | ||
15 | + ) | ||
16 | + | ||
17 | + def _set_filesystem(self, dir_data): | ||
18 | + self.apath = os.path.join(dir_data, 'benchmark', self.name) | ||
19 | + self.dir_hr = os.path.join(self.apath, 'HR') | ||
20 | + if self.input_large: | ||
21 | + self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') | ||
22 | + else: | ||
23 | + self.dir_lr = os.path.join(self.apath, 'LR_bicubic') | ||
24 | + self.ext = ('', '.png') | ||
25 | + |
edsr/src/data/common.py
0 → 100644
1 | +import random | ||
2 | + | ||
3 | +import numpy as np | ||
4 | +import skimage.color as sc | ||
5 | + | ||
6 | +import torch | ||
7 | + | ||
8 | +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | ||
9 | + ih, iw = args[0].shape[:2] | ||
10 | + | ||
11 | + if not input_large: | ||
12 | + p = scale if multi else 1 | ||
13 | + tp = p * patch_size | ||
14 | + ip = tp // scale | ||
15 | + else: | ||
16 | + tp = patch_size | ||
17 | + ip = patch_size | ||
18 | + | ||
19 | + ix = random.randrange(0, iw - ip + 1) | ||
20 | + iy = random.randrange(0, ih - ip + 1) | ||
21 | + | ||
22 | + if not input_large: | ||
23 | + tx, ty = scale * ix, scale * iy | ||
24 | + else: | ||
25 | + tx, ty = ix, iy | ||
26 | + | ||
27 | + ret = [ | ||
28 | + args[0][iy:iy + ip, ix:ix + ip, :], | ||
29 | + *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] | ||
30 | + ] | ||
31 | + | ||
32 | + return ret | ||
33 | + | ||
34 | +def set_channel(*args, n_channels=3): | ||
35 | + def _set_channel(img): | ||
36 | + if img.ndim == 2: | ||
37 | + img = np.expand_dims(img, axis=2) | ||
38 | + | ||
39 | + c = img.shape[2] | ||
40 | + if n_channels == 1 and c == 3: | ||
41 | + img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) | ||
42 | + elif n_channels == 3 and c == 1: | ||
43 | + img = np.concatenate([img] * n_channels, 2) | ||
44 | + | ||
45 | + return img | ||
46 | + | ||
47 | + return [_set_channel(a) for a in args] | ||
48 | + | ||
49 | +def np2Tensor(*args, rgb_range=255): | ||
50 | + def _np2Tensor(img): | ||
51 | + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | ||
52 | + tensor = torch.from_numpy(np_transpose).float() | ||
53 | + tensor.mul_(rgb_range / 255) | ||
54 | + | ||
55 | + return tensor | ||
56 | + | ||
57 | + return [_np2Tensor(a) for a in args] | ||
58 | + | ||
59 | +def augment(*args, hflip=True, rot=True): | ||
60 | + hflip = hflip and random.random() < 0.5 | ||
61 | + vflip = rot and random.random() < 0.5 | ||
62 | + rot90 = rot and random.random() < 0.5 | ||
63 | + | ||
64 | + def _augment(img): | ||
65 | + if hflip: img = img[:, ::-1, :] | ||
66 | + if vflip: img = img[::-1, :, :] | ||
67 | + if rot90: img = img.transpose(1, 0, 2) | ||
68 | + | ||
69 | + return img | ||
70 | + | ||
71 | + return [_augment(a) for a in args] | ||
72 | + |
edsr/src/data/demo.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +from data import common | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +import imageio | ||
7 | + | ||
8 | +import torch | ||
9 | +import torch.utils.data as data | ||
10 | + | ||
11 | +class Demo(data.Dataset): | ||
12 | + def __init__(self, args, name='Demo', train=False, benchmark=False): | ||
13 | + self.args = args | ||
14 | + self.name = name | ||
15 | + self.scale = args.scale | ||
16 | + self.idx_scale = 0 | ||
17 | + self.train = False | ||
18 | + self.benchmark = benchmark | ||
19 | + | ||
20 | + self.filelist = [] | ||
21 | + for f in os.listdir(args.dir_demo): | ||
22 | + if f.find('.png') >= 0 or f.find('.jp') >= 0: | ||
23 | + self.filelist.append(os.path.join(args.dir_demo, f)) | ||
24 | + self.filelist.sort() | ||
25 | + | ||
26 | + def __getitem__(self, idx): | ||
27 | + filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] | ||
28 | + lr = imageio.imread(self.filelist[idx]) | ||
29 | + lr, = common.set_channel(lr, n_channels=self.args.n_colors) | ||
30 | + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | ||
31 | + | ||
32 | + return lr_t, -1, filename | ||
33 | + | ||
34 | + def __len__(self): | ||
35 | + return len(self.filelist) | ||
36 | + | ||
37 | + def set_scale(self, idx_scale): | ||
38 | + self.idx_scale = idx_scale | ||
39 | + |
edsr/src/data/div2k.py
0 → 100644
1 | +import os | ||
2 | +from data import srdata | ||
3 | + | ||
4 | +class DIV2K(srdata.SRData): | ||
5 | + def __init__(self, args, name='DIV2K', train=True, benchmark=False): | ||
6 | + data_range = [r.split('-') for r in args.data_range.split('/')] | ||
7 | + if train: | ||
8 | + data_range = data_range[0] | ||
9 | + else: | ||
10 | + if args.test_only and len(data_range) == 1: | ||
11 | + data_range = data_range[0] | ||
12 | + else: | ||
13 | + data_range = data_range[1] | ||
14 | + | ||
15 | + self.begin, self.end = list(map(lambda x: int(x), data_range)) | ||
16 | + super(DIV2K, self).__init__( | ||
17 | + args, name=name, train=train, benchmark=benchmark | ||
18 | + ) | ||
19 | + | ||
20 | + def _scan(self): | ||
21 | + names_hr, names_lr = super(DIV2K, self)._scan() | ||
22 | + names_hr = names_hr[self.begin - 1:self.end] | ||
23 | + names_lr = [n[self.begin - 1:self.end] for n in names_lr] | ||
24 | + | ||
25 | + return names_hr, names_lr | ||
26 | + | ||
27 | + def _set_filesystem(self, dir_data): | ||
28 | + super(DIV2K, self)._set_filesystem(dir_data) | ||
29 | + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | ||
30 | + self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') | ||
31 | + if self.input_large: self.dir_lr += 'L' | ||
32 | + |
edsr/src/data/div2kjpeg.py
0 → 100644
1 | +import os | ||
2 | +from data import srdata | ||
3 | +from data import div2k | ||
4 | + | ||
5 | +class DIV2KJPEG(div2k.DIV2K): | ||
6 | + def __init__(self, args, name='', train=True, benchmark=False): | ||
7 | + self.q_factor = int(name.replace('DIV2K-Q', '')) | ||
8 | + super(DIV2KJPEG, self).__init__( | ||
9 | + args, name=name, train=train, benchmark=benchmark | ||
10 | + ) | ||
11 | + | ||
12 | + def _set_filesystem(self, dir_data): | ||
13 | + self.apath = os.path.join(dir_data, 'DIV2K') | ||
14 | + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | ||
15 | + self.dir_lr = os.path.join( | ||
16 | + self.apath, 'DIV2K_Q{}'.format(self.q_factor) | ||
17 | + ) | ||
18 | + if self.input_large: self.dir_lr += 'L' | ||
19 | + self.ext = ('.png', '.jpg') | ||
20 | + |
edsr/src/data/sr291.py
0 → 100644
edsr/src/data/srdata.py
0 → 100644
1 | +import os | ||
2 | +import glob | ||
3 | +import random | ||
4 | +import pickle | ||
5 | + | ||
6 | +from data import common | ||
7 | + | ||
8 | +import numpy as np | ||
9 | +import imageio | ||
10 | +import torch | ||
11 | +import torch.utils.data as data | ||
12 | + | ||
13 | +class SRData(data.Dataset): | ||
14 | + def __init__(self, args, name='', train=True, benchmark=False): | ||
15 | + self.args = args | ||
16 | + self.name = name | ||
17 | + self.train = train | ||
18 | + self.split = 'train' if train else 'test' | ||
19 | + self.do_eval = True | ||
20 | + self.benchmark = benchmark | ||
21 | + self.input_large = (args.model == 'VDSR') | ||
22 | + self.scale = args.scale | ||
23 | + self.idx_scale = 0 | ||
24 | + | ||
25 | + self._set_filesystem(args.dir_data) | ||
26 | + if args.ext.find('img') < 0: | ||
27 | + path_bin = os.path.join(self.apath, 'bin') | ||
28 | + os.makedirs(path_bin, exist_ok=True) | ||
29 | + | ||
30 | + list_hr, list_lr = self._scan() | ||
31 | + if args.ext.find('img') >= 0 or benchmark: | ||
32 | + self.images_hr, self.images_lr = list_hr, list_lr | ||
33 | + elif args.ext.find('sep') >= 0: | ||
34 | + os.makedirs( | ||
35 | + self.dir_hr.replace(self.apath, path_bin), | ||
36 | + exist_ok=True | ||
37 | + ) | ||
38 | + for s in self.scale: | ||
39 | + os.makedirs( | ||
40 | + os.path.join( | ||
41 | + self.dir_lr.replace(self.apath, path_bin), | ||
42 | + 'X{}'.format(s) | ||
43 | + ), | ||
44 | + exist_ok=True | ||
45 | + ) | ||
46 | + | ||
47 | + self.images_hr, self.images_lr = [], [[] for _ in self.scale] | ||
48 | + for h in list_hr: | ||
49 | + b = h.replace(self.apath, path_bin) | ||
50 | + b = b.replace(self.ext[0], '.pt') | ||
51 | + self.images_hr.append(b) | ||
52 | + self._check_and_load(args.ext, h, b, verbose=True) | ||
53 | + for i, ll in enumerate(list_lr): | ||
54 | + for l in ll: | ||
55 | + b = l.replace(self.apath, path_bin) | ||
56 | + b = b.replace(self.ext[1], '.pt') | ||
57 | + self.images_lr[i].append(b) | ||
58 | + self._check_and_load(args.ext, l, b, verbose=True) | ||
59 | + if train: | ||
60 | + n_patches = args.batch_size * args.test_every | ||
61 | + n_images = len(args.data_train) * len(self.images_hr) | ||
62 | + if n_images == 0: | ||
63 | + self.repeat = 0 | ||
64 | + else: | ||
65 | + self.repeat = max(n_patches // n_images, 1) | ||
66 | + | ||
67 | + # Below functions as used to prepare images | ||
68 | + def _scan(self): | ||
69 | + names_hr = sorted( | ||
70 | + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) | ||
71 | + ) | ||
72 | + names_lr = [[] for _ in self.scale] | ||
73 | + for f in names_hr: | ||
74 | + filename, _ = os.path.splitext(os.path.basename(f)) | ||
75 | + for si, s in enumerate(self.scale): | ||
76 | + names_lr[si].append(os.path.join( | ||
77 | + self.dir_lr, 'X{}/{}x{}{}'.format( | ||
78 | + s, filename, s, self.ext[1] | ||
79 | + ) | ||
80 | + )) | ||
81 | + | ||
82 | + return names_hr, names_lr | ||
83 | + | ||
84 | + def _set_filesystem(self, dir_data): | ||
85 | + self.apath = os.path.join(dir_data, self.name) | ||
86 | + self.dir_hr = os.path.join(self.apath, 'HR') | ||
87 | + self.dir_lr = os.path.join(self.apath, 'LR_bicubic') | ||
88 | + if self.input_large: self.dir_lr += 'L' | ||
89 | + self.ext = ('.png', '.png') | ||
90 | + | ||
91 | + def _check_and_load(self, ext, img, f, verbose=True): | ||
92 | + if not os.path.isfile(f) or ext.find('reset') >= 0: | ||
93 | + if verbose: | ||
94 | + print('Making a binary: {}'.format(f)) | ||
95 | + with open(f, 'wb') as _f: | ||
96 | + pickle.dump(imageio.imread(img), _f) | ||
97 | + | ||
98 | + def __getitem__(self, idx): | ||
99 | + lr, hr, filename = self._load_file(idx) | ||
100 | + pair = self.get_patch(lr, hr) | ||
101 | + pair = common.set_channel(*pair, n_channels=self.args.n_colors) | ||
102 | + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | ||
103 | + | ||
104 | + return pair_t[0], pair_t[1], filename | ||
105 | + | ||
106 | + def __len__(self): | ||
107 | + if self.train: | ||
108 | + return len(self.images_hr) * self.repeat | ||
109 | + else: | ||
110 | + return len(self.images_hr) | ||
111 | + | ||
112 | + def _get_index(self, idx): | ||
113 | + if self.train: | ||
114 | + return idx % len(self.images_hr) | ||
115 | + else: | ||
116 | + return idx | ||
117 | + | ||
118 | + def _load_file(self, idx): | ||
119 | + idx = self._get_index(idx) | ||
120 | + f_hr = self.images_hr[idx] | ||
121 | + f_lr = self.images_lr[self.idx_scale][idx] | ||
122 | + | ||
123 | + filename, _ = os.path.splitext(os.path.basename(f_hr)) | ||
124 | + if self.args.ext == 'img' or self.benchmark: | ||
125 | + hr = imageio.imread(f_hr) | ||
126 | + lr = imageio.imread(f_lr) | ||
127 | + elif self.args.ext.find('sep') >= 0: | ||
128 | + with open(f_hr, 'rb') as _f: | ||
129 | + hr = pickle.load(_f) | ||
130 | + with open(f_lr, 'rb') as _f: | ||
131 | + lr = pickle.load(_f) | ||
132 | + | ||
133 | + return lr, hr, filename | ||
134 | + | ||
135 | + def get_patch(self, lr, hr): | ||
136 | + scale = self.scale[self.idx_scale] | ||
137 | + if self.train: | ||
138 | + lr, hr = common.get_patch( | ||
139 | + lr, hr, | ||
140 | + patch_size=self.args.patch_size, | ||
141 | + scale=scale, | ||
142 | + multi=(len(self.scale) > 1), | ||
143 | + input_large=self.input_large | ||
144 | + ) | ||
145 | + if not self.args.no_augment: lr, hr = common.augment(lr, hr) | ||
146 | + else: | ||
147 | + ih, iw = lr.shape[:2] | ||
148 | + hr = hr[0:ih * scale, 0:iw * scale] | ||
149 | + | ||
150 | + return lr, hr | ||
151 | + | ||
152 | + def set_scale(self, idx_scale): | ||
153 | + if not self.input_large: | ||
154 | + self.idx_scale = idx_scale | ||
155 | + else: | ||
156 | + self.idx_scale = random.randint(0, len(self.scale) - 1) | ||
157 | + |
edsr/src/data/video.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +from data import common | ||
4 | + | ||
5 | +import cv2 | ||
6 | +import numpy as np | ||
7 | +import imageio | ||
8 | + | ||
9 | +import torch | ||
10 | +import torch.utils.data as data | ||
11 | + | ||
12 | +class Video(data.Dataset): | ||
13 | + def __init__(self, args, name='Video', train=False, benchmark=False): | ||
14 | + self.args = args | ||
15 | + self.name = name | ||
16 | + self.scale = args.scale | ||
17 | + self.idx_scale = 0 | ||
18 | + self.train = False | ||
19 | + self.do_eval = False | ||
20 | + self.benchmark = benchmark | ||
21 | + | ||
22 | + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) | ||
23 | + self.vidcap = cv2.VideoCapture(args.dir_demo) | ||
24 | + self.n_frames = 0 | ||
25 | + self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
26 | + | ||
27 | + def __getitem__(self, idx): | ||
28 | + success, lr = self.vidcap.read() | ||
29 | + if success: | ||
30 | + self.n_frames += 1 | ||
31 | + lr, = common.set_channel(lr, n_channels=self.args.n_colors) | ||
32 | + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | ||
33 | + | ||
34 | + return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) | ||
35 | + else: | ||
36 | + vidcap.release() | ||
37 | + return None | ||
38 | + | ||
39 | + def __len__(self): | ||
40 | + return self.total_frames | ||
41 | + | ||
42 | + def set_scale(self, idx_scale): | ||
43 | + self.idx_scale = idx_scale | ||
44 | + |
edsr/src/dataloader.py
0 → 100644
1 | +import threading | ||
2 | +import random | ||
3 | + | ||
4 | +import torch | ||
5 | +import torch.multiprocessing as multiprocessing | ||
6 | +from torch.utils.data import DataLoader | ||
7 | +from torch.utils.data import SequentialSampler | ||
8 | +from torch.utils.data import RandomSampler | ||
9 | +from torch.utils.data import BatchSampler | ||
10 | +from torch.utils.data import _utils | ||
11 | +from torch.utils.data.dataloader import _DataLoaderIter | ||
12 | + | ||
13 | +from torch.utils.data._utils import collate | ||
14 | +from torch.utils.data._utils import signal_handling | ||
15 | +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL | ||
16 | +from torch.utils.data._utils import ExceptionWrapper | ||
17 | +from torch.utils.data._utils import IS_WINDOWS | ||
18 | +from torch.utils.data._utils.worker import ManagerWatchdog | ||
19 | + | ||
20 | +from torch._six import queue | ||
21 | + | ||
22 | +def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): | ||
23 | + try: | ||
24 | + collate._use_shared_memory = True | ||
25 | + signal_handling._set_worker_signal_handlers() | ||
26 | + | ||
27 | + torch.set_num_threads(1) | ||
28 | + random.seed(seed) | ||
29 | + torch.manual_seed(seed) | ||
30 | + | ||
31 | + data_queue.cancel_join_thread() | ||
32 | + | ||
33 | + if init_fn is not None: | ||
34 | + init_fn(worker_id) | ||
35 | + | ||
36 | + watchdog = ManagerWatchdog() | ||
37 | + | ||
38 | + while watchdog.is_alive(): | ||
39 | + try: | ||
40 | + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) | ||
41 | + except queue.Empty: | ||
42 | + continue | ||
43 | + | ||
44 | + if r is None: | ||
45 | + assert done_event.is_set() | ||
46 | + return | ||
47 | + elif done_event.is_set(): | ||
48 | + continue | ||
49 | + | ||
50 | + idx, batch_indices = r | ||
51 | + try: | ||
52 | + idx_scale = 0 | ||
53 | + if len(scale) > 1 and dataset.train: | ||
54 | + idx_scale = random.randrange(0, len(scale)) | ||
55 | + dataset.set_scale(idx_scale) | ||
56 | + | ||
57 | + samples = collate_fn([dataset[i] for i in batch_indices]) | ||
58 | + samples.append(idx_scale) | ||
59 | + except Exception: | ||
60 | + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) | ||
61 | + else: | ||
62 | + data_queue.put((idx, samples)) | ||
63 | + del samples | ||
64 | + | ||
65 | + except KeyboardInterrupt: | ||
66 | + pass | ||
67 | + | ||
68 | +class _MSDataLoaderIter(_DataLoaderIter): | ||
69 | + | ||
70 | + def __init__(self, loader): | ||
71 | + self.dataset = loader.dataset | ||
72 | + self.scale = loader.scale | ||
73 | + self.collate_fn = loader.collate_fn | ||
74 | + self.batch_sampler = loader.batch_sampler | ||
75 | + self.num_workers = loader.num_workers | ||
76 | + self.pin_memory = loader.pin_memory and torch.cuda.is_available() | ||
77 | + self.timeout = loader.timeout | ||
78 | + | ||
79 | + self.sample_iter = iter(self.batch_sampler) | ||
80 | + | ||
81 | + base_seed = torch.LongTensor(1).random_().item() | ||
82 | + | ||
83 | + if self.num_workers > 0: | ||
84 | + self.worker_init_fn = loader.worker_init_fn | ||
85 | + self.worker_queue_idx = 0 | ||
86 | + self.worker_result_queue = multiprocessing.Queue() | ||
87 | + self.batches_outstanding = 0 | ||
88 | + self.worker_pids_set = False | ||
89 | + self.shutdown = False | ||
90 | + self.send_idx = 0 | ||
91 | + self.rcvd_idx = 0 | ||
92 | + self.reorder_dict = {} | ||
93 | + self.done_event = multiprocessing.Event() | ||
94 | + | ||
95 | + base_seed = torch.LongTensor(1).random_()[0] | ||
96 | + | ||
97 | + self.index_queues = [] | ||
98 | + self.workers = [] | ||
99 | + for i in range(self.num_workers): | ||
100 | + index_queue = multiprocessing.Queue() | ||
101 | + index_queue.cancel_join_thread() | ||
102 | + w = multiprocessing.Process( | ||
103 | + target=_ms_loop, | ||
104 | + args=( | ||
105 | + self.dataset, | ||
106 | + index_queue, | ||
107 | + self.worker_result_queue, | ||
108 | + self.done_event, | ||
109 | + self.collate_fn, | ||
110 | + self.scale, | ||
111 | + base_seed + i, | ||
112 | + self.worker_init_fn, | ||
113 | + i | ||
114 | + ) | ||
115 | + ) | ||
116 | + w.daemon = True | ||
117 | + w.start() | ||
118 | + self.index_queues.append(index_queue) | ||
119 | + self.workers.append(w) | ||
120 | + | ||
121 | + if self.pin_memory: | ||
122 | + self.data_queue = queue.Queue() | ||
123 | + pin_memory_thread = threading.Thread( | ||
124 | + target=_utils.pin_memory._pin_memory_loop, | ||
125 | + args=( | ||
126 | + self.worker_result_queue, | ||
127 | + self.data_queue, | ||
128 | + torch.cuda.current_device(), | ||
129 | + self.done_event | ||
130 | + ) | ||
131 | + ) | ||
132 | + pin_memory_thread.daemon = True | ||
133 | + pin_memory_thread.start() | ||
134 | + self.pin_memory_thread = pin_memory_thread | ||
135 | + else: | ||
136 | + self.data_queue = self.worker_result_queue | ||
137 | + | ||
138 | + _utils.signal_handling._set_worker_pids( | ||
139 | + id(self), tuple(w.pid for w in self.workers) | ||
140 | + ) | ||
141 | + _utils.signal_handling._set_SIGCHLD_handler() | ||
142 | + self.worker_pids_set = True | ||
143 | + | ||
144 | + for _ in range(2 * self.num_workers): | ||
145 | + self._put_indices() | ||
146 | + | ||
147 | + | ||
148 | +class MSDataLoader(DataLoader): | ||
149 | + | ||
150 | + def __init__(self, cfg, *args, **kwargs): | ||
151 | + super(MSDataLoader, self).__init__( | ||
152 | + *args, **kwargs, num_workers=cfg.n_threads | ||
153 | + ) | ||
154 | + self.scale = cfg.scale | ||
155 | + | ||
156 | + def __iter__(self): | ||
157 | + return _MSDataLoaderIter(self) | ||
158 | + |
edsr/src/demo.sh
0 → 100644
1 | +# EDSR baseline model (x2) + JPEG augmentation | ||
2 | +python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset | ||
3 | +#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 | ||
4 | + | ||
5 | +# EDSR baseline model (x3) - from EDSR baseline model (x2) | ||
6 | +#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] | ||
7 | + | ||
8 | +# EDSR baseline model (x4) - from EDSR baseline model (x2) | ||
9 | +#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] | ||
10 | + | ||
11 | +# EDSR in the paper (x2) | ||
12 | +#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset | ||
13 | + | ||
14 | +# EDSR in the paper (x3) - from EDSR (x2) | ||
15 | +#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] | ||
16 | + | ||
17 | +# EDSR in the paper (x4) - from EDSR (x2) | ||
18 | +#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] | ||
19 | + | ||
20 | +# MDSR baseline model | ||
21 | +#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models | ||
22 | + | ||
23 | +# MDSR in the paper | ||
24 | +#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models | ||
25 | + | ||
26 | +# Standard benchmarks (Ex. EDSR_baseline_x4) | ||
27 | +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble | ||
28 | + | ||
29 | +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble | ||
30 | + | ||
31 | +# Test your own images | ||
32 | +#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results | ||
33 | + | ||
34 | +# Advanced - Test with JPEG images | ||
35 | +#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results | ||
36 | + | ||
37 | +# Advanced - Training with adversarial loss | ||
38 | +#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download | ||
39 | + | ||
40 | +# RDN BI model (x2) | ||
41 | +#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset | ||
42 | +# RDN BI model (x3) | ||
43 | +#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset | ||
44 | +# RDN BI model (x4) | ||
45 | +#python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset | ||
46 | + | ||
47 | +# RCAN_BIX2_G10R20P48, input=48x48, output=96x96 | ||
48 | +# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0 | ||
49 | +#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96 | ||
50 | +# RCAN_BIX3_G10R20P48, input=48x48, output=144x144 | ||
51 | +#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt | ||
52 | +# RCAN_BIX4_G10R20P48, input=48x48, output=192x192 | ||
53 | +#python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt | ||
54 | +# RCAN_BIX8_G10R20P48, input=48x48, output=384x384 | ||
55 | +#python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt | ||
56 | + |
edsr/src/loss/__init__.py
0 → 100644
1 | +import os | ||
2 | +from importlib import import_module | ||
3 | + | ||
4 | +import matplotlib | ||
5 | +matplotlib.use('Agg') | ||
6 | +import matplotlib.pyplot as plt | ||
7 | + | ||
8 | +import numpy as np | ||
9 | + | ||
10 | +import torch | ||
11 | +import torch.nn as nn | ||
12 | +import torch.nn.functional as F | ||
13 | + | ||
14 | +class Loss(nn.modules.loss._Loss): | ||
15 | + def __init__(self, args, ckp): | ||
16 | + super(Loss, self).__init__() | ||
17 | + print('Preparing loss function:') | ||
18 | + | ||
19 | + self.n_GPUs = args.n_GPUs | ||
20 | + self.loss = [] | ||
21 | + self.loss_module = nn.ModuleList() | ||
22 | + for loss in args.loss.split('+'): | ||
23 | + weight, loss_type = loss.split('*') | ||
24 | + if loss_type == 'MSE': | ||
25 | + loss_function = nn.MSELoss() | ||
26 | + elif loss_type == 'L1': | ||
27 | + loss_function = nn.L1Loss() | ||
28 | + elif loss_type.find('VGG') >= 0: | ||
29 | + module = import_module('loss.vgg') | ||
30 | + loss_function = getattr(module, 'VGG')( | ||
31 | + loss_type[3:], | ||
32 | + rgb_range=args.rgb_range | ||
33 | + ) | ||
34 | + elif loss_type.find('GAN') >= 0: | ||
35 | + module = import_module('loss.adversarial') | ||
36 | + loss_function = getattr(module, 'Adversarial')( | ||
37 | + args, | ||
38 | + loss_type | ||
39 | + ) | ||
40 | + | ||
41 | + self.loss.append({ | ||
42 | + 'type': loss_type, | ||
43 | + 'weight': float(weight), | ||
44 | + 'function': loss_function} | ||
45 | + ) | ||
46 | + if loss_type.find('GAN') >= 0: | ||
47 | + self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) | ||
48 | + | ||
49 | + if len(self.loss) > 1: | ||
50 | + self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) | ||
51 | + | ||
52 | + for l in self.loss: | ||
53 | + if l['function'] is not None: | ||
54 | + print('{:.3f} * {}'.format(l['weight'], l['type'])) | ||
55 | + self.loss_module.append(l['function']) | ||
56 | + | ||
57 | + self.log = torch.Tensor() | ||
58 | + | ||
59 | + device = torch.device('cpu' if args.cpu else 'cuda') | ||
60 | + self.loss_module.to(device) | ||
61 | + if args.precision == 'half': self.loss_module.half() | ||
62 | + if not args.cpu and args.n_GPUs > 1: | ||
63 | + self.loss_module = nn.DataParallel( | ||
64 | + self.loss_module, range(args.n_GPUs) | ||
65 | + ) | ||
66 | + | ||
67 | + if args.load != '': self.load(ckp.dir, cpu=args.cpu) | ||
68 | + | ||
69 | + def forward(self, sr, hr): | ||
70 | + losses = [] | ||
71 | + for i, l in enumerate(self.loss): | ||
72 | + if l['function'] is not None: | ||
73 | + loss = l['function'](sr, hr) | ||
74 | + effective_loss = l['weight'] * loss | ||
75 | + losses.append(effective_loss) | ||
76 | + self.log[-1, i] += effective_loss.item() | ||
77 | + elif l['type'] == 'DIS': | ||
78 | + self.log[-1, i] += self.loss[i - 1]['function'].loss | ||
79 | + | ||
80 | + loss_sum = sum(losses) | ||
81 | + if len(self.loss) > 1: | ||
82 | + self.log[-1, -1] += loss_sum.item() | ||
83 | + | ||
84 | + return loss_sum | ||
85 | + | ||
86 | + def step(self): | ||
87 | + for l in self.get_loss_module(): | ||
88 | + if hasattr(l, 'scheduler'): | ||
89 | + l.scheduler.step() | ||
90 | + | ||
91 | + def start_log(self): | ||
92 | + self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) | ||
93 | + | ||
94 | + def end_log(self, n_batches): | ||
95 | + self.log[-1].div_(n_batches) | ||
96 | + | ||
97 | + def display_loss(self, batch): | ||
98 | + n_samples = batch + 1 | ||
99 | + log = [] | ||
100 | + for l, c in zip(self.loss, self.log[-1]): | ||
101 | + log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) | ||
102 | + | ||
103 | + return ''.join(log) | ||
104 | + | ||
105 | + def plot_loss(self, apath, epoch): | ||
106 | + axis = np.linspace(1, epoch, epoch) | ||
107 | + for i, l in enumerate(self.loss): | ||
108 | + label = '{} Loss'.format(l['type']) | ||
109 | + fig = plt.figure() | ||
110 | + plt.title(label) | ||
111 | + plt.plot(axis, self.log[:, i].numpy(), label=label) | ||
112 | + plt.legend() | ||
113 | + plt.xlabel('Epochs') | ||
114 | + plt.ylabel('Loss') | ||
115 | + plt.grid(True) | ||
116 | + plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) | ||
117 | + plt.close(fig) | ||
118 | + | ||
119 | + def get_loss_module(self): | ||
120 | + if self.n_GPUs == 1: | ||
121 | + return self.loss_module | ||
122 | + else: | ||
123 | + return self.loss_module.module | ||
124 | + | ||
125 | + def save(self, apath): | ||
126 | + torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) | ||
127 | + torch.save(self.log, os.path.join(apath, 'loss_log.pt')) | ||
128 | + | ||
129 | + def load(self, apath, cpu=False): | ||
130 | + if cpu: | ||
131 | + kwargs = {'map_location': lambda storage, loc: storage} | ||
132 | + else: | ||
133 | + kwargs = {} | ||
134 | + | ||
135 | + self.load_state_dict(torch.load( | ||
136 | + os.path.join(apath, 'loss.pt'), | ||
137 | + **kwargs | ||
138 | + )) | ||
139 | + self.log = torch.load(os.path.join(apath, 'loss_log.pt')) | ||
140 | + for l in self.get_loss_module(): | ||
141 | + if hasattr(l, 'scheduler'): | ||
142 | + for _ in range(len(self.log)): l.scheduler.step() | ||
143 | + |
edsr/src/loss/adversarial.py
0 → 100644
1 | +import utility | ||
2 | +from types import SimpleNamespace | ||
3 | + | ||
4 | +from model import common | ||
5 | +from loss import discriminator | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +import torch.nn.functional as F | ||
10 | +import torch.optim as optim | ||
11 | + | ||
12 | +class Adversarial(nn.Module): | ||
13 | + def __init__(self, args, gan_type): | ||
14 | + super(Adversarial, self).__init__() | ||
15 | + self.gan_type = gan_type | ||
16 | + self.gan_k = args.gan_k | ||
17 | + self.dis = discriminator.Discriminator(args) | ||
18 | + if gan_type == 'WGAN_GP': | ||
19 | + # see https://arxiv.org/pdf/1704.00028.pdf pp.4 | ||
20 | + optim_dict = { | ||
21 | + 'optimizer': 'ADAM', | ||
22 | + 'betas': (0, 0.9), | ||
23 | + 'epsilon': 1e-8, | ||
24 | + 'lr': 1e-5, | ||
25 | + 'weight_decay': args.weight_decay, | ||
26 | + 'decay': args.decay, | ||
27 | + 'gamma': args.gamma | ||
28 | + } | ||
29 | + optim_args = SimpleNamespace(**optim_dict) | ||
30 | + else: | ||
31 | + optim_args = args | ||
32 | + | ||
33 | + self.optimizer = utility.make_optimizer(optim_args, self.dis) | ||
34 | + | ||
35 | + def forward(self, fake, real): | ||
36 | + # updating discriminator... | ||
37 | + self.loss = 0 | ||
38 | + fake_detach = fake.detach() # do not backpropagate through G | ||
39 | + for _ in range(self.gan_k): | ||
40 | + self.optimizer.zero_grad() | ||
41 | + # d: B x 1 tensor | ||
42 | + d_fake = self.dis(fake_detach) | ||
43 | + d_real = self.dis(real) | ||
44 | + retain_graph = False | ||
45 | + if self.gan_type == 'GAN': | ||
46 | + loss_d = self.bce(d_real, d_fake) | ||
47 | + elif self.gan_type.find('WGAN') >= 0: | ||
48 | + loss_d = (d_fake - d_real).mean() | ||
49 | + if self.gan_type.find('GP') >= 0: | ||
50 | + epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) | ||
51 | + hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) | ||
52 | + hat.requires_grad = True | ||
53 | + d_hat = self.dis(hat) | ||
54 | + gradients = torch.autograd.grad( | ||
55 | + outputs=d_hat.sum(), inputs=hat, | ||
56 | + retain_graph=True, create_graph=True, only_inputs=True | ||
57 | + )[0] | ||
58 | + gradients = gradients.view(gradients.size(0), -1) | ||
59 | + gradient_norm = gradients.norm(2, dim=1) | ||
60 | + gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() | ||
61 | + loss_d += gradient_penalty | ||
62 | + # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks | ||
63 | + elif self.gan_type == 'RGAN': | ||
64 | + better_real = d_real - d_fake.mean(dim=0, keepdim=True) | ||
65 | + better_fake = d_fake - d_real.mean(dim=0, keepdim=True) | ||
66 | + loss_d = self.bce(better_real, better_fake) | ||
67 | + retain_graph = True | ||
68 | + | ||
69 | + # Discriminator update | ||
70 | + self.loss += loss_d.item() | ||
71 | + loss_d.backward(retain_graph=retain_graph) | ||
72 | + self.optimizer.step() | ||
73 | + | ||
74 | + if self.gan_type == 'WGAN': | ||
75 | + for p in self.dis.parameters(): | ||
76 | + p.data.clamp_(-1, 1) | ||
77 | + | ||
78 | + self.loss /= self.gan_k | ||
79 | + | ||
80 | + # updating generator... | ||
81 | + d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is | ||
82 | + if self.gan_type == 'GAN': | ||
83 | + label_real = torch.ones_like(d_fake_bp) | ||
84 | + loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) | ||
85 | + elif self.gan_type.find('WGAN') >= 0: | ||
86 | + loss_g = -d_fake_bp.mean() | ||
87 | + elif self.gan_type == 'RGAN': | ||
88 | + better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) | ||
89 | + better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) | ||
90 | + loss_g = self.bce(better_fake, better_real) | ||
91 | + | ||
92 | + # Generator loss | ||
93 | + return loss_g | ||
94 | + | ||
95 | + def state_dict(self, *args, **kwargs): | ||
96 | + state_discriminator = self.dis.state_dict(*args, **kwargs) | ||
97 | + state_optimizer = self.optimizer.state_dict() | ||
98 | + | ||
99 | + return dict(**state_discriminator, **state_optimizer) | ||
100 | + | ||
101 | + def bce(self, real, fake): | ||
102 | + label_real = torch.ones_like(real) | ||
103 | + label_fake = torch.zeros_like(fake) | ||
104 | + bce_real = F.binary_cross_entropy_with_logits(real, label_real) | ||
105 | + bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) | ||
106 | + bce_loss = bce_real + bce_fake | ||
107 | + return bce_loss | ||
108 | + | ||
109 | +# Some references | ||
110 | +# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py | ||
111 | +# OR | ||
112 | +# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py |
edsr/src/loss/discriminator.py
0 → 100644
1 | +from model import common | ||
2 | + | ||
3 | +import torch.nn as nn | ||
4 | + | ||
5 | +class Discriminator(nn.Module): | ||
6 | + ''' | ||
7 | + output is not normalized | ||
8 | + ''' | ||
9 | + def __init__(self, args): | ||
10 | + super(Discriminator, self).__init__() | ||
11 | + | ||
12 | + in_channels = args.n_colors | ||
13 | + out_channels = 64 | ||
14 | + depth = 7 | ||
15 | + | ||
16 | + def _block(_in_channels, _out_channels, stride=1): | ||
17 | + return nn.Sequential( | ||
18 | + nn.Conv2d( | ||
19 | + _in_channels, | ||
20 | + _out_channels, | ||
21 | + 3, | ||
22 | + padding=1, | ||
23 | + stride=stride, | ||
24 | + bias=False | ||
25 | + ), | ||
26 | + nn.BatchNorm2d(_out_channels), | ||
27 | + nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
28 | + ) | ||
29 | + | ||
30 | + m_features = [_block(in_channels, out_channels)] | ||
31 | + for i in range(depth): | ||
32 | + in_channels = out_channels | ||
33 | + if i % 2 == 1: | ||
34 | + stride = 1 | ||
35 | + out_channels *= 2 | ||
36 | + else: | ||
37 | + stride = 2 | ||
38 | + m_features.append(_block(in_channels, out_channels, stride=stride)) | ||
39 | + | ||
40 | + patch_size = args.patch_size // (2**((depth + 1) // 2)) | ||
41 | + m_classifier = [ | ||
42 | + nn.Linear(out_channels * patch_size**2, 1024), | ||
43 | + nn.LeakyReLU(negative_slope=0.2, inplace=True), | ||
44 | + nn.Linear(1024, 1) | ||
45 | + ] | ||
46 | + | ||
47 | + self.features = nn.Sequential(*m_features) | ||
48 | + self.classifier = nn.Sequential(*m_classifier) | ||
49 | + | ||
50 | + def forward(self, x): | ||
51 | + features = self.features(x) | ||
52 | + output = self.classifier(features.view(features.size(0), -1)) | ||
53 | + | ||
54 | + return output | ||
55 | + |
edsr/src/loss/vgg.py
0 → 100644
1 | +from model import common | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | +import torch.nn.functional as F | ||
6 | +import torchvision.models as models | ||
7 | + | ||
8 | +class VGG(nn.Module): | ||
9 | + def __init__(self, conv_index, rgb_range=1): | ||
10 | + super(VGG, self).__init__() | ||
11 | + vgg_features = models.vgg19(pretrained=True).features | ||
12 | + modules = [m for m in vgg_features] | ||
13 | + if conv_index.find('22') >= 0: | ||
14 | + self.vgg = nn.Sequential(*modules[:8]) | ||
15 | + elif conv_index.find('54') >= 0: | ||
16 | + self.vgg = nn.Sequential(*modules[:35]) | ||
17 | + | ||
18 | + vgg_mean = (0.485, 0.456, 0.406) | ||
19 | + vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) | ||
20 | + self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) | ||
21 | + for p in self.parameters(): | ||
22 | + p.requires_grad = False | ||
23 | + | ||
24 | + def forward(self, sr, hr): | ||
25 | + def _forward(x): | ||
26 | + x = self.sub_mean(x) | ||
27 | + x = self.vgg(x) | ||
28 | + return x | ||
29 | + | ||
30 | + vgg_sr = _forward(sr) | ||
31 | + with torch.no_grad(): | ||
32 | + vgg_hr = _forward(hr.detach()) | ||
33 | + | ||
34 | + loss = F.mse_loss(vgg_sr, vgg_hr) | ||
35 | + | ||
36 | + return loss |
edsr/src/main.py
0 → 100644
1 | +import torch | ||
2 | + | ||
3 | +import utility | ||
4 | +import data | ||
5 | +import model | ||
6 | +import loss | ||
7 | +from option import args | ||
8 | +from trainer import Trainer | ||
9 | + | ||
10 | +torch.manual_seed(args.seed) | ||
11 | +checkpoint = utility.checkpoint(args) | ||
12 | + | ||
13 | +def main(): | ||
14 | + global model | ||
15 | + if args.data_test == ['video']: | ||
16 | + from videotester import VideoTester | ||
17 | + model = model.Model(args, checkpoint) | ||
18 | + t = VideoTester(args, model, checkpoint) | ||
19 | + t.test() | ||
20 | + else: | ||
21 | + if checkpoint.ok: | ||
22 | + loader = data.Data(args) | ||
23 | + _model = model.Model(args, checkpoint) | ||
24 | + _loss = loss.Loss(args, checkpoint) if not args.test_only else None | ||
25 | + t = Trainer(args, loader, _model, _loss, checkpoint) | ||
26 | + while not t.terminate(): | ||
27 | + t.train() | ||
28 | + t.test() | ||
29 | + | ||
30 | + checkpoint.done() | ||
31 | + | ||
32 | +if __name__ == '__main__': | ||
33 | + main() |
edsr/src/model/__init__.py
0 → 100644
1 | +import os | ||
2 | +from importlib import import_module | ||
3 | + | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +import torch.nn.parallel as P | ||
7 | +import torch.utils.model_zoo | ||
8 | + | ||
9 | +class Model(nn.Module): | ||
10 | + def __init__(self, args, ckp): | ||
11 | + super(Model, self).__init__() | ||
12 | + print('Making model...') | ||
13 | + | ||
14 | + self.scale = args.scale | ||
15 | + self.idx_scale = 0 | ||
16 | + self.input_large = (args.model == 'VDSR') | ||
17 | + self.self_ensemble = args.self_ensemble | ||
18 | + self.chop = args.chop | ||
19 | + self.precision = args.precision | ||
20 | + self.cpu = args.cpu | ||
21 | + self.device = torch.device('cpu' if args.cpu else 'cuda') | ||
22 | + self.n_GPUs = args.n_GPUs | ||
23 | + self.save_models = args.save_models | ||
24 | + | ||
25 | + module = import_module('model.' + args.model.lower()) | ||
26 | + self.model = module.make_model(args).to(self.device) | ||
27 | + if args.precision == 'half': | ||
28 | + self.model.half() | ||
29 | + | ||
30 | + self.load( | ||
31 | + ckp.get_path('model'), | ||
32 | + pre_train=args.pre_train, | ||
33 | + resume=args.resume, | ||
34 | + cpu=args.cpu | ||
35 | + ) | ||
36 | + print(self.model, file=ckp.log_file) | ||
37 | + | ||
38 | + def forward(self, x, idx_scale): | ||
39 | + self.idx_scale = idx_scale | ||
40 | + if hasattr(self.model, 'set_scale'): | ||
41 | + self.model.set_scale(idx_scale) | ||
42 | + | ||
43 | + if self.training: | ||
44 | + if self.n_GPUs > 1: | ||
45 | + return P.data_parallel(self.model, x, range(self.n_GPUs)) | ||
46 | + else: | ||
47 | + return self.model(x) | ||
48 | + else: | ||
49 | + if self.chop: | ||
50 | + forward_function = self.forward_chop | ||
51 | + else: | ||
52 | + forward_function = self.model.forward | ||
53 | + | ||
54 | + if self.self_ensemble: | ||
55 | + return self.forward_x8(x, forward_function=forward_function) | ||
56 | + else: | ||
57 | + return forward_function(x) | ||
58 | + | ||
59 | + def save(self, apath, epoch, is_best=False): | ||
60 | + save_dirs = [os.path.join(apath, 'model_latest.pt')] | ||
61 | + | ||
62 | + if is_best: | ||
63 | + save_dirs.append(os.path.join(apath, 'model_best.pt')) | ||
64 | + if self.save_models: | ||
65 | + save_dirs.append( | ||
66 | + os.path.join(apath, 'model_{}.pt'.format(epoch)) | ||
67 | + ) | ||
68 | + | ||
69 | + for s in save_dirs: | ||
70 | + torch.save(self.model.state_dict(), s) | ||
71 | + | ||
72 | + def load(self, apath, pre_train='', resume=-1, cpu=False): | ||
73 | + load_from = None | ||
74 | + kwargs = {} | ||
75 | + if cpu: | ||
76 | + kwargs = {'map_location': lambda storage, loc: storage} | ||
77 | + | ||
78 | + if resume == -1: | ||
79 | + load_from = torch.load( | ||
80 | + os.path.join(apath, 'model_latest.pt'), | ||
81 | + **kwargs | ||
82 | + ) | ||
83 | + elif resume == 0: | ||
84 | + if pre_train == 'download': | ||
85 | + print('Download the model') | ||
86 | + dir_model = os.path.join('..', 'models') | ||
87 | + os.makedirs(dir_model, exist_ok=True) | ||
88 | + load_from = torch.utils.model_zoo.load_url( | ||
89 | + self.model.url, | ||
90 | + model_dir=dir_model, | ||
91 | + **kwargs | ||
92 | + ) | ||
93 | + elif pre_train: | ||
94 | + print('Load the model from {}'.format(pre_train)) | ||
95 | + load_from = torch.load(pre_train, **kwargs) | ||
96 | + else: | ||
97 | + load_from = torch.load( | ||
98 | + os.path.join(apath, 'model_{}.pt'.format(resume)), | ||
99 | + **kwargs | ||
100 | + ) | ||
101 | + | ||
102 | + if load_from: | ||
103 | + self.model.load_state_dict(load_from, strict=False) | ||
104 | + | ||
105 | + def forward_chop(self, *args, shave=10, min_size=160000): | ||
106 | + scale = 1 if self.input_large else self.scale[self.idx_scale] | ||
107 | + n_GPUs = min(self.n_GPUs, 4) | ||
108 | + # height, width | ||
109 | + h, w = args[0].size()[-2:] | ||
110 | + | ||
111 | + top = slice(0, h//2 + shave) | ||
112 | + bottom = slice(h - h//2 - shave, h) | ||
113 | + left = slice(0, w//2 + shave) | ||
114 | + right = slice(w - w//2 - shave, w) | ||
115 | + x_chops = [torch.cat([ | ||
116 | + a[..., top, left], | ||
117 | + a[..., top, right], | ||
118 | + a[..., bottom, left], | ||
119 | + a[..., bottom, right] | ||
120 | + ]) for a in args] | ||
121 | + | ||
122 | + y_chops = [] | ||
123 | + if h * w < 4 * min_size: | ||
124 | + for i in range(0, 4, n_GPUs): | ||
125 | + x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] | ||
126 | + y = P.data_parallel(self.model, *x, range(n_GPUs)) | ||
127 | + if not isinstance(y, list): y = [y] | ||
128 | + if not y_chops: | ||
129 | + y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] | ||
130 | + else: | ||
131 | + for y_chop, _y in zip(y_chops, y): | ||
132 | + y_chop.extend(_y.chunk(n_GPUs, dim=0)) | ||
133 | + else: | ||
134 | + for p in zip(*x_chops): | ||
135 | + y = self.forward_chop(*p, shave=shave, min_size=min_size) | ||
136 | + if not isinstance(y, list): y = [y] | ||
137 | + if not y_chops: | ||
138 | + y_chops = [[_y] for _y in y] | ||
139 | + else: | ||
140 | + for y_chop, _y in zip(y_chops, y): y_chop.append(_y) | ||
141 | + | ||
142 | + h *= scale | ||
143 | + w *= scale | ||
144 | + top = slice(0, h//2) | ||
145 | + bottom = slice(h - h//2, h) | ||
146 | + bottom_r = slice(h//2 - h, None) | ||
147 | + left = slice(0, w//2) | ||
148 | + right = slice(w - w//2, w) | ||
149 | + right_r = slice(w//2 - w, None) | ||
150 | + | ||
151 | + # batch size, number of color channels | ||
152 | + b, c = y_chops[0][0].size()[:-2] | ||
153 | + y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] | ||
154 | + for y_chop, _y in zip(y_chops, y): | ||
155 | + _y[..., top, left] = y_chop[0][..., top, left] | ||
156 | + _y[..., top, right] = y_chop[1][..., top, right_r] | ||
157 | + _y[..., bottom, left] = y_chop[2][..., bottom_r, left] | ||
158 | + _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] | ||
159 | + | ||
160 | + if len(y) == 1: y = y[0] | ||
161 | + | ||
162 | + return y | ||
163 | + | ||
164 | + def forward_x8(self, *args, forward_function=None): | ||
165 | + def _transform(v, op): | ||
166 | + if self.precision != 'single': v = v.float() | ||
167 | + | ||
168 | + v2np = v.data.cpu().numpy() | ||
169 | + if op == 'v': | ||
170 | + tfnp = v2np[:, :, :, ::-1].copy() | ||
171 | + elif op == 'h': | ||
172 | + tfnp = v2np[:, :, ::-1, :].copy() | ||
173 | + elif op == 't': | ||
174 | + tfnp = v2np.transpose((0, 1, 3, 2)).copy() | ||
175 | + | ||
176 | + ret = torch.Tensor(tfnp).to(self.device) | ||
177 | + if self.precision == 'half': ret = ret.half() | ||
178 | + | ||
179 | + return ret | ||
180 | + | ||
181 | + list_x = [] | ||
182 | + for a in args: | ||
183 | + x = [a] | ||
184 | + for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) | ||
185 | + | ||
186 | + list_x.append(x) | ||
187 | + | ||
188 | + list_y = [] | ||
189 | + for x in zip(*list_x): | ||
190 | + y = forward_function(*x) | ||
191 | + if not isinstance(y, list): y = [y] | ||
192 | + if not list_y: | ||
193 | + list_y = [[_y] for _y in y] | ||
194 | + else: | ||
195 | + for _list_y, _y in zip(list_y, y): _list_y.append(_y) | ||
196 | + | ||
197 | + for _list_y in list_y: | ||
198 | + for i in range(len(_list_y)): | ||
199 | + if i > 3: | ||
200 | + _list_y[i] = _transform(_list_y[i], 't') | ||
201 | + if i % 4 > 1: | ||
202 | + _list_y[i] = _transform(_list_y[i], 'h') | ||
203 | + if (i % 4) % 2 == 1: | ||
204 | + _list_y[i] = _transform(_list_y[i], 'v') | ||
205 | + | ||
206 | + y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] | ||
207 | + if len(y) == 1: y = y[0] | ||
208 | + | ||
209 | + return y |
edsr/src/model/common.py
0 → 100644
1 | +import math | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | +import torch.nn.functional as F | ||
6 | + | ||
7 | +def default_conv(in_channels, out_channels, kernel_size, bias=True): | ||
8 | + return nn.Conv2d( | ||
9 | + in_channels, out_channels, kernel_size, | ||
10 | + padding=(kernel_size//2), bias=bias) | ||
11 | + | ||
12 | +class MeanShift(nn.Conv2d): | ||
13 | + def __init__( | ||
14 | + self, rgb_range, | ||
15 | + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | ||
16 | + | ||
17 | + super(MeanShift, self).__init__(3, 3, kernel_size=1) | ||
18 | + std = torch.Tensor(rgb_std) | ||
19 | + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | ||
20 | + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | ||
21 | + for p in self.parameters(): | ||
22 | + p.requires_grad = False | ||
23 | + | ||
24 | +class BasicBlock(nn.Sequential): | ||
25 | + def __init__( | ||
26 | + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, | ||
27 | + bn=True, act=nn.ReLU(True)): | ||
28 | + | ||
29 | + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] | ||
30 | + if bn: | ||
31 | + m.append(nn.BatchNorm2d(out_channels)) | ||
32 | + if act is not None: | ||
33 | + m.append(act) | ||
34 | + | ||
35 | + super(BasicBlock, self).__init__(*m) | ||
36 | + | ||
37 | +class ResBlock(nn.Module): | ||
38 | + def __init__( | ||
39 | + self, conv, n_feats, kernel_size, | ||
40 | + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | ||
41 | + | ||
42 | + super(ResBlock, self).__init__() | ||
43 | + m = [] | ||
44 | + for i in range(2): | ||
45 | + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | ||
46 | + if bn: | ||
47 | + m.append(nn.BatchNorm2d(n_feats)) | ||
48 | + if i == 0: | ||
49 | + m.append(act) | ||
50 | + | ||
51 | + self.body = nn.Sequential(*m) | ||
52 | + self.res_scale = res_scale | ||
53 | + | ||
54 | + def forward(self, x): | ||
55 | + res = self.body(x).mul(self.res_scale) | ||
56 | + res += x | ||
57 | + | ||
58 | + return res | ||
59 | + | ||
60 | +class Upsampler(nn.Sequential): | ||
61 | + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | ||
62 | + | ||
63 | + m = [] | ||
64 | + if (scale & (scale - 1)) == 0: # Is scale = 2^n? | ||
65 | + for _ in range(int(math.log(scale, 2))): | ||
66 | + m.append(conv(n_feats, 4 * n_feats, 3, bias)) | ||
67 | + m.append(nn.PixelShuffle(2)) | ||
68 | + if bn: | ||
69 | + m.append(nn.BatchNorm2d(n_feats)) | ||
70 | + if act == 'relu': | ||
71 | + m.append(nn.ReLU(True)) | ||
72 | + elif act == 'prelu': | ||
73 | + m.append(nn.PReLU(n_feats)) | ||
74 | + | ||
75 | + elif scale == 3: | ||
76 | + m.append(conv(n_feats, 9 * n_feats, 3, bias)) | ||
77 | + m.append(nn.PixelShuffle(3)) | ||
78 | + if bn: | ||
79 | + m.append(nn.BatchNorm2d(n_feats)) | ||
80 | + if act == 'relu': | ||
81 | + m.append(nn.ReLU(True)) | ||
82 | + elif act == 'prelu': | ||
83 | + m.append(nn.PReLU(n_feats)) | ||
84 | + else: | ||
85 | + raise NotImplementedError | ||
86 | + | ||
87 | + super(Upsampler, self).__init__(*m) | ||
88 | + |
edsr/src/model/ddbpn.py
0 → 100644
1 | +# Deep Back-Projection Networks For Super-Resolution | ||
2 | +# https://arxiv.org/abs/1803.02735 | ||
3 | + | ||
4 | +from model import common | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | + | ||
9 | + | ||
10 | +def make_model(args, parent=False): | ||
11 | + return DDBPN(args) | ||
12 | + | ||
13 | +def projection_conv(in_channels, out_channels, scale, up=True): | ||
14 | + kernel_size, stride, padding = { | ||
15 | + 2: (6, 2, 2), | ||
16 | + 4: (8, 4, 2), | ||
17 | + 8: (12, 8, 2) | ||
18 | + }[scale] | ||
19 | + if up: | ||
20 | + conv_f = nn.ConvTranspose2d | ||
21 | + else: | ||
22 | + conv_f = nn.Conv2d | ||
23 | + | ||
24 | + return conv_f( | ||
25 | + in_channels, out_channels, kernel_size, | ||
26 | + stride=stride, padding=padding | ||
27 | + ) | ||
28 | + | ||
29 | +class DenseProjection(nn.Module): | ||
30 | + def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): | ||
31 | + super(DenseProjection, self).__init__() | ||
32 | + if bottleneck: | ||
33 | + self.bottleneck = nn.Sequential(*[ | ||
34 | + nn.Conv2d(in_channels, nr, 1), | ||
35 | + nn.PReLU(nr) | ||
36 | + ]) | ||
37 | + inter_channels = nr | ||
38 | + else: | ||
39 | + self.bottleneck = None | ||
40 | + inter_channels = in_channels | ||
41 | + | ||
42 | + self.conv_1 = nn.Sequential(*[ | ||
43 | + projection_conv(inter_channels, nr, scale, up), | ||
44 | + nn.PReLU(nr) | ||
45 | + ]) | ||
46 | + self.conv_2 = nn.Sequential(*[ | ||
47 | + projection_conv(nr, inter_channels, scale, not up), | ||
48 | + nn.PReLU(inter_channels) | ||
49 | + ]) | ||
50 | + self.conv_3 = nn.Sequential(*[ | ||
51 | + projection_conv(inter_channels, nr, scale, up), | ||
52 | + nn.PReLU(nr) | ||
53 | + ]) | ||
54 | + | ||
55 | + def forward(self, x): | ||
56 | + if self.bottleneck is not None: | ||
57 | + x = self.bottleneck(x) | ||
58 | + | ||
59 | + a_0 = self.conv_1(x) | ||
60 | + b_0 = self.conv_2(a_0) | ||
61 | + e = b_0.sub(x) | ||
62 | + a_1 = self.conv_3(e) | ||
63 | + | ||
64 | + out = a_0.add(a_1) | ||
65 | + | ||
66 | + return out | ||
67 | + | ||
68 | +class DDBPN(nn.Module): | ||
69 | + def __init__(self, args): | ||
70 | + super(DDBPN, self).__init__() | ||
71 | + scale = args.scale[0] | ||
72 | + | ||
73 | + n0 = 128 | ||
74 | + nr = 32 | ||
75 | + self.depth = 6 | ||
76 | + | ||
77 | + rgb_mean = (0.4488, 0.4371, 0.4040) | ||
78 | + rgb_std = (1.0, 1.0, 1.0) | ||
79 | + self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) | ||
80 | + initial = [ | ||
81 | + nn.Conv2d(args.n_colors, n0, 3, padding=1), | ||
82 | + nn.PReLU(n0), | ||
83 | + nn.Conv2d(n0, nr, 1), | ||
84 | + nn.PReLU(nr) | ||
85 | + ] | ||
86 | + self.initial = nn.Sequential(*initial) | ||
87 | + | ||
88 | + self.upmodules = nn.ModuleList() | ||
89 | + self.downmodules = nn.ModuleList() | ||
90 | + channels = nr | ||
91 | + for i in range(self.depth): | ||
92 | + self.upmodules.append( | ||
93 | + DenseProjection(channels, nr, scale, True, i > 1) | ||
94 | + ) | ||
95 | + if i != 0: | ||
96 | + channels += nr | ||
97 | + | ||
98 | + channels = nr | ||
99 | + for i in range(self.depth - 1): | ||
100 | + self.downmodules.append( | ||
101 | + DenseProjection(channels, nr, scale, False, i != 0) | ||
102 | + ) | ||
103 | + channels += nr | ||
104 | + | ||
105 | + reconstruction = [ | ||
106 | + nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) | ||
107 | + ] | ||
108 | + self.reconstruction = nn.Sequential(*reconstruction) | ||
109 | + | ||
110 | + self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) | ||
111 | + | ||
112 | + def forward(self, x): | ||
113 | + x = self.sub_mean(x) | ||
114 | + x = self.initial(x) | ||
115 | + | ||
116 | + h_list = [] | ||
117 | + l_list = [] | ||
118 | + for i in range(self.depth - 1): | ||
119 | + if i == 0: | ||
120 | + l = x | ||
121 | + else: | ||
122 | + l = torch.cat(l_list, dim=1) | ||
123 | + h_list.append(self.upmodules[i](l)) | ||
124 | + l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) | ||
125 | + | ||
126 | + h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) | ||
127 | + out = self.reconstruction(torch.cat(h_list, dim=1)) | ||
128 | + out = self.add_mean(out) | ||
129 | + | ||
130 | + return out | ||
131 | + |
edsr/src/model/edsr.py
0 → 100644
1 | +from model import common | ||
2 | + | ||
3 | +import torch.nn as nn | ||
4 | + | ||
5 | +url = { | ||
6 | + 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', | ||
7 | + 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', | ||
8 | + 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', | ||
9 | + 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', | ||
10 | + 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', | ||
11 | + 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' | ||
12 | +} | ||
13 | + | ||
14 | +def make_model(args, parent=False): | ||
15 | + return EDSR(args) | ||
16 | + | ||
17 | +class EDSR(nn.Module): | ||
18 | + def __init__(self, args, conv=common.default_conv): | ||
19 | + super(EDSR, self).__init__() | ||
20 | + | ||
21 | + n_resblocks = args.n_resblocks | ||
22 | + n_feats = args.n_feats | ||
23 | + kernel_size = 3 | ||
24 | + scale = args.scale[0] | ||
25 | + act = nn.ReLU(True) | ||
26 | + url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) | ||
27 | + if url_name in url: | ||
28 | + self.url = url[url_name] | ||
29 | + else: | ||
30 | + self.url = None | ||
31 | + self.sub_mean = common.MeanShift(args.rgb_range) | ||
32 | + self.add_mean = common.MeanShift(args.rgb_range, sign=1) | ||
33 | + | ||
34 | + # define head module | ||
35 | + m_head = [conv(args.n_colors, n_feats, kernel_size)] | ||
36 | + | ||
37 | + # define body module | ||
38 | + m_body = [ | ||
39 | + common.ResBlock( | ||
40 | + conv, n_feats, kernel_size, act=act, res_scale=args.res_scale | ||
41 | + ) for _ in range(n_resblocks) | ||
42 | + ] | ||
43 | + m_body.append(conv(n_feats, n_feats, kernel_size)) | ||
44 | + | ||
45 | + # define tail module | ||
46 | + m_tail = [ | ||
47 | + common.Upsampler(conv, scale, n_feats, act=False), | ||
48 | + conv(n_feats, args.n_colors, kernel_size) | ||
49 | + ] | ||
50 | + | ||
51 | + self.head = nn.Sequential(*m_head) | ||
52 | + self.body = nn.Sequential(*m_body) | ||
53 | + self.tail = nn.Sequential(*m_tail) | ||
54 | + | ||
55 | + def forward(self, x): | ||
56 | + x = self.sub_mean(x) | ||
57 | + x = self.head(x) | ||
58 | + | ||
59 | + res = self.body(x) | ||
60 | + res += x | ||
61 | + | ||
62 | + x = self.tail(res) | ||
63 | + x = self.add_mean(x) | ||
64 | + | ||
65 | + return x | ||
66 | + | ||
67 | + def load_state_dict(self, state_dict, strict=True): | ||
68 | + own_state = self.state_dict() | ||
69 | + for name, param in state_dict.items(): | ||
70 | + if name in own_state: | ||
71 | + if isinstance(param, nn.Parameter): | ||
72 | + param = param.data | ||
73 | + try: | ||
74 | + own_state[name].copy_(param) | ||
75 | + except Exception: | ||
76 | + if name.find('tail') == -1: | ||
77 | + raise RuntimeError('While copying the parameter named {}, ' | ||
78 | + 'whose dimensions in the model are {} and ' | ||
79 | + 'whose dimensions in the checkpoint are {}.' | ||
80 | + .format(name, own_state[name].size(), param.size())) | ||
81 | + elif strict: | ||
82 | + if name.find('tail') == -1: | ||
83 | + raise KeyError('unexpected key "{}" in state_dict' | ||
84 | + .format(name)) | ||
85 | + |
edsr/src/model/mdsr.py
0 → 100644
1 | +from model import common | ||
2 | + | ||
3 | +import torch.nn as nn | ||
4 | + | ||
5 | +url = { | ||
6 | + 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt', | ||
7 | + 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt' | ||
8 | +} | ||
9 | + | ||
10 | +def make_model(args, parent=False): | ||
11 | + return MDSR(args) | ||
12 | + | ||
13 | +class MDSR(nn.Module): | ||
14 | + def __init__(self, args, conv=common.default_conv): | ||
15 | + super(MDSR, self).__init__() | ||
16 | + n_resblocks = args.n_resblocks | ||
17 | + n_feats = args.n_feats | ||
18 | + kernel_size = 3 | ||
19 | + act = nn.ReLU(True) | ||
20 | + self.scale_idx = 0 | ||
21 | + self.url = url['r{}f{}'.format(n_resblocks, n_feats)] | ||
22 | + self.sub_mean = common.MeanShift(args.rgb_range) | ||
23 | + self.add_mean = common.MeanShift(args.rgb_range, sign=1) | ||
24 | + | ||
25 | + m_head = [conv(args.n_colors, n_feats, kernel_size)] | ||
26 | + | ||
27 | + self.pre_process = nn.ModuleList([ | ||
28 | + nn.Sequential( | ||
29 | + common.ResBlock(conv, n_feats, 5, act=act), | ||
30 | + common.ResBlock(conv, n_feats, 5, act=act) | ||
31 | + ) for _ in args.scale | ||
32 | + ]) | ||
33 | + | ||
34 | + m_body = [ | ||
35 | + common.ResBlock( | ||
36 | + conv, n_feats, kernel_size, act=act | ||
37 | + ) for _ in range(n_resblocks) | ||
38 | + ] | ||
39 | + m_body.append(conv(n_feats, n_feats, kernel_size)) | ||
40 | + | ||
41 | + self.upsample = nn.ModuleList([ | ||
42 | + common.Upsampler(conv, s, n_feats, act=False) for s in args.scale | ||
43 | + ]) | ||
44 | + | ||
45 | + m_tail = [conv(n_feats, args.n_colors, kernel_size)] | ||
46 | + | ||
47 | + self.head = nn.Sequential(*m_head) | ||
48 | + self.body = nn.Sequential(*m_body) | ||
49 | + self.tail = nn.Sequential(*m_tail) | ||
50 | + | ||
51 | + def forward(self, x): | ||
52 | + x = self.sub_mean(x) | ||
53 | + x = self.head(x) | ||
54 | + x = self.pre_process[self.scale_idx](x) | ||
55 | + | ||
56 | + res = self.body(x) | ||
57 | + res += x | ||
58 | + | ||
59 | + x = self.upsample[self.scale_idx](res) | ||
60 | + x = self.tail(x) | ||
61 | + x = self.add_mean(x) | ||
62 | + | ||
63 | + return x | ||
64 | + | ||
65 | + def set_scale(self, scale_idx): | ||
66 | + self.scale_idx = scale_idx | ||
67 | + |
edsr/src/model/rcan.py
0 → 100644
1 | +## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks | ||
2 | +## https://arxiv.org/abs/1807.02758 | ||
3 | +from model import common | ||
4 | + | ||
5 | +import torch.nn as nn | ||
6 | + | ||
7 | +def make_model(args, parent=False): | ||
8 | + return RCAN(args) | ||
9 | + | ||
10 | +## Channel Attention (CA) Layer | ||
11 | +class CALayer(nn.Module): | ||
12 | + def __init__(self, channel, reduction=16): | ||
13 | + super(CALayer, self).__init__() | ||
14 | + # global average pooling: feature --> point | ||
15 | + self.avg_pool = nn.AdaptiveAvgPool2d(1) | ||
16 | + # feature channel downscale and upscale --> channel weight | ||
17 | + self.conv_du = nn.Sequential( | ||
18 | + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), | ||
19 | + nn.ReLU(inplace=True), | ||
20 | + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), | ||
21 | + nn.Sigmoid() | ||
22 | + ) | ||
23 | + | ||
24 | + def forward(self, x): | ||
25 | + y = self.avg_pool(x) | ||
26 | + y = self.conv_du(y) | ||
27 | + return x * y | ||
28 | + | ||
29 | +## Residual Channel Attention Block (RCAB) | ||
30 | +class RCAB(nn.Module): | ||
31 | + def __init__( | ||
32 | + self, conv, n_feat, kernel_size, reduction, | ||
33 | + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | ||
34 | + | ||
35 | + super(RCAB, self).__init__() | ||
36 | + modules_body = [] | ||
37 | + for i in range(2): | ||
38 | + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | ||
39 | + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) | ||
40 | + if i == 0: modules_body.append(act) | ||
41 | + modules_body.append(CALayer(n_feat, reduction)) | ||
42 | + self.body = nn.Sequential(*modules_body) | ||
43 | + self.res_scale = res_scale | ||
44 | + | ||
45 | + def forward(self, x): | ||
46 | + res = self.body(x) | ||
47 | + #res = self.body(x).mul(self.res_scale) | ||
48 | + res += x | ||
49 | + return res | ||
50 | + | ||
51 | +## Residual Group (RG) | ||
52 | +class ResidualGroup(nn.Module): | ||
53 | + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): | ||
54 | + super(ResidualGroup, self).__init__() | ||
55 | + modules_body = [] | ||
56 | + modules_body = [ | ||
57 | + RCAB( | ||
58 | + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ | ||
59 | + for _ in range(n_resblocks)] | ||
60 | + modules_body.append(conv(n_feat, n_feat, kernel_size)) | ||
61 | + self.body = nn.Sequential(*modules_body) | ||
62 | + | ||
63 | + def forward(self, x): | ||
64 | + res = self.body(x) | ||
65 | + res += x | ||
66 | + return res | ||
67 | + | ||
68 | +## Residual Channel Attention Network (RCAN) | ||
69 | +class RCAN(nn.Module): | ||
70 | + def __init__(self, args, conv=common.default_conv): | ||
71 | + super(RCAN, self).__init__() | ||
72 | + | ||
73 | + n_resgroups = args.n_resgroups | ||
74 | + n_resblocks = args.n_resblocks | ||
75 | + n_feats = args.n_feats | ||
76 | + kernel_size = 3 | ||
77 | + reduction = args.reduction | ||
78 | + scale = args.scale[0] | ||
79 | + act = nn.ReLU(True) | ||
80 | + | ||
81 | + # RGB mean for DIV2K | ||
82 | + self.sub_mean = common.MeanShift(args.rgb_range) | ||
83 | + | ||
84 | + # define head module | ||
85 | + modules_head = [conv(args.n_colors, n_feats, kernel_size)] | ||
86 | + | ||
87 | + # define body module | ||
88 | + modules_body = [ | ||
89 | + ResidualGroup( | ||
90 | + conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ | ||
91 | + for _ in range(n_resgroups)] | ||
92 | + | ||
93 | + modules_body.append(conv(n_feats, n_feats, kernel_size)) | ||
94 | + | ||
95 | + # define tail module | ||
96 | + modules_tail = [ | ||
97 | + common.Upsampler(conv, scale, n_feats, act=False), | ||
98 | + conv(n_feats, args.n_colors, kernel_size)] | ||
99 | + | ||
100 | + self.add_mean = common.MeanShift(args.rgb_range, sign=1) | ||
101 | + | ||
102 | + self.head = nn.Sequential(*modules_head) | ||
103 | + self.body = nn.Sequential(*modules_body) | ||
104 | + self.tail = nn.Sequential(*modules_tail) | ||
105 | + | ||
106 | + def forward(self, x): | ||
107 | + x = self.sub_mean(x) | ||
108 | + x = self.head(x) | ||
109 | + | ||
110 | + res = self.body(x) | ||
111 | + res += x | ||
112 | + | ||
113 | + x = self.tail(res) | ||
114 | + x = self.add_mean(x) | ||
115 | + | ||
116 | + return x | ||
117 | + | ||
118 | + def load_state_dict(self, state_dict, strict=False): | ||
119 | + own_state = self.state_dict() | ||
120 | + for name, param in state_dict.items(): | ||
121 | + if name in own_state: | ||
122 | + if isinstance(param, nn.Parameter): | ||
123 | + param = param.data | ||
124 | + try: | ||
125 | + own_state[name].copy_(param) | ||
126 | + except Exception: | ||
127 | + if name.find('tail') >= 0: | ||
128 | + print('Replace pre-trained upsampler to new one...') | ||
129 | + else: | ||
130 | + raise RuntimeError('While copying the parameter named {}, ' | ||
131 | + 'whose dimensions in the model are {} and ' | ||
132 | + 'whose dimensions in the checkpoint are {}.' | ||
133 | + .format(name, own_state[name].size(), param.size())) | ||
134 | + elif strict: | ||
135 | + if name.find('tail') == -1: | ||
136 | + raise KeyError('unexpected key "{}" in state_dict' | ||
137 | + .format(name)) | ||
138 | + | ||
139 | + if strict: | ||
140 | + missing = set(own_state.keys()) - set(state_dict.keys()) | ||
141 | + if len(missing) > 0: | ||
142 | + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) |
edsr/src/model/rdn.py
0 → 100644
1 | +# Residual Dense Network for Image Super-Resolution | ||
2 | +# https://arxiv.org/abs/1802.08797 | ||
3 | + | ||
4 | +from model import common | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | + | ||
9 | + | ||
10 | +def make_model(args, parent=False): | ||
11 | + return RDN(args) | ||
12 | + | ||
13 | +class RDB_Conv(nn.Module): | ||
14 | + def __init__(self, inChannels, growRate, kSize=3): | ||
15 | + super(RDB_Conv, self).__init__() | ||
16 | + Cin = inChannels | ||
17 | + G = growRate | ||
18 | + self.conv = nn.Sequential(*[ | ||
19 | + nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), | ||
20 | + nn.ReLU() | ||
21 | + ]) | ||
22 | + | ||
23 | + def forward(self, x): | ||
24 | + out = self.conv(x) | ||
25 | + return torch.cat((x, out), 1) | ||
26 | + | ||
27 | +class RDB(nn.Module): | ||
28 | + def __init__(self, growRate0, growRate, nConvLayers, kSize=3): | ||
29 | + super(RDB, self).__init__() | ||
30 | + G0 = growRate0 | ||
31 | + G = growRate | ||
32 | + C = nConvLayers | ||
33 | + | ||
34 | + convs = [] | ||
35 | + for c in range(C): | ||
36 | + convs.append(RDB_Conv(G0 + c*G, G)) | ||
37 | + self.convs = nn.Sequential(*convs) | ||
38 | + | ||
39 | + # Local Feature Fusion | ||
40 | + self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) | ||
41 | + | ||
42 | + def forward(self, x): | ||
43 | + return self.LFF(self.convs(x)) + x | ||
44 | + | ||
45 | +class RDN(nn.Module): | ||
46 | + def __init__(self, args): | ||
47 | + super(RDN, self).__init__() | ||
48 | + r = args.scale[0] | ||
49 | + G0 = args.G0 | ||
50 | + kSize = args.RDNkSize | ||
51 | + | ||
52 | + # number of RDB blocks, conv layers, out channels | ||
53 | + self.D, C, G = { | ||
54 | + 'A': (20, 6, 32), | ||
55 | + 'B': (16, 8, 64), | ||
56 | + }[args.RDNconfig] | ||
57 | + | ||
58 | + # Shallow feature extraction net | ||
59 | + self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) | ||
60 | + self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) | ||
61 | + | ||
62 | + # Redidual dense blocks and dense feature fusion | ||
63 | + self.RDBs = nn.ModuleList() | ||
64 | + for i in range(self.D): | ||
65 | + self.RDBs.append( | ||
66 | + RDB(growRate0 = G0, growRate = G, nConvLayers = C) | ||
67 | + ) | ||
68 | + | ||
69 | + # Global Feature Fusion | ||
70 | + self.GFF = nn.Sequential(*[ | ||
71 | + nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), | ||
72 | + nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) | ||
73 | + ]) | ||
74 | + | ||
75 | + # Up-sampling net | ||
76 | + if r == 2 or r == 3: | ||
77 | + self.UPNet = nn.Sequential(*[ | ||
78 | + nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), | ||
79 | + nn.PixelShuffle(r), | ||
80 | + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) | ||
81 | + ]) | ||
82 | + elif r == 4: | ||
83 | + self.UPNet = nn.Sequential(*[ | ||
84 | + nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), | ||
85 | + nn.PixelShuffle(2), | ||
86 | + nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), | ||
87 | + nn.PixelShuffle(2), | ||
88 | + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) | ||
89 | + ]) | ||
90 | + else: | ||
91 | + raise ValueError("scale must be 2 or 3 or 4.") | ||
92 | + | ||
93 | + def forward(self, x): | ||
94 | + f__1 = self.SFENet1(x) | ||
95 | + x = self.SFENet2(f__1) | ||
96 | + | ||
97 | + RDBs_out = [] | ||
98 | + for i in range(self.D): | ||
99 | + x = self.RDBs[i](x) | ||
100 | + RDBs_out.append(x) | ||
101 | + | ||
102 | + x = self.GFF(torch.cat(RDBs_out,1)) | ||
103 | + x += f__1 | ||
104 | + | ||
105 | + return self.UPNet(x) |
edsr/src/model/vdsr.py
0 → 100644
1 | +from model import common | ||
2 | + | ||
3 | +import torch.nn as nn | ||
4 | +import torch.nn.init as init | ||
5 | + | ||
6 | +url = { | ||
7 | + 'r20f64': '' | ||
8 | +} | ||
9 | + | ||
10 | +def make_model(args, parent=False): | ||
11 | + return VDSR(args) | ||
12 | + | ||
13 | +class VDSR(nn.Module): | ||
14 | + def __init__(self, args, conv=common.default_conv): | ||
15 | + super(VDSR, self).__init__() | ||
16 | + | ||
17 | + n_resblocks = args.n_resblocks | ||
18 | + n_feats = args.n_feats | ||
19 | + kernel_size = 3 | ||
20 | + self.url = url['r{}f{}'.format(n_resblocks, n_feats)] | ||
21 | + self.sub_mean = common.MeanShift(args.rgb_range) | ||
22 | + self.add_mean = common.MeanShift(args.rgb_range, sign=1) | ||
23 | + | ||
24 | + def basic_block(in_channels, out_channels, act): | ||
25 | + return common.BasicBlock( | ||
26 | + conv, in_channels, out_channels, kernel_size, | ||
27 | + bias=True, bn=False, act=act | ||
28 | + ) | ||
29 | + | ||
30 | + # define body module | ||
31 | + m_body = [] | ||
32 | + m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) | ||
33 | + for _ in range(n_resblocks - 2): | ||
34 | + m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) | ||
35 | + m_body.append(basic_block(n_feats, args.n_colors, None)) | ||
36 | + | ||
37 | + self.body = nn.Sequential(*m_body) | ||
38 | + | ||
39 | + def forward(self, x): | ||
40 | + x = self.sub_mean(x) | ||
41 | + res = self.body(x) | ||
42 | + res += x | ||
43 | + x = self.add_mean(res) | ||
44 | + | ||
45 | + return x | ||
46 | + |
edsr/src/option.py
0 → 100644
1 | +import argparse | ||
2 | +import template | ||
3 | + | ||
4 | +parser = argparse.ArgumentParser(description='EDSR and MDSR') | ||
5 | + | ||
6 | +parser.add_argument('--debug', action='store_true', | ||
7 | + help='Enables debug mode') | ||
8 | +parser.add_argument('--template', default='.', | ||
9 | + help='You can set various templates in option.py') | ||
10 | + | ||
11 | +# Hardware specifications | ||
12 | +parser.add_argument('--n_threads', type=int, default=6, | ||
13 | + help='number of threads for data loading') | ||
14 | +parser.add_argument('--cpu', action='store_true', | ||
15 | + help='use cpu only') | ||
16 | +parser.add_argument('--n_GPUs', type=int, default=1, | ||
17 | + help='number of GPUs') | ||
18 | +parser.add_argument('--seed', type=int, default=1, | ||
19 | + help='random seed') | ||
20 | + | ||
21 | +# Data specifications | ||
22 | +parser.add_argument('--dir_data', type=str, default='../../../dataset', | ||
23 | + help='dataset directory') | ||
24 | +parser.add_argument('--dir_demo', type=str, default='../test', | ||
25 | + help='demo image directory') | ||
26 | +parser.add_argument('--data_train', type=str, default='DIV2K', | ||
27 | + help='train dataset name') | ||
28 | +parser.add_argument('--data_test', type=str, default='DIV2K', | ||
29 | + help='test dataset name') | ||
30 | +parser.add_argument('--data_range', type=str, default='1-800/801-810', | ||
31 | + help='train/test data range') | ||
32 | +parser.add_argument('--ext', type=str, default='sep', | ||
33 | + help='dataset file extension') | ||
34 | +parser.add_argument('--scale', type=str, default='4', | ||
35 | + help='super resolution scale') | ||
36 | +parser.add_argument('--patch_size', type=int, default=192, | ||
37 | + help='output patch size') | ||
38 | +parser.add_argument('--rgb_range', type=int, default=255, | ||
39 | + help='maximum value of RGB') | ||
40 | +parser.add_argument('--n_colors', type=int, default=3, | ||
41 | + help='number of color channels to use') | ||
42 | +parser.add_argument('--chop', action='store_true', | ||
43 | + help='enable memory-efficient forward') | ||
44 | +parser.add_argument('--no_augment', action='store_true', | ||
45 | + help='do not use data augmentation') | ||
46 | + | ||
47 | +# Model specifications | ||
48 | +parser.add_argument('--model', default='EDSR', | ||
49 | + help='model name') | ||
50 | + | ||
51 | +parser.add_argument('--act', type=str, default='relu', | ||
52 | + help='activation function') | ||
53 | +parser.add_argument('--pre_train', type=str, default='', | ||
54 | + help='pre-trained model directory') | ||
55 | +parser.add_argument('--extend', type=str, default='.', | ||
56 | + help='pre-trained model directory') | ||
57 | +parser.add_argument('--n_resblocks', type=int, default=16, | ||
58 | + help='number of residual blocks') | ||
59 | +parser.add_argument('--n_feats', type=int, default=64, | ||
60 | + help='number of feature maps') | ||
61 | +parser.add_argument('--res_scale', type=float, default=1, | ||
62 | + help='residual scaling') | ||
63 | +parser.add_argument('--shift_mean', default=True, | ||
64 | + help='subtract pixel mean from the input') | ||
65 | +parser.add_argument('--dilation', action='store_true', | ||
66 | + help='use dilated convolution') | ||
67 | +parser.add_argument('--precision', type=str, default='single', | ||
68 | + choices=('single', 'half'), | ||
69 | + help='FP precision for test (single | half)') | ||
70 | + | ||
71 | +# Option for Residual dense network (RDN) | ||
72 | +parser.add_argument('--G0', type=int, default=64, | ||
73 | + help='default number of filters. (Use in RDN)') | ||
74 | +parser.add_argument('--RDNkSize', type=int, default=3, | ||
75 | + help='default kernel size. (Use in RDN)') | ||
76 | +parser.add_argument('--RDNconfig', type=str, default='B', | ||
77 | + help='parameters config of RDN. (Use in RDN)') | ||
78 | + | ||
79 | +# Option for Residual channel attention network (RCAN) | ||
80 | +parser.add_argument('--n_resgroups', type=int, default=10, | ||
81 | + help='number of residual groups') | ||
82 | +parser.add_argument('--reduction', type=int, default=16, | ||
83 | + help='number of feature maps reduction') | ||
84 | + | ||
85 | +# Training specifications | ||
86 | +parser.add_argument('--reset', action='store_true', | ||
87 | + help='reset the training') | ||
88 | +parser.add_argument('--test_every', type=int, default=1000, | ||
89 | + help='do test per every N batches') | ||
90 | +parser.add_argument('--epochs', type=int, default=300, | ||
91 | + help='number of epochs to train') | ||
92 | +parser.add_argument('--batch_size', type=int, default=16, | ||
93 | + help='input batch size for training') | ||
94 | +parser.add_argument('--split_batch', type=int, default=1, | ||
95 | + help='split the batch into smaller chunks') | ||
96 | +parser.add_argument('--self_ensemble', action='store_true', | ||
97 | + help='use self-ensemble method for test') | ||
98 | +parser.add_argument('--test_only', action='store_true', | ||
99 | + help='set this option to test the model') | ||
100 | +parser.add_argument('--gan_k', type=int, default=1, | ||
101 | + help='k value for adversarial loss') | ||
102 | + | ||
103 | +# Optimization specifications | ||
104 | +parser.add_argument('--lr', type=float, default=1e-4, | ||
105 | + help='learning rate') | ||
106 | +parser.add_argument('--decay', type=str, default='200', | ||
107 | + help='learning rate decay type') | ||
108 | +parser.add_argument('--gamma', type=float, default=0.5, | ||
109 | + help='learning rate decay factor for step decay') | ||
110 | +parser.add_argument('--optimizer', default='ADAM', | ||
111 | + choices=('SGD', 'ADAM', 'RMSprop'), | ||
112 | + help='optimizer to use (SGD | ADAM | RMSprop)') | ||
113 | +parser.add_argument('--momentum', type=float, default=0.9, | ||
114 | + help='SGD momentum') | ||
115 | +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), | ||
116 | + help='ADAM beta') | ||
117 | +parser.add_argument('--epsilon', type=float, default=1e-8, | ||
118 | + help='ADAM epsilon for numerical stability') | ||
119 | +parser.add_argument('--weight_decay', type=float, default=0, | ||
120 | + help='weight decay') | ||
121 | +parser.add_argument('--gclip', type=float, default=0, | ||
122 | + help='gradient clipping threshold (0 = no clipping)') | ||
123 | + | ||
124 | +# Loss specifications | ||
125 | +parser.add_argument('--loss', type=str, default='1*L1', | ||
126 | + help='loss function configuration') | ||
127 | +parser.add_argument('--skip_threshold', type=float, default='1e8', | ||
128 | + help='skipping batch that has large error') | ||
129 | + | ||
130 | +# Log specifications | ||
131 | +parser.add_argument('--save', type=str, default='test', | ||
132 | + help='file name to save') | ||
133 | +parser.add_argument('--load', type=str, default='', | ||
134 | + help='file name to load') | ||
135 | +parser.add_argument('--resume', type=int, default=0, | ||
136 | + help='resume from specific checkpoint') | ||
137 | +parser.add_argument('--save_models', action='store_true', | ||
138 | + help='save all intermediate models') | ||
139 | +parser.add_argument('--print_every', type=int, default=100, | ||
140 | + help='how many batches to wait before logging training status') | ||
141 | +parser.add_argument('--save_results', action='store_true', | ||
142 | + help='save output results') | ||
143 | +parser.add_argument('--save_gt', action='store_true', | ||
144 | + help='save low-resolution and high-resolution images together') | ||
145 | + | ||
146 | +args = parser.parse_args() | ||
147 | +template.set_template(args) | ||
148 | + | ||
149 | +args.scale = list(map(lambda x: int(x), args.scale.split('+'))) | ||
150 | +args.data_train = args.data_train.split('+') | ||
151 | +args.data_test = args.data_test.split('+') | ||
152 | + | ||
153 | +if args.epochs == 0: | ||
154 | + args.epochs = 1e8 | ||
155 | + | ||
156 | +for arg in vars(args): | ||
157 | + if vars(args)[arg] == 'True': | ||
158 | + vars(args)[arg] = True | ||
159 | + elif vars(args)[arg] == 'False': | ||
160 | + vars(args)[arg] = False | ||
161 | + |
edsr/src/template.py
0 → 100644
1 | +def set_template(args): | ||
2 | + # Set the templates here | ||
3 | + if args.template.find('jpeg') >= 0: | ||
4 | + args.data_train = 'DIV2K_jpeg' | ||
5 | + args.data_test = 'DIV2K_jpeg' | ||
6 | + args.epochs = 200 | ||
7 | + args.decay = '100' | ||
8 | + | ||
9 | + if args.template.find('EDSR_paper') >= 0: | ||
10 | + args.model = 'EDSR' | ||
11 | + args.n_resblocks = 32 | ||
12 | + args.n_feats = 256 | ||
13 | + args.res_scale = 0.1 | ||
14 | + | ||
15 | + if args.template.find('MDSR') >= 0: | ||
16 | + args.model = 'MDSR' | ||
17 | + args.patch_size = 48 | ||
18 | + args.epochs = 650 | ||
19 | + | ||
20 | + if args.template.find('DDBPN') >= 0: | ||
21 | + args.model = 'DDBPN' | ||
22 | + args.patch_size = 128 | ||
23 | + args.scale = '4' | ||
24 | + | ||
25 | + args.data_test = 'Set5' | ||
26 | + | ||
27 | + args.batch_size = 20 | ||
28 | + args.epochs = 1000 | ||
29 | + args.decay = '500' | ||
30 | + args.gamma = 0.1 | ||
31 | + args.weight_decay = 1e-4 | ||
32 | + | ||
33 | + args.loss = '1*MSE' | ||
34 | + | ||
35 | + if args.template.find('GAN') >= 0: | ||
36 | + args.epochs = 200 | ||
37 | + args.lr = 5e-5 | ||
38 | + args.decay = '150' | ||
39 | + | ||
40 | + if args.template.find('RCAN') >= 0: | ||
41 | + args.model = 'RCAN' | ||
42 | + args.n_resgroups = 10 | ||
43 | + args.n_resblocks = 20 | ||
44 | + args.n_feats = 64 | ||
45 | + args.chop = True | ||
46 | + | ||
47 | + if args.template.find('VDSR') >= 0: | ||
48 | + args.model = 'VDSR' | ||
49 | + args.n_resblocks = 20 | ||
50 | + args.n_feats = 64 | ||
51 | + args.patch_size = 41 | ||
52 | + args.lr = 1e-1 | ||
53 | + |
edsr/src/trainer.py
0 → 100644
1 | +import os | ||
2 | +import math | ||
3 | +from decimal import Decimal | ||
4 | + | ||
5 | +import utility | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn.utils as utils | ||
9 | +from tqdm import tqdm | ||
10 | + | ||
11 | +class Trainer(): | ||
12 | + def __init__(self, args, loader, my_model, my_loss, ckp): | ||
13 | + self.args = args | ||
14 | + self.scale = args.scale | ||
15 | + | ||
16 | + self.ckp = ckp | ||
17 | + self.loader_train = loader.loader_train | ||
18 | + self.loader_test = loader.loader_test | ||
19 | + self.model = my_model | ||
20 | + self.loss = my_loss | ||
21 | + self.optimizer = utility.make_optimizer(args, self.model) | ||
22 | + | ||
23 | + if self.args.load != '': | ||
24 | + self.optimizer.load(ckp.dir, epoch=len(ckp.log)) | ||
25 | + | ||
26 | + self.error_last = 1e8 | ||
27 | + | ||
28 | + def train(self): | ||
29 | + self.loss.step() | ||
30 | + epoch = self.optimizer.get_last_epoch() + 1 | ||
31 | + lr = self.optimizer.get_lr() | ||
32 | + | ||
33 | + self.ckp.write_log( | ||
34 | + '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) | ||
35 | + ) | ||
36 | + self.loss.start_log() | ||
37 | + self.model.train() | ||
38 | + | ||
39 | + timer_data, timer_model = utility.timer(), utility.timer() | ||
40 | + # TEMP | ||
41 | + self.loader_train.dataset.set_scale(0) | ||
42 | + for batch, (lr, hr, _,) in enumerate(self.loader_train): | ||
43 | + lr, hr = self.prepare(lr, hr) | ||
44 | + timer_data.hold() | ||
45 | + timer_model.tic() | ||
46 | + | ||
47 | + self.optimizer.zero_grad() | ||
48 | + sr = self.model(lr, 0) | ||
49 | + loss = self.loss(sr, hr) | ||
50 | + loss.backward() | ||
51 | + if self.args.gclip > 0: | ||
52 | + utils.clip_grad_value_( | ||
53 | + self.model.parameters(), | ||
54 | + self.args.gclip | ||
55 | + ) | ||
56 | + self.optimizer.step() | ||
57 | + | ||
58 | + timer_model.hold() | ||
59 | + | ||
60 | + if (batch + 1) % self.args.print_every == 0: | ||
61 | + self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( | ||
62 | + (batch + 1) * self.args.batch_size, | ||
63 | + len(self.loader_train.dataset), | ||
64 | + self.loss.display_loss(batch), | ||
65 | + timer_model.release(), | ||
66 | + timer_data.release())) | ||
67 | + | ||
68 | + timer_data.tic() | ||
69 | + | ||
70 | + self.loss.end_log(len(self.loader_train)) | ||
71 | + self.error_last = self.loss.log[-1, -1] | ||
72 | + self.optimizer.schedule() | ||
73 | + | ||
74 | + def test(self): | ||
75 | + torch.set_grad_enabled(False) | ||
76 | + | ||
77 | + epoch = self.optimizer.get_last_epoch() | ||
78 | + self.ckp.write_log('\nEvaluation:') | ||
79 | + self.ckp.add_log( | ||
80 | + torch.zeros(1, len(self.loader_test), len(self.scale)) | ||
81 | + ) | ||
82 | + self.model.eval() | ||
83 | + | ||
84 | + timer_test = utility.timer() | ||
85 | + if self.args.save_results: self.ckp.begin_background() | ||
86 | + for idx_data, d in enumerate(self.loader_test): | ||
87 | + for idx_scale, scale in enumerate(self.scale): | ||
88 | + d.dataset.set_scale(idx_scale) | ||
89 | + for lr, hr, filename in tqdm(d, ncols=80): | ||
90 | + lr, hr = self.prepare(lr, hr) | ||
91 | + sr = self.model(lr, idx_scale) | ||
92 | + sr = utility.quantize(sr, self.args.rgb_range) | ||
93 | + | ||
94 | + save_list = [sr] | ||
95 | + self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( | ||
96 | + sr, hr, scale, self.args.rgb_range, dataset=d | ||
97 | + ) | ||
98 | + if self.args.save_gt: | ||
99 | + save_list.extend([lr, hr]) | ||
100 | + | ||
101 | + if self.args.save_results: | ||
102 | + self.ckp.save_results(d, filename[0], save_list, scale) | ||
103 | + | ||
104 | + self.ckp.log[-1, idx_data, idx_scale] /= len(d) | ||
105 | + best = self.ckp.log.max(0) | ||
106 | + self.ckp.write_log( | ||
107 | + '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( | ||
108 | + d.dataset.name, | ||
109 | + scale, | ||
110 | + self.ckp.log[-1, idx_data, idx_scale], | ||
111 | + best[0][idx_data, idx_scale], | ||
112 | + best[1][idx_data, idx_scale] + 1 | ||
113 | + ) | ||
114 | + ) | ||
115 | + | ||
116 | + self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) | ||
117 | + self.ckp.write_log('Saving...') | ||
118 | + | ||
119 | + if self.args.save_results: | ||
120 | + self.ckp.end_background() | ||
121 | + | ||
122 | + if not self.args.test_only: | ||
123 | + self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) | ||
124 | + | ||
125 | + self.ckp.write_log( | ||
126 | + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True | ||
127 | + ) | ||
128 | + | ||
129 | + torch.set_grad_enabled(True) | ||
130 | + | ||
131 | + def prepare(self, *args): | ||
132 | + device = torch.device('cpu' if self.args.cpu else 'cuda') | ||
133 | + def _prepare(tensor): | ||
134 | + if self.args.precision == 'half': tensor = tensor.half() | ||
135 | + return tensor.to(device) | ||
136 | + | ||
137 | + return [_prepare(a) for a in args] | ||
138 | + | ||
139 | + def terminate(self): | ||
140 | + if self.args.test_only: | ||
141 | + self.test() | ||
142 | + return True | ||
143 | + else: | ||
144 | + epoch = self.optimizer.get_last_epoch() + 1 | ||
145 | + return epoch >= self.args.epochs | ||
146 | + |
edsr/src/utility.py
0 → 100644
1 | +import os | ||
2 | +import math | ||
3 | +import time | ||
4 | +import datetime | ||
5 | +from multiprocessing import Process | ||
6 | +from multiprocessing import Queue | ||
7 | + | ||
8 | +import matplotlib | ||
9 | +matplotlib.use('Agg') | ||
10 | +import matplotlib.pyplot as plt | ||
11 | + | ||
12 | +import numpy as np | ||
13 | +import imageio | ||
14 | + | ||
15 | +import torch | ||
16 | +import torch.optim as optim | ||
17 | +import torch.optim.lr_scheduler as lrs | ||
18 | + | ||
19 | +class timer(): | ||
20 | + def __init__(self): | ||
21 | + self.acc = 0 | ||
22 | + self.tic() | ||
23 | + | ||
24 | + def tic(self): | ||
25 | + self.t0 = time.time() | ||
26 | + | ||
27 | + def toc(self, restart=False): | ||
28 | + diff = time.time() - self.t0 | ||
29 | + if restart: self.t0 = time.time() | ||
30 | + return diff | ||
31 | + | ||
32 | + def hold(self): | ||
33 | + self.acc += self.toc() | ||
34 | + | ||
35 | + def release(self): | ||
36 | + ret = self.acc | ||
37 | + self.acc = 0 | ||
38 | + | ||
39 | + return ret | ||
40 | + | ||
41 | + def reset(self): | ||
42 | + self.acc = 0 | ||
43 | + | ||
44 | +class checkpoint(): | ||
45 | + def __init__(self, args): | ||
46 | + self.args = args | ||
47 | + self.ok = True | ||
48 | + self.log = torch.Tensor() | ||
49 | + now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') | ||
50 | + | ||
51 | + if not args.load: | ||
52 | + if not args.save: | ||
53 | + args.save = now | ||
54 | + self.dir = os.path.join('..', 'experiment', args.save) | ||
55 | + else: | ||
56 | + self.dir = os.path.join('..', 'experiment', args.load) | ||
57 | + if os.path.exists(self.dir): | ||
58 | + self.log = torch.load(self.get_path('psnr_log.pt')) | ||
59 | + print('Continue from epoch {}...'.format(len(self.log))) | ||
60 | + else: | ||
61 | + args.load = '' | ||
62 | + | ||
63 | + if args.reset: | ||
64 | + os.system('rm -rf ' + self.dir) | ||
65 | + args.load = '' | ||
66 | + | ||
67 | + os.makedirs(self.dir, exist_ok=True) | ||
68 | + os.makedirs(self.get_path('model'), exist_ok=True) | ||
69 | + for d in args.data_test: | ||
70 | + os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) | ||
71 | + | ||
72 | + open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' | ||
73 | + self.log_file = open(self.get_path('log.txt'), open_type) | ||
74 | + with open(self.get_path('config.txt'), open_type) as f: | ||
75 | + f.write(now + '\n\n') | ||
76 | + for arg in vars(args): | ||
77 | + f.write('{}: {}\n'.format(arg, getattr(args, arg))) | ||
78 | + f.write('\n') | ||
79 | + | ||
80 | + self.n_processes = 8 | ||
81 | + | ||
82 | + def get_path(self, *subdir): | ||
83 | + return os.path.join(self.dir, *subdir) | ||
84 | + | ||
85 | + def save(self, trainer, epoch, is_best=False): | ||
86 | + trainer.model.save(self.get_path('model'), epoch, is_best=is_best) | ||
87 | + trainer.loss.save(self.dir) | ||
88 | + trainer.loss.plot_loss(self.dir, epoch) | ||
89 | + | ||
90 | + self.plot_psnr(epoch) | ||
91 | + trainer.optimizer.save(self.dir) | ||
92 | + torch.save(self.log, self.get_path('psnr_log.pt')) | ||
93 | + | ||
94 | + def add_log(self, log): | ||
95 | + self.log = torch.cat([self.log, log]) | ||
96 | + | ||
97 | + def write_log(self, log, refresh=False): | ||
98 | + print(log) | ||
99 | + self.log_file.write(log + '\n') | ||
100 | + if refresh: | ||
101 | + self.log_file.close() | ||
102 | + self.log_file = open(self.get_path('log.txt'), 'a') | ||
103 | + | ||
104 | + def done(self): | ||
105 | + self.log_file.close() | ||
106 | + | ||
107 | + def plot_psnr(self, epoch): | ||
108 | + axis = np.linspace(1, epoch, epoch) | ||
109 | + for idx_data, d in enumerate(self.args.data_test): | ||
110 | + label = 'SR on {}'.format(d) | ||
111 | + fig = plt.figure() | ||
112 | + plt.title(label) | ||
113 | + for idx_scale, scale in enumerate(self.args.scale): | ||
114 | + plt.plot( | ||
115 | + axis, | ||
116 | + self.log[:, idx_data, idx_scale].numpy(), | ||
117 | + label='Scale {}'.format(scale) | ||
118 | + ) | ||
119 | + plt.legend() | ||
120 | + plt.xlabel('Epochs') | ||
121 | + plt.ylabel('PSNR') | ||
122 | + plt.grid(True) | ||
123 | + plt.savefig(self.get_path('test_{}.pdf'.format(d))) | ||
124 | + plt.close(fig) | ||
125 | + | ||
126 | + def begin_background(self): | ||
127 | + self.queue = Queue() | ||
128 | + | ||
129 | + def bg_target(queue): | ||
130 | + while True: | ||
131 | + if not queue.empty(): | ||
132 | + filename, tensor = queue.get() | ||
133 | + if filename is None: break | ||
134 | + imageio.imwrite(filename, tensor.numpy()) | ||
135 | + | ||
136 | + self.process = [ | ||
137 | + Process(target=bg_target, args=(self.queue,)) \ | ||
138 | + for _ in range(self.n_processes) | ||
139 | + ] | ||
140 | + | ||
141 | + for p in self.process: p.start() | ||
142 | + | ||
143 | + def end_background(self): | ||
144 | + for _ in range(self.n_processes): self.queue.put((None, None)) | ||
145 | + while not self.queue.empty(): time.sleep(1) | ||
146 | + for p in self.process: p.join() | ||
147 | + | ||
148 | + def save_results(self, dataset, filename, save_list, scale): | ||
149 | + if self.args.save_results: | ||
150 | + filename = self.get_path( | ||
151 | + 'results-{}'.format(dataset.dataset.name), | ||
152 | + '{}_x{}_'.format(filename, scale) | ||
153 | + ) | ||
154 | + | ||
155 | + postfix = ('SR', 'LR', 'HR') | ||
156 | + for v, p in zip(save_list, postfix): | ||
157 | + normalized = v[0].mul(255 / self.args.rgb_range) | ||
158 | + tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() | ||
159 | + self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) | ||
160 | + | ||
161 | +def quantize(img, rgb_range): | ||
162 | + pixel_range = 255 / rgb_range | ||
163 | + return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) | ||
164 | + | ||
165 | +def calc_psnr(sr, hr, scale, rgb_range, dataset=None): | ||
166 | + if hr.nelement() == 1: return 0 | ||
167 | + | ||
168 | + diff = (sr - hr) / rgb_range | ||
169 | + if dataset and dataset.dataset.benchmark: | ||
170 | + shave = scale | ||
171 | + if diff.size(1) > 1: | ||
172 | + gray_coeffs = [65.738, 129.057, 25.064] | ||
173 | + convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 | ||
174 | + diff = diff.mul(convert).sum(dim=1) | ||
175 | + else: | ||
176 | + shave = scale + 6 | ||
177 | + | ||
178 | + valid = diff[..., shave:-shave, shave:-shave] | ||
179 | + mse = valid.pow(2).mean() | ||
180 | + | ||
181 | + return -10 * math.log10(mse) | ||
182 | + | ||
183 | +def make_optimizer(args, target): | ||
184 | + ''' | ||
185 | + make optimizer and scheduler together | ||
186 | + ''' | ||
187 | + # optimizer | ||
188 | + trainable = filter(lambda x: x.requires_grad, target.parameters()) | ||
189 | + kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} | ||
190 | + | ||
191 | + if args.optimizer == 'SGD': | ||
192 | + optimizer_class = optim.SGD | ||
193 | + kwargs_optimizer['momentum'] = args.momentum | ||
194 | + elif args.optimizer == 'ADAM': | ||
195 | + optimizer_class = optim.Adam | ||
196 | + kwargs_optimizer['betas'] = args.betas | ||
197 | + kwargs_optimizer['eps'] = args.epsilon | ||
198 | + elif args.optimizer == 'RMSprop': | ||
199 | + optimizer_class = optim.RMSprop | ||
200 | + kwargs_optimizer['eps'] = args.epsilon | ||
201 | + | ||
202 | + # scheduler | ||
203 | + milestones = list(map(lambda x: int(x), args.decay.split('-'))) | ||
204 | + kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} | ||
205 | + scheduler_class = lrs.MultiStepLR | ||
206 | + | ||
207 | + class CustomOptimizer(optimizer_class): | ||
208 | + def __init__(self, *args, **kwargs): | ||
209 | + super(CustomOptimizer, self).__init__(*args, **kwargs) | ||
210 | + | ||
211 | + def _register_scheduler(self, scheduler_class, **kwargs): | ||
212 | + self.scheduler = scheduler_class(self, **kwargs) | ||
213 | + | ||
214 | + def save(self, save_dir): | ||
215 | + torch.save(self.state_dict(), self.get_dir(save_dir)) | ||
216 | + | ||
217 | + def load(self, load_dir, epoch=1): | ||
218 | + self.load_state_dict(torch.load(self.get_dir(load_dir))) | ||
219 | + if epoch > 1: | ||
220 | + for _ in range(epoch): self.scheduler.step() | ||
221 | + | ||
222 | + def get_dir(self, dir_path): | ||
223 | + return os.path.join(dir_path, 'optimizer.pt') | ||
224 | + | ||
225 | + def schedule(self): | ||
226 | + self.scheduler.step() | ||
227 | + | ||
228 | + def get_lr(self): | ||
229 | + return self.scheduler.get_lr()[0] | ||
230 | + | ||
231 | + def get_last_epoch(self): | ||
232 | + return self.scheduler.last_epoch | ||
233 | + | ||
234 | + optimizer = CustomOptimizer(trainable, **kwargs_optimizer) | ||
235 | + optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) | ||
236 | + return optimizer | ||
237 | + |
edsr/src/videotester.py
0 → 100644
1 | +import os | ||
2 | +import math | ||
3 | + | ||
4 | +import utility | ||
5 | +from data import common | ||
6 | + | ||
7 | +import torch | ||
8 | +import cv2 | ||
9 | + | ||
10 | +from tqdm import tqdm | ||
11 | + | ||
12 | +class VideoTester(): | ||
13 | + def __init__(self, args, my_model, ckp): | ||
14 | + self.args = args | ||
15 | + self.scale = args.scale | ||
16 | + | ||
17 | + self.ckp = ckp | ||
18 | + self.model = my_model | ||
19 | + | ||
20 | + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) | ||
21 | + | ||
22 | + def test(self): | ||
23 | + torch.set_grad_enabled(False) | ||
24 | + | ||
25 | + self.ckp.write_log('\nEvaluation on video:') | ||
26 | + self.model.eval() | ||
27 | + | ||
28 | + timer_test = utility.timer() | ||
29 | + for idx_scale, scale in enumerate(self.scale): | ||
30 | + vidcap = cv2.VideoCapture(self.args.dir_demo) | ||
31 | + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
32 | + vidwri = cv2.VideoWriter( | ||
33 | + self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), | ||
34 | + cv2.VideoWriter_fourcc(*'XVID'), | ||
35 | + vidcap.get(cv2.CAP_PROP_FPS), | ||
36 | + ( | ||
37 | + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), | ||
38 | + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
39 | + ) | ||
40 | + ) | ||
41 | + | ||
42 | + tqdm_test = tqdm(range(total_frames), ncols=80) | ||
43 | + for _ in tqdm_test: | ||
44 | + success, lr = vidcap.read() | ||
45 | + if not success: break | ||
46 | + | ||
47 | + lr, = common.set_channel(lr, n_channels=self.args.n_colors) | ||
48 | + lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | ||
49 | + lr, = self.prepare(lr.unsqueeze(0)) | ||
50 | + sr = self.model(lr, idx_scale) | ||
51 | + sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) | ||
52 | + | ||
53 | + normalized = sr * 255 / self.args.rgb_range | ||
54 | + ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() | ||
55 | + vidwri.write(ndarr) | ||
56 | + | ||
57 | + vidcap.release() | ||
58 | + vidwri.release() | ||
59 | + | ||
60 | + self.ckp.write_log( | ||
61 | + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True | ||
62 | + ) | ||
63 | + torch.set_grad_enabled(True) | ||
64 | + | ||
65 | + def prepare(self, *args): | ||
66 | + device = torch.device('cpu' if self.args.cpu else 'cuda') | ||
67 | + def _prepare(tensor): | ||
68 | + if self.args.precision == 'half': tensor = tensor.half() | ||
69 | + return tensor.to(device) | ||
70 | + | ||
71 | + return [_prepare(a) for a in args] | ||
72 | + |
-
Please register or login to post a comment