김재형
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +env/
12 +build/
13 +develop-eggs/
14 +dist/
15 +downloads/
16 +eggs/
17 +.eggs/
18 +lib/
19 +lib64/
20 +parts/
21 +sdist/
22 +var/
23 +*.egg-info/
24 +.installed.cfg
25 +*.egg
26 +
27 +# PyInstaller
28 +# Usually these files are written by a python script from a template
29 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
30 +*.manifest
31 +*.spec
32 +
33 +# Installer logs
34 +pip-log.txt
35 +pip-delete-this-directory.txt
36 +
37 +# Unit test / coverage reports
38 +htmlcov/
39 +.tox/
40 +.coverage
41 +.coverage.*
42 +.cache
43 +nosetests.xml
44 +coverage.xml
45 +*,cover
46 +
47 +# Translations
48 +*.mo
49 +*.pot
50 +
51 +# Django stuff:
52 +*.log
53 +
54 +# Sphinx documentation
55 +docs/_build/
56 +
57 +# PyBuilder
58 +target/
59 +
60 +# PyTorch
61 +*.pt
62 +*.pdf
63 +*.png
64 +*.txt
65 +*.swp
66 +.vscode
1 +MIT License
2 +
3 +Copyright (c) 2018 Sanghyun Son
4 +
5 +Permission is hereby granted, free of charge, to any person obtaining a copy
6 +of this software and associated documentation files (the "Software"), to deal
7 +in the Software without restriction, including without limitation the rights
8 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 +copies of the Software, and to permit persons to whom the Software is
10 +furnished to do so, subject to the following conditions:
11 +
12 +The above copyright notice and this permission notice shall be included in all
13 +copies or substantial portions of the Software.
14 +
15 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 +SOFTWARE.
This diff is collapsed. Click to expand it.
1 +*
2 +!.gitignore
3 +!/model/*.pt
File mode changed
1 +from importlib import import_module
2 +#from dataloader import MSDataLoader
3 +from torch.utils.data import dataloader
4 +from torch.utils.data import ConcatDataset
5 +
6 +# This is a simple wrapper function for ConcatDataset
7 +class MyConcatDataset(ConcatDataset):
8 + def __init__(self, datasets):
9 + super(MyConcatDataset, self).__init__(datasets)
10 + self.train = datasets[0].train
11 +
12 + def set_scale(self, idx_scale):
13 + for d in self.datasets:
14 + if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15 +
16 +class Data:
17 + def __init__(self, args):
18 + self.loader_train = None
19 + if not args.test_only:
20 + datasets = []
21 + for d in args.data_train:
22 + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23 + m = import_module('data.' + module_name.lower())
24 + datasets.append(getattr(m, module_name)(args, name=d))
25 +
26 + self.loader_train = dataloader.DataLoader(
27 + MyConcatDataset(datasets),
28 + batch_size=args.batch_size,
29 + shuffle=True,
30 + pin_memory=not args.cpu,
31 + num_workers=args.n_threads,
32 + )
33 +
34 + self.loader_test = []
35 + for d in args.data_test:
36 + if d in ['Set5', 'Set14', 'B100', 'Urban100']:
37 + m = import_module('data.benchmark')
38 + testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39 + else:
40 + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41 + m = import_module('data.' + module_name.lower())
42 + testset = getattr(m, module_name)(args, train=False, name=d)
43 +
44 + self.loader_test.append(
45 + dataloader.DataLoader(
46 + testset,
47 + batch_size=1,
48 + shuffle=False,
49 + pin_memory=not args.cpu,
50 + num_workers=args.n_threads,
51 + )
52 + )
1 +import os
2 +
3 +from data import common
4 +from data import srdata
5 +
6 +import numpy as np
7 +
8 +import torch
9 +import torch.utils.data as data
10 +
11 +class Benchmark(srdata.SRData):
12 + def __init__(self, args, name='', train=True, benchmark=True):
13 + super(Benchmark, self).__init__(
14 + args, name=name, train=train, benchmark=True
15 + )
16 +
17 + def _set_filesystem(self, dir_data):
18 + self.apath = os.path.join(dir_data, 'benchmark', self.name)
19 + self.dir_hr = os.path.join(self.apath, 'HR')
20 + if self.input_large:
21 + self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22 + else:
23 + self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24 + self.ext = ('', '.png')
25 +
1 +import random
2 +
3 +import numpy as np
4 +import skimage.color as sc
5 +
6 +import torch
7 +
8 +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9 + ih, iw = args[0].shape[:2]
10 +
11 + if not input_large:
12 + p = scale if multi else 1
13 + tp = p * patch_size
14 + ip = tp // scale
15 + else:
16 + tp = patch_size
17 + ip = patch_size
18 +
19 + ix = random.randrange(0, iw - ip + 1)
20 + iy = random.randrange(0, ih - ip + 1)
21 +
22 + if not input_large:
23 + tx, ty = scale * ix, scale * iy
24 + else:
25 + tx, ty = ix, iy
26 +
27 + ret = [
28 + args[0][iy:iy + ip, ix:ix + ip, :],
29 + *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30 + ]
31 +
32 + return ret
33 +
34 +def set_channel(*args, n_channels=3):
35 + def _set_channel(img):
36 + if img.ndim == 2:
37 + img = np.expand_dims(img, axis=2)
38 +
39 + c = img.shape[2]
40 + if n_channels == 1 and c == 3:
41 + img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42 + elif n_channels == 3 and c == 1:
43 + img = np.concatenate([img] * n_channels, 2)
44 +
45 + return img
46 +
47 + return [_set_channel(a) for a in args]
48 +
49 +def np2Tensor(*args, rgb_range=255):
50 + def _np2Tensor(img):
51 + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52 + tensor = torch.from_numpy(np_transpose).float()
53 + tensor.mul_(rgb_range / 255)
54 +
55 + return tensor
56 +
57 + return [_np2Tensor(a) for a in args]
58 +
59 +def augment(*args, hflip=True, rot=True):
60 + hflip = hflip and random.random() < 0.5
61 + vflip = rot and random.random() < 0.5
62 + rot90 = rot and random.random() < 0.5
63 +
64 + def _augment(img):
65 + if hflip: img = img[:, ::-1, :]
66 + if vflip: img = img[::-1, :, :]
67 + if rot90: img = img.transpose(1, 0, 2)
68 +
69 + return img
70 +
71 + return [_augment(a) for a in args]
72 +
1 +import os
2 +
3 +from data import common
4 +
5 +import numpy as np
6 +import imageio
7 +
8 +import torch
9 +import torch.utils.data as data
10 +
11 +class Demo(data.Dataset):
12 + def __init__(self, args, name='Demo', train=False, benchmark=False):
13 + self.args = args
14 + self.name = name
15 + self.scale = args.scale
16 + self.idx_scale = 0
17 + self.train = False
18 + self.benchmark = benchmark
19 +
20 + self.filelist = []
21 + for f in os.listdir(args.dir_demo):
22 + if f.find('.png') >= 0 or f.find('.jp') >= 0:
23 + self.filelist.append(os.path.join(args.dir_demo, f))
24 + self.filelist.sort()
25 +
26 + def __getitem__(self, idx):
27 + filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28 + lr = imageio.imread(self.filelist[idx])
29 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30 + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31 +
32 + return lr_t, -1, filename
33 +
34 + def __len__(self):
35 + return len(self.filelist)
36 +
37 + def set_scale(self, idx_scale):
38 + self.idx_scale = idx_scale
39 +
1 +import os
2 +from data import srdata
3 +
4 +class DIV2K(srdata.SRData):
5 + def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6 + data_range = [r.split('-') for r in args.data_range.split('/')]
7 + if train:
8 + data_range = data_range[0]
9 + else:
10 + if args.test_only and len(data_range) == 1:
11 + data_range = data_range[0]
12 + else:
13 + data_range = data_range[1]
14 +
15 + self.begin, self.end = list(map(lambda x: int(x), data_range))
16 + super(DIV2K, self).__init__(
17 + args, name=name, train=train, benchmark=benchmark
18 + )
19 +
20 + def _scan(self):
21 + names_hr, names_lr = super(DIV2K, self)._scan()
22 + names_hr = names_hr[self.begin - 1:self.end]
23 + names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24 +
25 + return names_hr, names_lr
26 +
27 + def _set_filesystem(self, dir_data):
28 + super(DIV2K, self)._set_filesystem(dir_data)
29 + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30 + self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31 + if self.input_large: self.dir_lr += 'L'
32 +
1 +import os
2 +from data import srdata
3 +from data import div2k
4 +
5 +class DIV2KJPEG(div2k.DIV2K):
6 + def __init__(self, args, name='', train=True, benchmark=False):
7 + self.q_factor = int(name.replace('DIV2K-Q', ''))
8 + super(DIV2KJPEG, self).__init__(
9 + args, name=name, train=train, benchmark=benchmark
10 + )
11 +
12 + def _set_filesystem(self, dir_data):
13 + self.apath = os.path.join(dir_data, 'DIV2K')
14 + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15 + self.dir_lr = os.path.join(
16 + self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17 + )
18 + if self.input_large: self.dir_lr += 'L'
19 + self.ext = ('.png', '.jpg')
20 +
1 +from data import srdata
2 +
3 +class SR291(srdata.SRData):
4 + def __init__(self, args, name='SR291', train=True, benchmark=False):
5 + super(SR291, self).__init__(args, name=name)
6 +
1 +import os
2 +import glob
3 +import random
4 +import pickle
5 +
6 +from data import common
7 +
8 +import numpy as np
9 +import imageio
10 +import torch
11 +import torch.utils.data as data
12 +
13 +class SRData(data.Dataset):
14 + def __init__(self, args, name='', train=True, benchmark=False):
15 + self.args = args
16 + self.name = name
17 + self.train = train
18 + self.split = 'train' if train else 'test'
19 + self.do_eval = True
20 + self.benchmark = benchmark
21 + self.input_large = (args.model == 'VDSR')
22 + self.scale = args.scale
23 + self.idx_scale = 0
24 +
25 + self._set_filesystem(args.dir_data)
26 + if args.ext.find('img') < 0:
27 + path_bin = os.path.join(self.apath, 'bin')
28 + os.makedirs(path_bin, exist_ok=True)
29 +
30 + list_hr, list_lr = self._scan()
31 + if args.ext.find('img') >= 0 or benchmark:
32 + self.images_hr, self.images_lr = list_hr, list_lr
33 + elif args.ext.find('sep') >= 0:
34 + os.makedirs(
35 + self.dir_hr.replace(self.apath, path_bin),
36 + exist_ok=True
37 + )
38 + for s in self.scale:
39 + os.makedirs(
40 + os.path.join(
41 + self.dir_lr.replace(self.apath, path_bin),
42 + 'X{}'.format(s)
43 + ),
44 + exist_ok=True
45 + )
46 +
47 + self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48 + for h in list_hr:
49 + b = h.replace(self.apath, path_bin)
50 + b = b.replace(self.ext[0], '.pt')
51 + self.images_hr.append(b)
52 + self._check_and_load(args.ext, h, b, verbose=True)
53 + for i, ll in enumerate(list_lr):
54 + for l in ll:
55 + b = l.replace(self.apath, path_bin)
56 + b = b.replace(self.ext[1], '.pt')
57 + self.images_lr[i].append(b)
58 + self._check_and_load(args.ext, l, b, verbose=True)
59 + if train:
60 + n_patches = args.batch_size * args.test_every
61 + n_images = len(args.data_train) * len(self.images_hr)
62 + if n_images == 0:
63 + self.repeat = 0
64 + else:
65 + self.repeat = max(n_patches // n_images, 1)
66 +
67 + # Below functions as used to prepare images
68 + def _scan(self):
69 + names_hr = sorted(
70 + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71 + )
72 + names_lr = [[] for _ in self.scale]
73 + for f in names_hr:
74 + filename, _ = os.path.splitext(os.path.basename(f))
75 + for si, s in enumerate(self.scale):
76 + names_lr[si].append(os.path.join(
77 + self.dir_lr, 'X{}/{}x{}{}'.format(
78 + s, filename, s, self.ext[1]
79 + )
80 + ))
81 +
82 + return names_hr, names_lr
83 +
84 + def _set_filesystem(self, dir_data):
85 + self.apath = os.path.join(dir_data, self.name)
86 + self.dir_hr = os.path.join(self.apath, 'HR')
87 + self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88 + if self.input_large: self.dir_lr += 'L'
89 + self.ext = ('.png', '.png')
90 +
91 + def _check_and_load(self, ext, img, f, verbose=True):
92 + if not os.path.isfile(f) or ext.find('reset') >= 0:
93 + if verbose:
94 + print('Making a binary: {}'.format(f))
95 + with open(f, 'wb') as _f:
96 + pickle.dump(imageio.imread(img), _f)
97 +
98 + def __getitem__(self, idx):
99 + lr, hr, filename = self._load_file(idx)
100 + pair = self.get_patch(lr, hr)
101 + pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102 + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103 +
104 + return pair_t[0], pair_t[1], filename
105 +
106 + def __len__(self):
107 + if self.train:
108 + return len(self.images_hr) * self.repeat
109 + else:
110 + return len(self.images_hr)
111 +
112 + def _get_index(self, idx):
113 + if self.train:
114 + return idx % len(self.images_hr)
115 + else:
116 + return idx
117 +
118 + def _load_file(self, idx):
119 + idx = self._get_index(idx)
120 + f_hr = self.images_hr[idx]
121 + f_lr = self.images_lr[self.idx_scale][idx]
122 +
123 + filename, _ = os.path.splitext(os.path.basename(f_hr))
124 + if self.args.ext == 'img' or self.benchmark:
125 + hr = imageio.imread(f_hr)
126 + lr = imageio.imread(f_lr)
127 + elif self.args.ext.find('sep') >= 0:
128 + with open(f_hr, 'rb') as _f:
129 + hr = pickle.load(_f)
130 + with open(f_lr, 'rb') as _f:
131 + lr = pickle.load(_f)
132 +
133 + return lr, hr, filename
134 +
135 + def get_patch(self, lr, hr):
136 + scale = self.scale[self.idx_scale]
137 + if self.train:
138 + lr, hr = common.get_patch(
139 + lr, hr,
140 + patch_size=self.args.patch_size,
141 + scale=scale,
142 + multi=(len(self.scale) > 1),
143 + input_large=self.input_large
144 + )
145 + if not self.args.no_augment: lr, hr = common.augment(lr, hr)
146 + else:
147 + ih, iw = lr.shape[:2]
148 + hr = hr[0:ih * scale, 0:iw * scale]
149 +
150 + return lr, hr
151 +
152 + def set_scale(self, idx_scale):
153 + if not self.input_large:
154 + self.idx_scale = idx_scale
155 + else:
156 + self.idx_scale = random.randint(0, len(self.scale) - 1)
157 +
1 +import os
2 +
3 +from data import common
4 +
5 +import cv2
6 +import numpy as np
7 +import imageio
8 +
9 +import torch
10 +import torch.utils.data as data
11 +
12 +class Video(data.Dataset):
13 + def __init__(self, args, name='Video', train=False, benchmark=False):
14 + self.args = args
15 + self.name = name
16 + self.scale = args.scale
17 + self.idx_scale = 0
18 + self.train = False
19 + self.do_eval = False
20 + self.benchmark = benchmark
21 +
22 + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23 + self.vidcap = cv2.VideoCapture(args.dir_demo)
24 + self.n_frames = 0
25 + self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26 +
27 + def __getitem__(self, idx):
28 + success, lr = self.vidcap.read()
29 + if success:
30 + self.n_frames += 1
31 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32 + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33 +
34 + return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35 + else:
36 + vidcap.release()
37 + return None
38 +
39 + def __len__(self):
40 + return self.total_frames
41 +
42 + def set_scale(self, idx_scale):
43 + self.idx_scale = idx_scale
44 +
1 +import threading
2 +import random
3 +
4 +import torch
5 +import torch.multiprocessing as multiprocessing
6 +from torch.utils.data import DataLoader
7 +from torch.utils.data import SequentialSampler
8 +from torch.utils.data import RandomSampler
9 +from torch.utils.data import BatchSampler
10 +from torch.utils.data import _utils
11 +from torch.utils.data.dataloader import _DataLoaderIter
12 +
13 +from torch.utils.data._utils import collate
14 +from torch.utils.data._utils import signal_handling
15 +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16 +from torch.utils.data._utils import ExceptionWrapper
17 +from torch.utils.data._utils import IS_WINDOWS
18 +from torch.utils.data._utils.worker import ManagerWatchdog
19 +
20 +from torch._six import queue
21 +
22 +def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23 + try:
24 + collate._use_shared_memory = True
25 + signal_handling._set_worker_signal_handlers()
26 +
27 + torch.set_num_threads(1)
28 + random.seed(seed)
29 + torch.manual_seed(seed)
30 +
31 + data_queue.cancel_join_thread()
32 +
33 + if init_fn is not None:
34 + init_fn(worker_id)
35 +
36 + watchdog = ManagerWatchdog()
37 +
38 + while watchdog.is_alive():
39 + try:
40 + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41 + except queue.Empty:
42 + continue
43 +
44 + if r is None:
45 + assert done_event.is_set()
46 + return
47 + elif done_event.is_set():
48 + continue
49 +
50 + idx, batch_indices = r
51 + try:
52 + idx_scale = 0
53 + if len(scale) > 1 and dataset.train:
54 + idx_scale = random.randrange(0, len(scale))
55 + dataset.set_scale(idx_scale)
56 +
57 + samples = collate_fn([dataset[i] for i in batch_indices])
58 + samples.append(idx_scale)
59 + except Exception:
60 + data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61 + else:
62 + data_queue.put((idx, samples))
63 + del samples
64 +
65 + except KeyboardInterrupt:
66 + pass
67 +
68 +class _MSDataLoaderIter(_DataLoaderIter):
69 +
70 + def __init__(self, loader):
71 + self.dataset = loader.dataset
72 + self.scale = loader.scale
73 + self.collate_fn = loader.collate_fn
74 + self.batch_sampler = loader.batch_sampler
75 + self.num_workers = loader.num_workers
76 + self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77 + self.timeout = loader.timeout
78 +
79 + self.sample_iter = iter(self.batch_sampler)
80 +
81 + base_seed = torch.LongTensor(1).random_().item()
82 +
83 + if self.num_workers > 0:
84 + self.worker_init_fn = loader.worker_init_fn
85 + self.worker_queue_idx = 0
86 + self.worker_result_queue = multiprocessing.Queue()
87 + self.batches_outstanding = 0
88 + self.worker_pids_set = False
89 + self.shutdown = False
90 + self.send_idx = 0
91 + self.rcvd_idx = 0
92 + self.reorder_dict = {}
93 + self.done_event = multiprocessing.Event()
94 +
95 + base_seed = torch.LongTensor(1).random_()[0]
96 +
97 + self.index_queues = []
98 + self.workers = []
99 + for i in range(self.num_workers):
100 + index_queue = multiprocessing.Queue()
101 + index_queue.cancel_join_thread()
102 + w = multiprocessing.Process(
103 + target=_ms_loop,
104 + args=(
105 + self.dataset,
106 + index_queue,
107 + self.worker_result_queue,
108 + self.done_event,
109 + self.collate_fn,
110 + self.scale,
111 + base_seed + i,
112 + self.worker_init_fn,
113 + i
114 + )
115 + )
116 + w.daemon = True
117 + w.start()
118 + self.index_queues.append(index_queue)
119 + self.workers.append(w)
120 +
121 + if self.pin_memory:
122 + self.data_queue = queue.Queue()
123 + pin_memory_thread = threading.Thread(
124 + target=_utils.pin_memory._pin_memory_loop,
125 + args=(
126 + self.worker_result_queue,
127 + self.data_queue,
128 + torch.cuda.current_device(),
129 + self.done_event
130 + )
131 + )
132 + pin_memory_thread.daemon = True
133 + pin_memory_thread.start()
134 + self.pin_memory_thread = pin_memory_thread
135 + else:
136 + self.data_queue = self.worker_result_queue
137 +
138 + _utils.signal_handling._set_worker_pids(
139 + id(self), tuple(w.pid for w in self.workers)
140 + )
141 + _utils.signal_handling._set_SIGCHLD_handler()
142 + self.worker_pids_set = True
143 +
144 + for _ in range(2 * self.num_workers):
145 + self._put_indices()
146 +
147 +
148 +class MSDataLoader(DataLoader):
149 +
150 + def __init__(self, cfg, *args, **kwargs):
151 + super(MSDataLoader, self).__init__(
152 + *args, **kwargs, num_workers=cfg.n_threads
153 + )
154 + self.scale = cfg.scale
155 +
156 + def __iter__(self):
157 + return _MSDataLoaderIter(self)
158 +
1 +# EDSR baseline model (x2) + JPEG augmentation
2 +python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
3 +#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
4 +
5 +# EDSR baseline model (x3) - from EDSR baseline model (x2)
6 +#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
7 +
8 +# EDSR baseline model (x4) - from EDSR baseline model (x2)
9 +#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
10 +
11 +# EDSR in the paper (x2)
12 +#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
13 +
14 +# EDSR in the paper (x3) - from EDSR (x2)
15 +#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
16 +
17 +# EDSR in the paper (x4) - from EDSR (x2)
18 +#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
19 +
20 +# MDSR baseline model
21 +#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
22 +
23 +# MDSR in the paper
24 +#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
25 +
26 +# Standard benchmarks (Ex. EDSR_baseline_x4)
27 +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
28 +
29 +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
30 +
31 +# Test your own images
32 +#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
33 +
34 +# Advanced - Test with JPEG images
35 +#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
36 +
37 +# Advanced - Training with adversarial loss
38 +#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
39 +
40 +# RDN BI model (x2)
41 +#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
42 +# RDN BI model (x3)
43 +#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset
44 +# RDN BI model (x4)
45 +#python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset
46 +
47 +# RCAN_BIX2_G10R20P48, input=48x48, output=96x96
48 +# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0
49 +#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96
50 +# RCAN_BIX3_G10R20P48, input=48x48, output=144x144
51 +#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
52 +# RCAN_BIX4_G10R20P48, input=48x48, output=192x192
53 +#python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
54 +# RCAN_BIX8_G10R20P48, input=48x48, output=384x384
55 +#python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
56 +
1 +import os
2 +from importlib import import_module
3 +
4 +import matplotlib
5 +matplotlib.use('Agg')
6 +import matplotlib.pyplot as plt
7 +
8 +import numpy as np
9 +
10 +import torch
11 +import torch.nn as nn
12 +import torch.nn.functional as F
13 +
14 +class Loss(nn.modules.loss._Loss):
15 + def __init__(self, args, ckp):
16 + super(Loss, self).__init__()
17 + print('Preparing loss function:')
18 +
19 + self.n_GPUs = args.n_GPUs
20 + self.loss = []
21 + self.loss_module = nn.ModuleList()
22 + for loss in args.loss.split('+'):
23 + weight, loss_type = loss.split('*')
24 + if loss_type == 'MSE':
25 + loss_function = nn.MSELoss()
26 + elif loss_type == 'L1':
27 + loss_function = nn.L1Loss()
28 + elif loss_type.find('VGG') >= 0:
29 + module = import_module('loss.vgg')
30 + loss_function = getattr(module, 'VGG')(
31 + loss_type[3:],
32 + rgb_range=args.rgb_range
33 + )
34 + elif loss_type.find('GAN') >= 0:
35 + module = import_module('loss.adversarial')
36 + loss_function = getattr(module, 'Adversarial')(
37 + args,
38 + loss_type
39 + )
40 +
41 + self.loss.append({
42 + 'type': loss_type,
43 + 'weight': float(weight),
44 + 'function': loss_function}
45 + )
46 + if loss_type.find('GAN') >= 0:
47 + self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
48 +
49 + if len(self.loss) > 1:
50 + self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
51 +
52 + for l in self.loss:
53 + if l['function'] is not None:
54 + print('{:.3f} * {}'.format(l['weight'], l['type']))
55 + self.loss_module.append(l['function'])
56 +
57 + self.log = torch.Tensor()
58 +
59 + device = torch.device('cpu' if args.cpu else 'cuda')
60 + self.loss_module.to(device)
61 + if args.precision == 'half': self.loss_module.half()
62 + if not args.cpu and args.n_GPUs > 1:
63 + self.loss_module = nn.DataParallel(
64 + self.loss_module, range(args.n_GPUs)
65 + )
66 +
67 + if args.load != '': self.load(ckp.dir, cpu=args.cpu)
68 +
69 + def forward(self, sr, hr):
70 + losses = []
71 + for i, l in enumerate(self.loss):
72 + if l['function'] is not None:
73 + loss = l['function'](sr, hr)
74 + effective_loss = l['weight'] * loss
75 + losses.append(effective_loss)
76 + self.log[-1, i] += effective_loss.item()
77 + elif l['type'] == 'DIS':
78 + self.log[-1, i] += self.loss[i - 1]['function'].loss
79 +
80 + loss_sum = sum(losses)
81 + if len(self.loss) > 1:
82 + self.log[-1, -1] += loss_sum.item()
83 +
84 + return loss_sum
85 +
86 + def step(self):
87 + for l in self.get_loss_module():
88 + if hasattr(l, 'scheduler'):
89 + l.scheduler.step()
90 +
91 + def start_log(self):
92 + self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
93 +
94 + def end_log(self, n_batches):
95 + self.log[-1].div_(n_batches)
96 +
97 + def display_loss(self, batch):
98 + n_samples = batch + 1
99 + log = []
100 + for l, c in zip(self.loss, self.log[-1]):
101 + log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
102 +
103 + return ''.join(log)
104 +
105 + def plot_loss(self, apath, epoch):
106 + axis = np.linspace(1, epoch, epoch)
107 + for i, l in enumerate(self.loss):
108 + label = '{} Loss'.format(l['type'])
109 + fig = plt.figure()
110 + plt.title(label)
111 + plt.plot(axis, self.log[:, i].numpy(), label=label)
112 + plt.legend()
113 + plt.xlabel('Epochs')
114 + plt.ylabel('Loss')
115 + plt.grid(True)
116 + plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
117 + plt.close(fig)
118 +
119 + def get_loss_module(self):
120 + if self.n_GPUs == 1:
121 + return self.loss_module
122 + else:
123 + return self.loss_module.module
124 +
125 + def save(self, apath):
126 + torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
127 + torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
128 +
129 + def load(self, apath, cpu=False):
130 + if cpu:
131 + kwargs = {'map_location': lambda storage, loc: storage}
132 + else:
133 + kwargs = {}
134 +
135 + self.load_state_dict(torch.load(
136 + os.path.join(apath, 'loss.pt'),
137 + **kwargs
138 + ))
139 + self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
140 + for l in self.get_loss_module():
141 + if hasattr(l, 'scheduler'):
142 + for _ in range(len(self.log)): l.scheduler.step()
143 +
1 +import utility
2 +from types import SimpleNamespace
3 +
4 +from model import common
5 +from loss import discriminator
6 +
7 +import torch
8 +import torch.nn as nn
9 +import torch.nn.functional as F
10 +import torch.optim as optim
11 +
12 +class Adversarial(nn.Module):
13 + def __init__(self, args, gan_type):
14 + super(Adversarial, self).__init__()
15 + self.gan_type = gan_type
16 + self.gan_k = args.gan_k
17 + self.dis = discriminator.Discriminator(args)
18 + if gan_type == 'WGAN_GP':
19 + # see https://arxiv.org/pdf/1704.00028.pdf pp.4
20 + optim_dict = {
21 + 'optimizer': 'ADAM',
22 + 'betas': (0, 0.9),
23 + 'epsilon': 1e-8,
24 + 'lr': 1e-5,
25 + 'weight_decay': args.weight_decay,
26 + 'decay': args.decay,
27 + 'gamma': args.gamma
28 + }
29 + optim_args = SimpleNamespace(**optim_dict)
30 + else:
31 + optim_args = args
32 +
33 + self.optimizer = utility.make_optimizer(optim_args, self.dis)
34 +
35 + def forward(self, fake, real):
36 + # updating discriminator...
37 + self.loss = 0
38 + fake_detach = fake.detach() # do not backpropagate through G
39 + for _ in range(self.gan_k):
40 + self.optimizer.zero_grad()
41 + # d: B x 1 tensor
42 + d_fake = self.dis(fake_detach)
43 + d_real = self.dis(real)
44 + retain_graph = False
45 + if self.gan_type == 'GAN':
46 + loss_d = self.bce(d_real, d_fake)
47 + elif self.gan_type.find('WGAN') >= 0:
48 + loss_d = (d_fake - d_real).mean()
49 + if self.gan_type.find('GP') >= 0:
50 + epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
51 + hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
52 + hat.requires_grad = True
53 + d_hat = self.dis(hat)
54 + gradients = torch.autograd.grad(
55 + outputs=d_hat.sum(), inputs=hat,
56 + retain_graph=True, create_graph=True, only_inputs=True
57 + )[0]
58 + gradients = gradients.view(gradients.size(0), -1)
59 + gradient_norm = gradients.norm(2, dim=1)
60 + gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
61 + loss_d += gradient_penalty
62 + # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
63 + elif self.gan_type == 'RGAN':
64 + better_real = d_real - d_fake.mean(dim=0, keepdim=True)
65 + better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
66 + loss_d = self.bce(better_real, better_fake)
67 + retain_graph = True
68 +
69 + # Discriminator update
70 + self.loss += loss_d.item()
71 + loss_d.backward(retain_graph=retain_graph)
72 + self.optimizer.step()
73 +
74 + if self.gan_type == 'WGAN':
75 + for p in self.dis.parameters():
76 + p.data.clamp_(-1, 1)
77 +
78 + self.loss /= self.gan_k
79 +
80 + # updating generator...
81 + d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
82 + if self.gan_type == 'GAN':
83 + label_real = torch.ones_like(d_fake_bp)
84 + loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
85 + elif self.gan_type.find('WGAN') >= 0:
86 + loss_g = -d_fake_bp.mean()
87 + elif self.gan_type == 'RGAN':
88 + better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
89 + better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
90 + loss_g = self.bce(better_fake, better_real)
91 +
92 + # Generator loss
93 + return loss_g
94 +
95 + def state_dict(self, *args, **kwargs):
96 + state_discriminator = self.dis.state_dict(*args, **kwargs)
97 + state_optimizer = self.optimizer.state_dict()
98 +
99 + return dict(**state_discriminator, **state_optimizer)
100 +
101 + def bce(self, real, fake):
102 + label_real = torch.ones_like(real)
103 + label_fake = torch.zeros_like(fake)
104 + bce_real = F.binary_cross_entropy_with_logits(real, label_real)
105 + bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
106 + bce_loss = bce_real + bce_fake
107 + return bce_loss
108 +
109 +# Some references
110 +# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
111 +# OR
112 +# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +class Discriminator(nn.Module):
6 + '''
7 + output is not normalized
8 + '''
9 + def __init__(self, args):
10 + super(Discriminator, self).__init__()
11 +
12 + in_channels = args.n_colors
13 + out_channels = 64
14 + depth = 7
15 +
16 + def _block(_in_channels, _out_channels, stride=1):
17 + return nn.Sequential(
18 + nn.Conv2d(
19 + _in_channels,
20 + _out_channels,
21 + 3,
22 + padding=1,
23 + stride=stride,
24 + bias=False
25 + ),
26 + nn.BatchNorm2d(_out_channels),
27 + nn.LeakyReLU(negative_slope=0.2, inplace=True)
28 + )
29 +
30 + m_features = [_block(in_channels, out_channels)]
31 + for i in range(depth):
32 + in_channels = out_channels
33 + if i % 2 == 1:
34 + stride = 1
35 + out_channels *= 2
36 + else:
37 + stride = 2
38 + m_features.append(_block(in_channels, out_channels, stride=stride))
39 +
40 + patch_size = args.patch_size // (2**((depth + 1) // 2))
41 + m_classifier = [
42 + nn.Linear(out_channels * patch_size**2, 1024),
43 + nn.LeakyReLU(negative_slope=0.2, inplace=True),
44 + nn.Linear(1024, 1)
45 + ]
46 +
47 + self.features = nn.Sequential(*m_features)
48 + self.classifier = nn.Sequential(*m_classifier)
49 +
50 + def forward(self, x):
51 + features = self.features(x)
52 + output = self.classifier(features.view(features.size(0), -1))
53 +
54 + return output
55 +
1 +from model import common
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +import torchvision.models as models
7 +
8 +class VGG(nn.Module):
9 + def __init__(self, conv_index, rgb_range=1):
10 + super(VGG, self).__init__()
11 + vgg_features = models.vgg19(pretrained=True).features
12 + modules = [m for m in vgg_features]
13 + if conv_index.find('22') >= 0:
14 + self.vgg = nn.Sequential(*modules[:8])
15 + elif conv_index.find('54') >= 0:
16 + self.vgg = nn.Sequential(*modules[:35])
17 +
18 + vgg_mean = (0.485, 0.456, 0.406)
19 + vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20 + self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21 + for p in self.parameters():
22 + p.requires_grad = False
23 +
24 + def forward(self, sr, hr):
25 + def _forward(x):
26 + x = self.sub_mean(x)
27 + x = self.vgg(x)
28 + return x
29 +
30 + vgg_sr = _forward(sr)
31 + with torch.no_grad():
32 + vgg_hr = _forward(hr.detach())
33 +
34 + loss = F.mse_loss(vgg_sr, vgg_hr)
35 +
36 + return loss
1 +import torch
2 +
3 +import utility
4 +import data
5 +import model
6 +import loss
7 +from option import args
8 +from trainer import Trainer
9 +
10 +torch.manual_seed(args.seed)
11 +checkpoint = utility.checkpoint(args)
12 +
13 +def main():
14 + global model
15 + if args.data_test == ['video']:
16 + from videotester import VideoTester
17 + model = model.Model(args, checkpoint)
18 + t = VideoTester(args, model, checkpoint)
19 + t.test()
20 + else:
21 + if checkpoint.ok:
22 + loader = data.Data(args)
23 + _model = model.Model(args, checkpoint)
24 + _loss = loss.Loss(args, checkpoint) if not args.test_only else None
25 + t = Trainer(args, loader, _model, _loss, checkpoint)
26 + while not t.terminate():
27 + t.train()
28 + t.test()
29 +
30 + checkpoint.done()
31 +
32 +if __name__ == '__main__':
33 + main()
1 +import os
2 +from importlib import import_module
3 +
4 +import torch
5 +import torch.nn as nn
6 +import torch.nn.parallel as P
7 +import torch.utils.model_zoo
8 +
9 +class Model(nn.Module):
10 + def __init__(self, args, ckp):
11 + super(Model, self).__init__()
12 + print('Making model...')
13 +
14 + self.scale = args.scale
15 + self.idx_scale = 0
16 + self.input_large = (args.model == 'VDSR')
17 + self.self_ensemble = args.self_ensemble
18 + self.chop = args.chop
19 + self.precision = args.precision
20 + self.cpu = args.cpu
21 + self.device = torch.device('cpu' if args.cpu else 'cuda')
22 + self.n_GPUs = args.n_GPUs
23 + self.save_models = args.save_models
24 +
25 + module = import_module('model.' + args.model.lower())
26 + self.model = module.make_model(args).to(self.device)
27 + if args.precision == 'half':
28 + self.model.half()
29 +
30 + self.load(
31 + ckp.get_path('model'),
32 + pre_train=args.pre_train,
33 + resume=args.resume,
34 + cpu=args.cpu
35 + )
36 + print(self.model, file=ckp.log_file)
37 +
38 + def forward(self, x, idx_scale):
39 + self.idx_scale = idx_scale
40 + if hasattr(self.model, 'set_scale'):
41 + self.model.set_scale(idx_scale)
42 +
43 + if self.training:
44 + if self.n_GPUs > 1:
45 + return P.data_parallel(self.model, x, range(self.n_GPUs))
46 + else:
47 + return self.model(x)
48 + else:
49 + if self.chop:
50 + forward_function = self.forward_chop
51 + else:
52 + forward_function = self.model.forward
53 +
54 + if self.self_ensemble:
55 + return self.forward_x8(x, forward_function=forward_function)
56 + else:
57 + return forward_function(x)
58 +
59 + def save(self, apath, epoch, is_best=False):
60 + save_dirs = [os.path.join(apath, 'model_latest.pt')]
61 +
62 + if is_best:
63 + save_dirs.append(os.path.join(apath, 'model_best.pt'))
64 + if self.save_models:
65 + save_dirs.append(
66 + os.path.join(apath, 'model_{}.pt'.format(epoch))
67 + )
68 +
69 + for s in save_dirs:
70 + torch.save(self.model.state_dict(), s)
71 +
72 + def load(self, apath, pre_train='', resume=-1, cpu=False):
73 + load_from = None
74 + kwargs = {}
75 + if cpu:
76 + kwargs = {'map_location': lambda storage, loc: storage}
77 +
78 + if resume == -1:
79 + load_from = torch.load(
80 + os.path.join(apath, 'model_latest.pt'),
81 + **kwargs
82 + )
83 + elif resume == 0:
84 + if pre_train == 'download':
85 + print('Download the model')
86 + dir_model = os.path.join('..', 'models')
87 + os.makedirs(dir_model, exist_ok=True)
88 + load_from = torch.utils.model_zoo.load_url(
89 + self.model.url,
90 + model_dir=dir_model,
91 + **kwargs
92 + )
93 + elif pre_train:
94 + print('Load the model from {}'.format(pre_train))
95 + load_from = torch.load(pre_train, **kwargs)
96 + else:
97 + load_from = torch.load(
98 + os.path.join(apath, 'model_{}.pt'.format(resume)),
99 + **kwargs
100 + )
101 +
102 + if load_from:
103 + self.model.load_state_dict(load_from, strict=False)
104 +
105 + def forward_chop(self, *args, shave=10, min_size=160000):
106 + scale = 1 if self.input_large else self.scale[self.idx_scale]
107 + n_GPUs = min(self.n_GPUs, 4)
108 + # height, width
109 + h, w = args[0].size()[-2:]
110 +
111 + top = slice(0, h//2 + shave)
112 + bottom = slice(h - h//2 - shave, h)
113 + left = slice(0, w//2 + shave)
114 + right = slice(w - w//2 - shave, w)
115 + x_chops = [torch.cat([
116 + a[..., top, left],
117 + a[..., top, right],
118 + a[..., bottom, left],
119 + a[..., bottom, right]
120 + ]) for a in args]
121 +
122 + y_chops = []
123 + if h * w < 4 * min_size:
124 + for i in range(0, 4, n_GPUs):
125 + x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
126 + y = P.data_parallel(self.model, *x, range(n_GPUs))
127 + if not isinstance(y, list): y = [y]
128 + if not y_chops:
129 + y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
130 + else:
131 + for y_chop, _y in zip(y_chops, y):
132 + y_chop.extend(_y.chunk(n_GPUs, dim=0))
133 + else:
134 + for p in zip(*x_chops):
135 + y = self.forward_chop(*p, shave=shave, min_size=min_size)
136 + if not isinstance(y, list): y = [y]
137 + if not y_chops:
138 + y_chops = [[_y] for _y in y]
139 + else:
140 + for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
141 +
142 + h *= scale
143 + w *= scale
144 + top = slice(0, h//2)
145 + bottom = slice(h - h//2, h)
146 + bottom_r = slice(h//2 - h, None)
147 + left = slice(0, w//2)
148 + right = slice(w - w//2, w)
149 + right_r = slice(w//2 - w, None)
150 +
151 + # batch size, number of color channels
152 + b, c = y_chops[0][0].size()[:-2]
153 + y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
154 + for y_chop, _y in zip(y_chops, y):
155 + _y[..., top, left] = y_chop[0][..., top, left]
156 + _y[..., top, right] = y_chop[1][..., top, right_r]
157 + _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
158 + _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
159 +
160 + if len(y) == 1: y = y[0]
161 +
162 + return y
163 +
164 + def forward_x8(self, *args, forward_function=None):
165 + def _transform(v, op):
166 + if self.precision != 'single': v = v.float()
167 +
168 + v2np = v.data.cpu().numpy()
169 + if op == 'v':
170 + tfnp = v2np[:, :, :, ::-1].copy()
171 + elif op == 'h':
172 + tfnp = v2np[:, :, ::-1, :].copy()
173 + elif op == 't':
174 + tfnp = v2np.transpose((0, 1, 3, 2)).copy()
175 +
176 + ret = torch.Tensor(tfnp).to(self.device)
177 + if self.precision == 'half': ret = ret.half()
178 +
179 + return ret
180 +
181 + list_x = []
182 + for a in args:
183 + x = [a]
184 + for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
185 +
186 + list_x.append(x)
187 +
188 + list_y = []
189 + for x in zip(*list_x):
190 + y = forward_function(*x)
191 + if not isinstance(y, list): y = [y]
192 + if not list_y:
193 + list_y = [[_y] for _y in y]
194 + else:
195 + for _list_y, _y in zip(list_y, y): _list_y.append(_y)
196 +
197 + for _list_y in list_y:
198 + for i in range(len(_list_y)):
199 + if i > 3:
200 + _list_y[i] = _transform(_list_y[i], 't')
201 + if i % 4 > 1:
202 + _list_y[i] = _transform(_list_y[i], 'h')
203 + if (i % 4) % 2 == 1:
204 + _list_y[i] = _transform(_list_y[i], 'v')
205 +
206 + y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
207 + if len(y) == 1: y = y[0]
208 +
209 + return y
1 +import math
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +
7 +def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 + return nn.Conv2d(
9 + in_channels, out_channels, kernel_size,
10 + padding=(kernel_size//2), bias=bias)
11 +
12 +class MeanShift(nn.Conv2d):
13 + def __init__(
14 + self, rgb_range,
15 + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
16 +
17 + super(MeanShift, self).__init__(3, 3, kernel_size=1)
18 + std = torch.Tensor(rgb_std)
19 + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
20 + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
21 + for p in self.parameters():
22 + p.requires_grad = False
23 +
24 +class BasicBlock(nn.Sequential):
25 + def __init__(
26 + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
27 + bn=True, act=nn.ReLU(True)):
28 +
29 + m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
30 + if bn:
31 + m.append(nn.BatchNorm2d(out_channels))
32 + if act is not None:
33 + m.append(act)
34 +
35 + super(BasicBlock, self).__init__(*m)
36 +
37 +class ResBlock(nn.Module):
38 + def __init__(
39 + self, conv, n_feats, kernel_size,
40 + bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
41 +
42 + super(ResBlock, self).__init__()
43 + m = []
44 + for i in range(2):
45 + m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
46 + if bn:
47 + m.append(nn.BatchNorm2d(n_feats))
48 + if i == 0:
49 + m.append(act)
50 +
51 + self.body = nn.Sequential(*m)
52 + self.res_scale = res_scale
53 +
54 + def forward(self, x):
55 + res = self.body(x).mul(self.res_scale)
56 + res += x
57 +
58 + return res
59 +
60 +class Upsampler(nn.Sequential):
61 + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
62 +
63 + m = []
64 + if (scale & (scale - 1)) == 0: # Is scale = 2^n?
65 + for _ in range(int(math.log(scale, 2))):
66 + m.append(conv(n_feats, 4 * n_feats, 3, bias))
67 + m.append(nn.PixelShuffle(2))
68 + if bn:
69 + m.append(nn.BatchNorm2d(n_feats))
70 + if act == 'relu':
71 + m.append(nn.ReLU(True))
72 + elif act == 'prelu':
73 + m.append(nn.PReLU(n_feats))
74 +
75 + elif scale == 3:
76 + m.append(conv(n_feats, 9 * n_feats, 3, bias))
77 + m.append(nn.PixelShuffle(3))
78 + if bn:
79 + m.append(nn.BatchNorm2d(n_feats))
80 + if act == 'relu':
81 + m.append(nn.ReLU(True))
82 + elif act == 'prelu':
83 + m.append(nn.PReLU(n_feats))
84 + else:
85 + raise NotImplementedError
86 +
87 + super(Upsampler, self).__init__(*m)
88 +
1 +# Deep Back-Projection Networks For Super-Resolution
2 +# https://arxiv.org/abs/1803.02735
3 +
4 +from model import common
5 +
6 +import torch
7 +import torch.nn as nn
8 +
9 +
10 +def make_model(args, parent=False):
11 + return DDBPN(args)
12 +
13 +def projection_conv(in_channels, out_channels, scale, up=True):
14 + kernel_size, stride, padding = {
15 + 2: (6, 2, 2),
16 + 4: (8, 4, 2),
17 + 8: (12, 8, 2)
18 + }[scale]
19 + if up:
20 + conv_f = nn.ConvTranspose2d
21 + else:
22 + conv_f = nn.Conv2d
23 +
24 + return conv_f(
25 + in_channels, out_channels, kernel_size,
26 + stride=stride, padding=padding
27 + )
28 +
29 +class DenseProjection(nn.Module):
30 + def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
31 + super(DenseProjection, self).__init__()
32 + if bottleneck:
33 + self.bottleneck = nn.Sequential(*[
34 + nn.Conv2d(in_channels, nr, 1),
35 + nn.PReLU(nr)
36 + ])
37 + inter_channels = nr
38 + else:
39 + self.bottleneck = None
40 + inter_channels = in_channels
41 +
42 + self.conv_1 = nn.Sequential(*[
43 + projection_conv(inter_channels, nr, scale, up),
44 + nn.PReLU(nr)
45 + ])
46 + self.conv_2 = nn.Sequential(*[
47 + projection_conv(nr, inter_channels, scale, not up),
48 + nn.PReLU(inter_channels)
49 + ])
50 + self.conv_3 = nn.Sequential(*[
51 + projection_conv(inter_channels, nr, scale, up),
52 + nn.PReLU(nr)
53 + ])
54 +
55 + def forward(self, x):
56 + if self.bottleneck is not None:
57 + x = self.bottleneck(x)
58 +
59 + a_0 = self.conv_1(x)
60 + b_0 = self.conv_2(a_0)
61 + e = b_0.sub(x)
62 + a_1 = self.conv_3(e)
63 +
64 + out = a_0.add(a_1)
65 +
66 + return out
67 +
68 +class DDBPN(nn.Module):
69 + def __init__(self, args):
70 + super(DDBPN, self).__init__()
71 + scale = args.scale[0]
72 +
73 + n0 = 128
74 + nr = 32
75 + self.depth = 6
76 +
77 + rgb_mean = (0.4488, 0.4371, 0.4040)
78 + rgb_std = (1.0, 1.0, 1.0)
79 + self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
80 + initial = [
81 + nn.Conv2d(args.n_colors, n0, 3, padding=1),
82 + nn.PReLU(n0),
83 + nn.Conv2d(n0, nr, 1),
84 + nn.PReLU(nr)
85 + ]
86 + self.initial = nn.Sequential(*initial)
87 +
88 + self.upmodules = nn.ModuleList()
89 + self.downmodules = nn.ModuleList()
90 + channels = nr
91 + for i in range(self.depth):
92 + self.upmodules.append(
93 + DenseProjection(channels, nr, scale, True, i > 1)
94 + )
95 + if i != 0:
96 + channels += nr
97 +
98 + channels = nr
99 + for i in range(self.depth - 1):
100 + self.downmodules.append(
101 + DenseProjection(channels, nr, scale, False, i != 0)
102 + )
103 + channels += nr
104 +
105 + reconstruction = [
106 + nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
107 + ]
108 + self.reconstruction = nn.Sequential(*reconstruction)
109 +
110 + self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
111 +
112 + def forward(self, x):
113 + x = self.sub_mean(x)
114 + x = self.initial(x)
115 +
116 + h_list = []
117 + l_list = []
118 + for i in range(self.depth - 1):
119 + if i == 0:
120 + l = x
121 + else:
122 + l = torch.cat(l_list, dim=1)
123 + h_list.append(self.upmodules[i](l))
124 + l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
125 +
126 + h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
127 + out = self.reconstruction(torch.cat(h_list, dim=1))
128 + out = self.add_mean(out)
129 +
130 + return out
131 +
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +url = {
6 + 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
7 + 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
8 + 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
9 + 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
10 + 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
11 + 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
12 +}
13 +
14 +def make_model(args, parent=False):
15 + return EDSR(args)
16 +
17 +class EDSR(nn.Module):
18 + def __init__(self, args, conv=common.default_conv):
19 + super(EDSR, self).__init__()
20 +
21 + n_resblocks = args.n_resblocks
22 + n_feats = args.n_feats
23 + kernel_size = 3
24 + scale = args.scale[0]
25 + act = nn.ReLU(True)
26 + url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
27 + if url_name in url:
28 + self.url = url[url_name]
29 + else:
30 + self.url = None
31 + self.sub_mean = common.MeanShift(args.rgb_range)
32 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
33 +
34 + # define head module
35 + m_head = [conv(args.n_colors, n_feats, kernel_size)]
36 +
37 + # define body module
38 + m_body = [
39 + common.ResBlock(
40 + conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
41 + ) for _ in range(n_resblocks)
42 + ]
43 + m_body.append(conv(n_feats, n_feats, kernel_size))
44 +
45 + # define tail module
46 + m_tail = [
47 + common.Upsampler(conv, scale, n_feats, act=False),
48 + conv(n_feats, args.n_colors, kernel_size)
49 + ]
50 +
51 + self.head = nn.Sequential(*m_head)
52 + self.body = nn.Sequential(*m_body)
53 + self.tail = nn.Sequential(*m_tail)
54 +
55 + def forward(self, x):
56 + x = self.sub_mean(x)
57 + x = self.head(x)
58 +
59 + res = self.body(x)
60 + res += x
61 +
62 + x = self.tail(res)
63 + x = self.add_mean(x)
64 +
65 + return x
66 +
67 + def load_state_dict(self, state_dict, strict=True):
68 + own_state = self.state_dict()
69 + for name, param in state_dict.items():
70 + if name in own_state:
71 + if isinstance(param, nn.Parameter):
72 + param = param.data
73 + try:
74 + own_state[name].copy_(param)
75 + except Exception:
76 + if name.find('tail') == -1:
77 + raise RuntimeError('While copying the parameter named {}, '
78 + 'whose dimensions in the model are {} and '
79 + 'whose dimensions in the checkpoint are {}.'
80 + .format(name, own_state[name].size(), param.size()))
81 + elif strict:
82 + if name.find('tail') == -1:
83 + raise KeyError('unexpected key "{}" in state_dict'
84 + .format(name))
85 +
1 +from model import common
2 +
3 +import torch.nn as nn
4 +
5 +url = {
6 + 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt',
7 + 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt'
8 +}
9 +
10 +def make_model(args, parent=False):
11 + return MDSR(args)
12 +
13 +class MDSR(nn.Module):
14 + def __init__(self, args, conv=common.default_conv):
15 + super(MDSR, self).__init__()
16 + n_resblocks = args.n_resblocks
17 + n_feats = args.n_feats
18 + kernel_size = 3
19 + act = nn.ReLU(True)
20 + self.scale_idx = 0
21 + self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
22 + self.sub_mean = common.MeanShift(args.rgb_range)
23 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
24 +
25 + m_head = [conv(args.n_colors, n_feats, kernel_size)]
26 +
27 + self.pre_process = nn.ModuleList([
28 + nn.Sequential(
29 + common.ResBlock(conv, n_feats, 5, act=act),
30 + common.ResBlock(conv, n_feats, 5, act=act)
31 + ) for _ in args.scale
32 + ])
33 +
34 + m_body = [
35 + common.ResBlock(
36 + conv, n_feats, kernel_size, act=act
37 + ) for _ in range(n_resblocks)
38 + ]
39 + m_body.append(conv(n_feats, n_feats, kernel_size))
40 +
41 + self.upsample = nn.ModuleList([
42 + common.Upsampler(conv, s, n_feats, act=False) for s in args.scale
43 + ])
44 +
45 + m_tail = [conv(n_feats, args.n_colors, kernel_size)]
46 +
47 + self.head = nn.Sequential(*m_head)
48 + self.body = nn.Sequential(*m_body)
49 + self.tail = nn.Sequential(*m_tail)
50 +
51 + def forward(self, x):
52 + x = self.sub_mean(x)
53 + x = self.head(x)
54 + x = self.pre_process[self.scale_idx](x)
55 +
56 + res = self.body(x)
57 + res += x
58 +
59 + x = self.upsample[self.scale_idx](res)
60 + x = self.tail(x)
61 + x = self.add_mean(x)
62 +
63 + return x
64 +
65 + def set_scale(self, scale_idx):
66 + self.scale_idx = scale_idx
67 +
1 +## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks
2 +## https://arxiv.org/abs/1807.02758
3 +from model import common
4 +
5 +import torch.nn as nn
6 +
7 +def make_model(args, parent=False):
8 + return RCAN(args)
9 +
10 +## Channel Attention (CA) Layer
11 +class CALayer(nn.Module):
12 + def __init__(self, channel, reduction=16):
13 + super(CALayer, self).__init__()
14 + # global average pooling: feature --> point
15 + self.avg_pool = nn.AdaptiveAvgPool2d(1)
16 + # feature channel downscale and upscale --> channel weight
17 + self.conv_du = nn.Sequential(
18 + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
19 + nn.ReLU(inplace=True),
20 + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
21 + nn.Sigmoid()
22 + )
23 +
24 + def forward(self, x):
25 + y = self.avg_pool(x)
26 + y = self.conv_du(y)
27 + return x * y
28 +
29 +## Residual Channel Attention Block (RCAB)
30 +class RCAB(nn.Module):
31 + def __init__(
32 + self, conv, n_feat, kernel_size, reduction,
33 + bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
34 +
35 + super(RCAB, self).__init__()
36 + modules_body = []
37 + for i in range(2):
38 + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
39 + if bn: modules_body.append(nn.BatchNorm2d(n_feat))
40 + if i == 0: modules_body.append(act)
41 + modules_body.append(CALayer(n_feat, reduction))
42 + self.body = nn.Sequential(*modules_body)
43 + self.res_scale = res_scale
44 +
45 + def forward(self, x):
46 + res = self.body(x)
47 + #res = self.body(x).mul(self.res_scale)
48 + res += x
49 + return res
50 +
51 +## Residual Group (RG)
52 +class ResidualGroup(nn.Module):
53 + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
54 + super(ResidualGroup, self).__init__()
55 + modules_body = []
56 + modules_body = [
57 + RCAB(
58 + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
59 + for _ in range(n_resblocks)]
60 + modules_body.append(conv(n_feat, n_feat, kernel_size))
61 + self.body = nn.Sequential(*modules_body)
62 +
63 + def forward(self, x):
64 + res = self.body(x)
65 + res += x
66 + return res
67 +
68 +## Residual Channel Attention Network (RCAN)
69 +class RCAN(nn.Module):
70 + def __init__(self, args, conv=common.default_conv):
71 + super(RCAN, self).__init__()
72 +
73 + n_resgroups = args.n_resgroups
74 + n_resblocks = args.n_resblocks
75 + n_feats = args.n_feats
76 + kernel_size = 3
77 + reduction = args.reduction
78 + scale = args.scale[0]
79 + act = nn.ReLU(True)
80 +
81 + # RGB mean for DIV2K
82 + self.sub_mean = common.MeanShift(args.rgb_range)
83 +
84 + # define head module
85 + modules_head = [conv(args.n_colors, n_feats, kernel_size)]
86 +
87 + # define body module
88 + modules_body = [
89 + ResidualGroup(
90 + conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
91 + for _ in range(n_resgroups)]
92 +
93 + modules_body.append(conv(n_feats, n_feats, kernel_size))
94 +
95 + # define tail module
96 + modules_tail = [
97 + common.Upsampler(conv, scale, n_feats, act=False),
98 + conv(n_feats, args.n_colors, kernel_size)]
99 +
100 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
101 +
102 + self.head = nn.Sequential(*modules_head)
103 + self.body = nn.Sequential(*modules_body)
104 + self.tail = nn.Sequential(*modules_tail)
105 +
106 + def forward(self, x):
107 + x = self.sub_mean(x)
108 + x = self.head(x)
109 +
110 + res = self.body(x)
111 + res += x
112 +
113 + x = self.tail(res)
114 + x = self.add_mean(x)
115 +
116 + return x
117 +
118 + def load_state_dict(self, state_dict, strict=False):
119 + own_state = self.state_dict()
120 + for name, param in state_dict.items():
121 + if name in own_state:
122 + if isinstance(param, nn.Parameter):
123 + param = param.data
124 + try:
125 + own_state[name].copy_(param)
126 + except Exception:
127 + if name.find('tail') >= 0:
128 + print('Replace pre-trained upsampler to new one...')
129 + else:
130 + raise RuntimeError('While copying the parameter named {}, '
131 + 'whose dimensions in the model are {} and '
132 + 'whose dimensions in the checkpoint are {}.'
133 + .format(name, own_state[name].size(), param.size()))
134 + elif strict:
135 + if name.find('tail') == -1:
136 + raise KeyError('unexpected key "{}" in state_dict'
137 + .format(name))
138 +
139 + if strict:
140 + missing = set(own_state.keys()) - set(state_dict.keys())
141 + if len(missing) > 0:
142 + raise KeyError('missing keys in state_dict: "{}"'.format(missing))
1 +# Residual Dense Network for Image Super-Resolution
2 +# https://arxiv.org/abs/1802.08797
3 +
4 +from model import common
5 +
6 +import torch
7 +import torch.nn as nn
8 +
9 +
10 +def make_model(args, parent=False):
11 + return RDN(args)
12 +
13 +class RDB_Conv(nn.Module):
14 + def __init__(self, inChannels, growRate, kSize=3):
15 + super(RDB_Conv, self).__init__()
16 + Cin = inChannels
17 + G = growRate
18 + self.conv = nn.Sequential(*[
19 + nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
20 + nn.ReLU()
21 + ])
22 +
23 + def forward(self, x):
24 + out = self.conv(x)
25 + return torch.cat((x, out), 1)
26 +
27 +class RDB(nn.Module):
28 + def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
29 + super(RDB, self).__init__()
30 + G0 = growRate0
31 + G = growRate
32 + C = nConvLayers
33 +
34 + convs = []
35 + for c in range(C):
36 + convs.append(RDB_Conv(G0 + c*G, G))
37 + self.convs = nn.Sequential(*convs)
38 +
39 + # Local Feature Fusion
40 + self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
41 +
42 + def forward(self, x):
43 + return self.LFF(self.convs(x)) + x
44 +
45 +class RDN(nn.Module):
46 + def __init__(self, args):
47 + super(RDN, self).__init__()
48 + r = args.scale[0]
49 + G0 = args.G0
50 + kSize = args.RDNkSize
51 +
52 + # number of RDB blocks, conv layers, out channels
53 + self.D, C, G = {
54 + 'A': (20, 6, 32),
55 + 'B': (16, 8, 64),
56 + }[args.RDNconfig]
57 +
58 + # Shallow feature extraction net
59 + self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
60 + self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
61 +
62 + # Redidual dense blocks and dense feature fusion
63 + self.RDBs = nn.ModuleList()
64 + for i in range(self.D):
65 + self.RDBs.append(
66 + RDB(growRate0 = G0, growRate = G, nConvLayers = C)
67 + )
68 +
69 + # Global Feature Fusion
70 + self.GFF = nn.Sequential(*[
71 + nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
72 + nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
73 + ])
74 +
75 + # Up-sampling net
76 + if r == 2 or r == 3:
77 + self.UPNet = nn.Sequential(*[
78 + nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
79 + nn.PixelShuffle(r),
80 + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
81 + ])
82 + elif r == 4:
83 + self.UPNet = nn.Sequential(*[
84 + nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
85 + nn.PixelShuffle(2),
86 + nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
87 + nn.PixelShuffle(2),
88 + nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
89 + ])
90 + else:
91 + raise ValueError("scale must be 2 or 3 or 4.")
92 +
93 + def forward(self, x):
94 + f__1 = self.SFENet1(x)
95 + x = self.SFENet2(f__1)
96 +
97 + RDBs_out = []
98 + for i in range(self.D):
99 + x = self.RDBs[i](x)
100 + RDBs_out.append(x)
101 +
102 + x = self.GFF(torch.cat(RDBs_out,1))
103 + x += f__1
104 +
105 + return self.UPNet(x)
1 +from model import common
2 +
3 +import torch.nn as nn
4 +import torch.nn.init as init
5 +
6 +url = {
7 + 'r20f64': ''
8 +}
9 +
10 +def make_model(args, parent=False):
11 + return VDSR(args)
12 +
13 +class VDSR(nn.Module):
14 + def __init__(self, args, conv=common.default_conv):
15 + super(VDSR, self).__init__()
16 +
17 + n_resblocks = args.n_resblocks
18 + n_feats = args.n_feats
19 + kernel_size = 3
20 + self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
21 + self.sub_mean = common.MeanShift(args.rgb_range)
22 + self.add_mean = common.MeanShift(args.rgb_range, sign=1)
23 +
24 + def basic_block(in_channels, out_channels, act):
25 + return common.BasicBlock(
26 + conv, in_channels, out_channels, kernel_size,
27 + bias=True, bn=False, act=act
28 + )
29 +
30 + # define body module
31 + m_body = []
32 + m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
33 + for _ in range(n_resblocks - 2):
34 + m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
35 + m_body.append(basic_block(n_feats, args.n_colors, None))
36 +
37 + self.body = nn.Sequential(*m_body)
38 +
39 + def forward(self, x):
40 + x = self.sub_mean(x)
41 + res = self.body(x)
42 + res += x
43 + x = self.add_mean(res)
44 +
45 + return x
46 +
1 +import argparse
2 +import template
3 +
4 +parser = argparse.ArgumentParser(description='EDSR and MDSR')
5 +
6 +parser.add_argument('--debug', action='store_true',
7 + help='Enables debug mode')
8 +parser.add_argument('--template', default='.',
9 + help='You can set various templates in option.py')
10 +
11 +# Hardware specifications
12 +parser.add_argument('--n_threads', type=int, default=6,
13 + help='number of threads for data loading')
14 +parser.add_argument('--cpu', action='store_true',
15 + help='use cpu only')
16 +parser.add_argument('--n_GPUs', type=int, default=1,
17 + help='number of GPUs')
18 +parser.add_argument('--seed', type=int, default=1,
19 + help='random seed')
20 +
21 +# Data specifications
22 +parser.add_argument('--dir_data', type=str, default='../../../dataset',
23 + help='dataset directory')
24 +parser.add_argument('--dir_demo', type=str, default='../test',
25 + help='demo image directory')
26 +parser.add_argument('--data_train', type=str, default='DIV2K',
27 + help='train dataset name')
28 +parser.add_argument('--data_test', type=str, default='DIV2K',
29 + help='test dataset name')
30 +parser.add_argument('--data_range', type=str, default='1-800/801-810',
31 + help='train/test data range')
32 +parser.add_argument('--ext', type=str, default='sep',
33 + help='dataset file extension')
34 +parser.add_argument('--scale', type=str, default='4',
35 + help='super resolution scale')
36 +parser.add_argument('--patch_size', type=int, default=192,
37 + help='output patch size')
38 +parser.add_argument('--rgb_range', type=int, default=255,
39 + help='maximum value of RGB')
40 +parser.add_argument('--n_colors', type=int, default=3,
41 + help='number of color channels to use')
42 +parser.add_argument('--chop', action='store_true',
43 + help='enable memory-efficient forward')
44 +parser.add_argument('--no_augment', action='store_true',
45 + help='do not use data augmentation')
46 +
47 +# Model specifications
48 +parser.add_argument('--model', default='EDSR',
49 + help='model name')
50 +
51 +parser.add_argument('--act', type=str, default='relu',
52 + help='activation function')
53 +parser.add_argument('--pre_train', type=str, default='',
54 + help='pre-trained model directory')
55 +parser.add_argument('--extend', type=str, default='.',
56 + help='pre-trained model directory')
57 +parser.add_argument('--n_resblocks', type=int, default=16,
58 + help='number of residual blocks')
59 +parser.add_argument('--n_feats', type=int, default=64,
60 + help='number of feature maps')
61 +parser.add_argument('--res_scale', type=float, default=1,
62 + help='residual scaling')
63 +parser.add_argument('--shift_mean', default=True,
64 + help='subtract pixel mean from the input')
65 +parser.add_argument('--dilation', action='store_true',
66 + help='use dilated convolution')
67 +parser.add_argument('--precision', type=str, default='single',
68 + choices=('single', 'half'),
69 + help='FP precision for test (single | half)')
70 +
71 +# Option for Residual dense network (RDN)
72 +parser.add_argument('--G0', type=int, default=64,
73 + help='default number of filters. (Use in RDN)')
74 +parser.add_argument('--RDNkSize', type=int, default=3,
75 + help='default kernel size. (Use in RDN)')
76 +parser.add_argument('--RDNconfig', type=str, default='B',
77 + help='parameters config of RDN. (Use in RDN)')
78 +
79 +# Option for Residual channel attention network (RCAN)
80 +parser.add_argument('--n_resgroups', type=int, default=10,
81 + help='number of residual groups')
82 +parser.add_argument('--reduction', type=int, default=16,
83 + help='number of feature maps reduction')
84 +
85 +# Training specifications
86 +parser.add_argument('--reset', action='store_true',
87 + help='reset the training')
88 +parser.add_argument('--test_every', type=int, default=1000,
89 + help='do test per every N batches')
90 +parser.add_argument('--epochs', type=int, default=300,
91 + help='number of epochs to train')
92 +parser.add_argument('--batch_size', type=int, default=16,
93 + help='input batch size for training')
94 +parser.add_argument('--split_batch', type=int, default=1,
95 + help='split the batch into smaller chunks')
96 +parser.add_argument('--self_ensemble', action='store_true',
97 + help='use self-ensemble method for test')
98 +parser.add_argument('--test_only', action='store_true',
99 + help='set this option to test the model')
100 +parser.add_argument('--gan_k', type=int, default=1,
101 + help='k value for adversarial loss')
102 +
103 +# Optimization specifications
104 +parser.add_argument('--lr', type=float, default=1e-4,
105 + help='learning rate')
106 +parser.add_argument('--decay', type=str, default='200',
107 + help='learning rate decay type')
108 +parser.add_argument('--gamma', type=float, default=0.5,
109 + help='learning rate decay factor for step decay')
110 +parser.add_argument('--optimizer', default='ADAM',
111 + choices=('SGD', 'ADAM', 'RMSprop'),
112 + help='optimizer to use (SGD | ADAM | RMSprop)')
113 +parser.add_argument('--momentum', type=float, default=0.9,
114 + help='SGD momentum')
115 +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
116 + help='ADAM beta')
117 +parser.add_argument('--epsilon', type=float, default=1e-8,
118 + help='ADAM epsilon for numerical stability')
119 +parser.add_argument('--weight_decay', type=float, default=0,
120 + help='weight decay')
121 +parser.add_argument('--gclip', type=float, default=0,
122 + help='gradient clipping threshold (0 = no clipping)')
123 +
124 +# Loss specifications
125 +parser.add_argument('--loss', type=str, default='1*L1',
126 + help='loss function configuration')
127 +parser.add_argument('--skip_threshold', type=float, default='1e8',
128 + help='skipping batch that has large error')
129 +
130 +# Log specifications
131 +parser.add_argument('--save', type=str, default='test',
132 + help='file name to save')
133 +parser.add_argument('--load', type=str, default='',
134 + help='file name to load')
135 +parser.add_argument('--resume', type=int, default=0,
136 + help='resume from specific checkpoint')
137 +parser.add_argument('--save_models', action='store_true',
138 + help='save all intermediate models')
139 +parser.add_argument('--print_every', type=int, default=100,
140 + help='how many batches to wait before logging training status')
141 +parser.add_argument('--save_results', action='store_true',
142 + help='save output results')
143 +parser.add_argument('--save_gt', action='store_true',
144 + help='save low-resolution and high-resolution images together')
145 +
146 +args = parser.parse_args()
147 +template.set_template(args)
148 +
149 +args.scale = list(map(lambda x: int(x), args.scale.split('+')))
150 +args.data_train = args.data_train.split('+')
151 +args.data_test = args.data_test.split('+')
152 +
153 +if args.epochs == 0:
154 + args.epochs = 1e8
155 +
156 +for arg in vars(args):
157 + if vars(args)[arg] == 'True':
158 + vars(args)[arg] = True
159 + elif vars(args)[arg] == 'False':
160 + vars(args)[arg] = False
161 +
1 +def set_template(args):
2 + # Set the templates here
3 + if args.template.find('jpeg') >= 0:
4 + args.data_train = 'DIV2K_jpeg'
5 + args.data_test = 'DIV2K_jpeg'
6 + args.epochs = 200
7 + args.decay = '100'
8 +
9 + if args.template.find('EDSR_paper') >= 0:
10 + args.model = 'EDSR'
11 + args.n_resblocks = 32
12 + args.n_feats = 256
13 + args.res_scale = 0.1
14 +
15 + if args.template.find('MDSR') >= 0:
16 + args.model = 'MDSR'
17 + args.patch_size = 48
18 + args.epochs = 650
19 +
20 + if args.template.find('DDBPN') >= 0:
21 + args.model = 'DDBPN'
22 + args.patch_size = 128
23 + args.scale = '4'
24 +
25 + args.data_test = 'Set5'
26 +
27 + args.batch_size = 20
28 + args.epochs = 1000
29 + args.decay = '500'
30 + args.gamma = 0.1
31 + args.weight_decay = 1e-4
32 +
33 + args.loss = '1*MSE'
34 +
35 + if args.template.find('GAN') >= 0:
36 + args.epochs = 200
37 + args.lr = 5e-5
38 + args.decay = '150'
39 +
40 + if args.template.find('RCAN') >= 0:
41 + args.model = 'RCAN'
42 + args.n_resgroups = 10
43 + args.n_resblocks = 20
44 + args.n_feats = 64
45 + args.chop = True
46 +
47 + if args.template.find('VDSR') >= 0:
48 + args.model = 'VDSR'
49 + args.n_resblocks = 20
50 + args.n_feats = 64
51 + args.patch_size = 41
52 + args.lr = 1e-1
53 +
1 +import os
2 +import math
3 +from decimal import Decimal
4 +
5 +import utility
6 +
7 +import torch
8 +import torch.nn.utils as utils
9 +from tqdm import tqdm
10 +
11 +class Trainer():
12 + def __init__(self, args, loader, my_model, my_loss, ckp):
13 + self.args = args
14 + self.scale = args.scale
15 +
16 + self.ckp = ckp
17 + self.loader_train = loader.loader_train
18 + self.loader_test = loader.loader_test
19 + self.model = my_model
20 + self.loss = my_loss
21 + self.optimizer = utility.make_optimizer(args, self.model)
22 +
23 + if self.args.load != '':
24 + self.optimizer.load(ckp.dir, epoch=len(ckp.log))
25 +
26 + self.error_last = 1e8
27 +
28 + def train(self):
29 + self.loss.step()
30 + epoch = self.optimizer.get_last_epoch() + 1
31 + lr = self.optimizer.get_lr()
32 +
33 + self.ckp.write_log(
34 + '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
35 + )
36 + self.loss.start_log()
37 + self.model.train()
38 +
39 + timer_data, timer_model = utility.timer(), utility.timer()
40 + # TEMP
41 + self.loader_train.dataset.set_scale(0)
42 + for batch, (lr, hr, _,) in enumerate(self.loader_train):
43 + lr, hr = self.prepare(lr, hr)
44 + timer_data.hold()
45 + timer_model.tic()
46 +
47 + self.optimizer.zero_grad()
48 + sr = self.model(lr, 0)
49 + loss = self.loss(sr, hr)
50 + loss.backward()
51 + if self.args.gclip > 0:
52 + utils.clip_grad_value_(
53 + self.model.parameters(),
54 + self.args.gclip
55 + )
56 + self.optimizer.step()
57 +
58 + timer_model.hold()
59 +
60 + if (batch + 1) % self.args.print_every == 0:
61 + self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
62 + (batch + 1) * self.args.batch_size,
63 + len(self.loader_train.dataset),
64 + self.loss.display_loss(batch),
65 + timer_model.release(),
66 + timer_data.release()))
67 +
68 + timer_data.tic()
69 +
70 + self.loss.end_log(len(self.loader_train))
71 + self.error_last = self.loss.log[-1, -1]
72 + self.optimizer.schedule()
73 +
74 + def test(self):
75 + torch.set_grad_enabled(False)
76 +
77 + epoch = self.optimizer.get_last_epoch()
78 + self.ckp.write_log('\nEvaluation:')
79 + self.ckp.add_log(
80 + torch.zeros(1, len(self.loader_test), len(self.scale))
81 + )
82 + self.model.eval()
83 +
84 + timer_test = utility.timer()
85 + if self.args.save_results: self.ckp.begin_background()
86 + for idx_data, d in enumerate(self.loader_test):
87 + for idx_scale, scale in enumerate(self.scale):
88 + d.dataset.set_scale(idx_scale)
89 + for lr, hr, filename in tqdm(d, ncols=80):
90 + lr, hr = self.prepare(lr, hr)
91 + sr = self.model(lr, idx_scale)
92 + sr = utility.quantize(sr, self.args.rgb_range)
93 +
94 + save_list = [sr]
95 + self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
96 + sr, hr, scale, self.args.rgb_range, dataset=d
97 + )
98 + if self.args.save_gt:
99 + save_list.extend([lr, hr])
100 +
101 + if self.args.save_results:
102 + self.ckp.save_results(d, filename[0], save_list, scale)
103 +
104 + self.ckp.log[-1, idx_data, idx_scale] /= len(d)
105 + best = self.ckp.log.max(0)
106 + self.ckp.write_log(
107 + '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
108 + d.dataset.name,
109 + scale,
110 + self.ckp.log[-1, idx_data, idx_scale],
111 + best[0][idx_data, idx_scale],
112 + best[1][idx_data, idx_scale] + 1
113 + )
114 + )
115 +
116 + self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
117 + self.ckp.write_log('Saving...')
118 +
119 + if self.args.save_results:
120 + self.ckp.end_background()
121 +
122 + if not self.args.test_only:
123 + self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
124 +
125 + self.ckp.write_log(
126 + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
127 + )
128 +
129 + torch.set_grad_enabled(True)
130 +
131 + def prepare(self, *args):
132 + device = torch.device('cpu' if self.args.cpu else 'cuda')
133 + def _prepare(tensor):
134 + if self.args.precision == 'half': tensor = tensor.half()
135 + return tensor.to(device)
136 +
137 + return [_prepare(a) for a in args]
138 +
139 + def terminate(self):
140 + if self.args.test_only:
141 + self.test()
142 + return True
143 + else:
144 + epoch = self.optimizer.get_last_epoch() + 1
145 + return epoch >= self.args.epochs
146 +
1 +import os
2 +import math
3 +import time
4 +import datetime
5 +from multiprocessing import Process
6 +from multiprocessing import Queue
7 +
8 +import matplotlib
9 +matplotlib.use('Agg')
10 +import matplotlib.pyplot as plt
11 +
12 +import numpy as np
13 +import imageio
14 +
15 +import torch
16 +import torch.optim as optim
17 +import torch.optim.lr_scheduler as lrs
18 +
19 +class timer():
20 + def __init__(self):
21 + self.acc = 0
22 + self.tic()
23 +
24 + def tic(self):
25 + self.t0 = time.time()
26 +
27 + def toc(self, restart=False):
28 + diff = time.time() - self.t0
29 + if restart: self.t0 = time.time()
30 + return diff
31 +
32 + def hold(self):
33 + self.acc += self.toc()
34 +
35 + def release(self):
36 + ret = self.acc
37 + self.acc = 0
38 +
39 + return ret
40 +
41 + def reset(self):
42 + self.acc = 0
43 +
44 +class checkpoint():
45 + def __init__(self, args):
46 + self.args = args
47 + self.ok = True
48 + self.log = torch.Tensor()
49 + now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
50 +
51 + if not args.load:
52 + if not args.save:
53 + args.save = now
54 + self.dir = os.path.join('..', 'experiment', args.save)
55 + else:
56 + self.dir = os.path.join('..', 'experiment', args.load)
57 + if os.path.exists(self.dir):
58 + self.log = torch.load(self.get_path('psnr_log.pt'))
59 + print('Continue from epoch {}...'.format(len(self.log)))
60 + else:
61 + args.load = ''
62 +
63 + if args.reset:
64 + os.system('rm -rf ' + self.dir)
65 + args.load = ''
66 +
67 + os.makedirs(self.dir, exist_ok=True)
68 + os.makedirs(self.get_path('model'), exist_ok=True)
69 + for d in args.data_test:
70 + os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
71 +
72 + open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
73 + self.log_file = open(self.get_path('log.txt'), open_type)
74 + with open(self.get_path('config.txt'), open_type) as f:
75 + f.write(now + '\n\n')
76 + for arg in vars(args):
77 + f.write('{}: {}\n'.format(arg, getattr(args, arg)))
78 + f.write('\n')
79 +
80 + self.n_processes = 8
81 +
82 + def get_path(self, *subdir):
83 + return os.path.join(self.dir, *subdir)
84 +
85 + def save(self, trainer, epoch, is_best=False):
86 + trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
87 + trainer.loss.save(self.dir)
88 + trainer.loss.plot_loss(self.dir, epoch)
89 +
90 + self.plot_psnr(epoch)
91 + trainer.optimizer.save(self.dir)
92 + torch.save(self.log, self.get_path('psnr_log.pt'))
93 +
94 + def add_log(self, log):
95 + self.log = torch.cat([self.log, log])
96 +
97 + def write_log(self, log, refresh=False):
98 + print(log)
99 + self.log_file.write(log + '\n')
100 + if refresh:
101 + self.log_file.close()
102 + self.log_file = open(self.get_path('log.txt'), 'a')
103 +
104 + def done(self):
105 + self.log_file.close()
106 +
107 + def plot_psnr(self, epoch):
108 + axis = np.linspace(1, epoch, epoch)
109 + for idx_data, d in enumerate(self.args.data_test):
110 + label = 'SR on {}'.format(d)
111 + fig = plt.figure()
112 + plt.title(label)
113 + for idx_scale, scale in enumerate(self.args.scale):
114 + plt.plot(
115 + axis,
116 + self.log[:, idx_data, idx_scale].numpy(),
117 + label='Scale {}'.format(scale)
118 + )
119 + plt.legend()
120 + plt.xlabel('Epochs')
121 + plt.ylabel('PSNR')
122 + plt.grid(True)
123 + plt.savefig(self.get_path('test_{}.pdf'.format(d)))
124 + plt.close(fig)
125 +
126 + def begin_background(self):
127 + self.queue = Queue()
128 +
129 + def bg_target(queue):
130 + while True:
131 + if not queue.empty():
132 + filename, tensor = queue.get()
133 + if filename is None: break
134 + imageio.imwrite(filename, tensor.numpy())
135 +
136 + self.process = [
137 + Process(target=bg_target, args=(self.queue,)) \
138 + for _ in range(self.n_processes)
139 + ]
140 +
141 + for p in self.process: p.start()
142 +
143 + def end_background(self):
144 + for _ in range(self.n_processes): self.queue.put((None, None))
145 + while not self.queue.empty(): time.sleep(1)
146 + for p in self.process: p.join()
147 +
148 + def save_results(self, dataset, filename, save_list, scale):
149 + if self.args.save_results:
150 + filename = self.get_path(
151 + 'results-{}'.format(dataset.dataset.name),
152 + '{}_x{}_'.format(filename, scale)
153 + )
154 +
155 + postfix = ('SR', 'LR', 'HR')
156 + for v, p in zip(save_list, postfix):
157 + normalized = v[0].mul(255 / self.args.rgb_range)
158 + tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
159 + self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
160 +
161 +def quantize(img, rgb_range):
162 + pixel_range = 255 / rgb_range
163 + return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
164 +
165 +def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
166 + if hr.nelement() == 1: return 0
167 +
168 + diff = (sr - hr) / rgb_range
169 + if dataset and dataset.dataset.benchmark:
170 + shave = scale
171 + if diff.size(1) > 1:
172 + gray_coeffs = [65.738, 129.057, 25.064]
173 + convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
174 + diff = diff.mul(convert).sum(dim=1)
175 + else:
176 + shave = scale + 6
177 +
178 + valid = diff[..., shave:-shave, shave:-shave]
179 + mse = valid.pow(2).mean()
180 +
181 + return -10 * math.log10(mse)
182 +
183 +def make_optimizer(args, target):
184 + '''
185 + make optimizer and scheduler together
186 + '''
187 + # optimizer
188 + trainable = filter(lambda x: x.requires_grad, target.parameters())
189 + kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
190 +
191 + if args.optimizer == 'SGD':
192 + optimizer_class = optim.SGD
193 + kwargs_optimizer['momentum'] = args.momentum
194 + elif args.optimizer == 'ADAM':
195 + optimizer_class = optim.Adam
196 + kwargs_optimizer['betas'] = args.betas
197 + kwargs_optimizer['eps'] = args.epsilon
198 + elif args.optimizer == 'RMSprop':
199 + optimizer_class = optim.RMSprop
200 + kwargs_optimizer['eps'] = args.epsilon
201 +
202 + # scheduler
203 + milestones = list(map(lambda x: int(x), args.decay.split('-')))
204 + kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
205 + scheduler_class = lrs.MultiStepLR
206 +
207 + class CustomOptimizer(optimizer_class):
208 + def __init__(self, *args, **kwargs):
209 + super(CustomOptimizer, self).__init__(*args, **kwargs)
210 +
211 + def _register_scheduler(self, scheduler_class, **kwargs):
212 + self.scheduler = scheduler_class(self, **kwargs)
213 +
214 + def save(self, save_dir):
215 + torch.save(self.state_dict(), self.get_dir(save_dir))
216 +
217 + def load(self, load_dir, epoch=1):
218 + self.load_state_dict(torch.load(self.get_dir(load_dir)))
219 + if epoch > 1:
220 + for _ in range(epoch): self.scheduler.step()
221 +
222 + def get_dir(self, dir_path):
223 + return os.path.join(dir_path, 'optimizer.pt')
224 +
225 + def schedule(self):
226 + self.scheduler.step()
227 +
228 + def get_lr(self):
229 + return self.scheduler.get_lr()[0]
230 +
231 + def get_last_epoch(self):
232 + return self.scheduler.last_epoch
233 +
234 + optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
235 + optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
236 + return optimizer
237 +
1 +import os
2 +import math
3 +
4 +import utility
5 +from data import common
6 +
7 +import torch
8 +import cv2
9 +
10 +from tqdm import tqdm
11 +
12 +class VideoTester():
13 + def __init__(self, args, my_model, ckp):
14 + self.args = args
15 + self.scale = args.scale
16 +
17 + self.ckp = ckp
18 + self.model = my_model
19 +
20 + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
21 +
22 + def test(self):
23 + torch.set_grad_enabled(False)
24 +
25 + self.ckp.write_log('\nEvaluation on video:')
26 + self.model.eval()
27 +
28 + timer_test = utility.timer()
29 + for idx_scale, scale in enumerate(self.scale):
30 + vidcap = cv2.VideoCapture(self.args.dir_demo)
31 + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
32 + vidwri = cv2.VideoWriter(
33 + self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
34 + cv2.VideoWriter_fourcc(*'XVID'),
35 + vidcap.get(cv2.CAP_PROP_FPS),
36 + (
37 + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
38 + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39 + )
40 + )
41 +
42 + tqdm_test = tqdm(range(total_frames), ncols=80)
43 + for _ in tqdm_test:
44 + success, lr = vidcap.read()
45 + if not success: break
46 +
47 + lr, = common.set_channel(lr, n_channels=self.args.n_colors)
48 + lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
49 + lr, = self.prepare(lr.unsqueeze(0))
50 + sr = self.model(lr, idx_scale)
51 + sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
52 +
53 + normalized = sr * 255 / self.args.rgb_range
54 + ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
55 + vidwri.write(ndarr)
56 +
57 + vidcap.release()
58 + vidwri.release()
59 +
60 + self.ckp.write_log(
61 + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
62 + )
63 + torch.set_grad_enabled(True)
64 +
65 + def prepare(self, *args):
66 + device = torch.device('cpu' if self.args.cpu else 'cuda')
67 + def _prepare(tensor):
68 + if self.args.precision == 'half': tensor = tensor.half()
69 + return tensor.to(device)
70 +
71 + return [_prepare(a) for a in args]
72 +