Showing
20 changed files
with
933 additions
and
0 deletions
carn/.gitignore
0 → 100644
carn/LICENSE
0 → 100644
1 | +MIT License | ||
2 | + | ||
3 | +Copyright (c) 2018 Namhyuk Ahn | ||
4 | + | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
6 | +of this software and associated documentation files (the "Software"), to deal | ||
7 | +in the Software without restriction, including without limitation the rights | ||
8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
9 | +copies of the Software, and to permit persons to whom the Software is | ||
10 | +furnished to do so, subject to the following conditions: | ||
11 | + | ||
12 | +The above copyright notice and this permission notice shall be included in all | ||
13 | +copies or substantial portions of the Software. | ||
14 | + | ||
15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
21 | +SOFTWARE. |
carn/README.md
0 → 100644
1 | +# Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network | ||
2 | +Namhyuk Ahn, Byungkon Kang, Kyung-Ah Sohn.<br> | ||
3 | +European Conference on Computer Vision (ECCV), 2018. | ||
4 | +[[arXiv](https://arxiv.org/abs/1803.08664)] | ||
5 | + | ||
6 | +<img src="assets/benchmark.png"> | ||
7 | + | ||
8 | +### Abstract | ||
9 | +In recent years, deep learning methods have been successfully applied to single-image super-resolution tasks. Despite their great performances, deep learning methods cannot be easily applied to real-world applications due to the requirement of heavy computation. In this paper, we address this issue by proposing an accurate and lightweight deep learning model for image super-resolution. In detail, we design an architecture that implements a cascading mechanism upon a residual network. We also present a variant model of the proposed cascading residual network to further improve efficiency. Our extensive experiments show that even with much fewer parameters and operations, our models achieve performance comparable to that of state-of-the-art methods. | ||
10 | + | ||
11 | +### FAQs | ||
12 | +1. Can't reproduce PSNR/SSIM as recorded in the paper: See [issue#6](https://github.com/nmhkahn/CARN-pytorch/issues/6) | ||
13 | + | ||
14 | +### Requirements | ||
15 | +- Python 3 | ||
16 | +- [PyTorch](https://github.com/pytorch/pytorch) (0.4.0), [torchvision](https://github.com/pytorch/vision) | ||
17 | +- Numpy, Scipy | ||
18 | +- Pillow, Scikit-image | ||
19 | +- h5py | ||
20 | +- importlib | ||
21 | + | ||
22 | +### Dataset | ||
23 | +We use DIV2K dataset for training and Set5, Set14, B100 and Urban100 dataset for the benchmark test. Here are the following steps to prepare datasets. | ||
24 | + | ||
25 | +1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) and unzip on `dataset` directory as below: | ||
26 | + ``` | ||
27 | + dataset | ||
28 | + └── DIV2K | ||
29 | + ├── DIV2K_train_HR | ||
30 | + ├── DIV2K_train_LR_bicubic | ||
31 | + ├── DIV2K_valid_HR | ||
32 | + └── DIV2K_valid_LR_bicubic | ||
33 | + ``` | ||
34 | +2. To accelerate training, we first convert training images to h5 format as follow (h5py module has to be installed). | ||
35 | +```shell | ||
36 | +$ cd datasets && python div2h5.py | ||
37 | +``` | ||
38 | +3. Other benchmark datasets can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1t2le0-Wz7GZQ4M2mJqmRamw5o4ce2AVw?usp=sharing). Same as DIV2K, please put all the datasets in `dataset` directory. | ||
39 | + | ||
40 | +### Test Pretrained Models | ||
41 | +We provide the pretrained models in `checkpoint` directory. To test CARN on benchmark dataset: | ||
42 | +```shell | ||
43 | +$ python carn/sample.py --model carn \ | ||
44 | + --test_data_dir dataset/<dataset> \ | ||
45 | + --scale [2|3|4] \ | ||
46 | + --ckpt_path ./checkpoint/<path>.pth \ | ||
47 | + --sample_dir <sample_dir> | ||
48 | +``` | ||
49 | +and for CARN-M, | ||
50 | +```shell | ||
51 | +$ python carn/sample.py --model carn_m \ | ||
52 | + --test_data_dir dataset/<dataset> \ | ||
53 | + --scale [2|3|4] \ | ||
54 | + --ckpt_path ./checkpoint/<path>.pth \ | ||
55 | + --sample_dir <sample_dir> \ | ||
56 | + --group 4 | ||
57 | +``` | ||
58 | +We provide our results on four benchmark dataset (Set5, Set14, B100 and Urban100). [Google Drive](https://drive.google.com/drive/folders/1R4vZMs3Adf8UlYbIzStY98qlsl5y1wxH?usp=sharing) | ||
59 | + | ||
60 | +### Training Models | ||
61 | +Here are our settings to train CARN and CARN-M. Note: We use two GPU to utilize large batch size, but if OOM error arise, please reduce batch size. | ||
62 | +```shell | ||
63 | +# For CARN | ||
64 | +$ python carn/train.py --patch_size 64 \ | ||
65 | + --batch_size 64 \ | ||
66 | + --max_steps 600000 \ | ||
67 | + --decay 400000 \ | ||
68 | + --model carn \ | ||
69 | + --ckpt_name carn \ | ||
70 | + --ckpt_dir checkpoint/carn \ | ||
71 | + --scale 0 \ | ||
72 | + --num_gpu 2 | ||
73 | +# For CARN-M | ||
74 | +$ python carn/train.py --patch_size 64 \ | ||
75 | + --batch_size 64 \ | ||
76 | + --max_steps 600000 \ | ||
77 | + --decay 400000 \ | ||
78 | + --model carn_m \ | ||
79 | + --ckpt_name carn_m \ | ||
80 | + --ckpt_dir checkpoint/carn_m \ | ||
81 | + --scale 0 \ | ||
82 | + --group 4 \ | ||
83 | + --num_gpu 2 | ||
84 | +``` | ||
85 | +In the `--scale` argument, [2, 3, 4] is for single-scale training and 0 for multi-scale learning. `--group` represents group size of group convolution. The differences from previous version are: 1) we increase batch size and patch size to 64 and 64. 2) Instead of using `reduce_upsample` argument which replace 3x3 conv of the upsample block to 1x1, we use group convolution as same way to the efficient residual block. | ||
86 | + | ||
87 | +### Results | ||
88 | +**Note:** As pointed out in [#2](https://github.com/nmhkahn/CARN-pytorch/issues/2), previous Urban100 benchmark dataset was incorrect. The issue is related to the mismatch of the HR image resolution from the original dataset in x2 and x3 scale. We correct this problem, and provided dataset and results are fixed ones. | ||
89 | + | ||
90 | +<img src="assets/table.png"> | ||
91 | +<img src="assets/visual.png"> | ||
92 | + | ||
93 | +### Citation | ||
94 | +``` | ||
95 | +@article{ahn2018fast, | ||
96 | + title={Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network}, | ||
97 | + author={Ahn, Namhyuk and Kang, Byungkon and Sohn, Kyung-Ah}, | ||
98 | + journal={arXiv preprint arXiv:1803.08664}, | ||
99 | + year={2018} | ||
100 | +} | ||
101 | +``` |
carn/assets/benchmark.png
0 → 100644
460 KB
carn/assets/table.png
0 → 100644
718 KB
carn/assets/visual.png
0 → 100644
5.91 MB
carn/carn/__init__.py
0 → 100644
File mode changed
carn/carn/dataset.py
0 → 100644
1 | +import os | ||
2 | +import glob | ||
3 | +import h5py | ||
4 | +import random | ||
5 | +import numpy as np | ||
6 | +from PIL import Image | ||
7 | +import torch.utils.data as data | ||
8 | +import torchvision.transforms as transforms | ||
9 | + | ||
10 | +def random_crop(hr, lr, size, scale): | ||
11 | + h, w = lr.shape[:-1] | ||
12 | + x = random.randint(0, w-size) | ||
13 | + y = random.randint(0, h-size) | ||
14 | + | ||
15 | + hsize = size*scale | ||
16 | + hx, hy = x*scale, y*scale | ||
17 | + | ||
18 | + crop_lr = lr[y:y+size, x:x+size].copy() | ||
19 | + crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy() | ||
20 | + | ||
21 | + return crop_hr, crop_lr | ||
22 | + | ||
23 | + | ||
24 | +def random_flip_and_rotate(im1, im2): | ||
25 | + if random.random() < 0.5: | ||
26 | + im1 = np.flipud(im1) | ||
27 | + im2 = np.flipud(im2) | ||
28 | + | ||
29 | + if random.random() < 0.5: | ||
30 | + im1 = np.fliplr(im1) | ||
31 | + im2 = np.fliplr(im2) | ||
32 | + | ||
33 | + angle = random.choice([0, 1, 2, 3]) | ||
34 | + im1 = np.rot90(im1, angle) | ||
35 | + im2 = np.rot90(im2, angle) | ||
36 | + | ||
37 | + # have to copy before be called by transform function | ||
38 | + return im1.copy(), im2.copy() | ||
39 | + | ||
40 | + | ||
41 | +class TrainDataset(data.Dataset): | ||
42 | + def __init__(self, path, size, scale): | ||
43 | + super(TrainDataset, self).__init__() | ||
44 | + | ||
45 | + self.size = size | ||
46 | + h5f = h5py.File(path, "r") | ||
47 | + | ||
48 | + self.hr = [v[:] for v in h5f["HR"].values()] | ||
49 | + # perform multi-scale training | ||
50 | + if scale == 0: | ||
51 | + self.scale = [2, 3, 4] | ||
52 | + self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale] | ||
53 | + else: | ||
54 | + self.scale = [scale] | ||
55 | + self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]] | ||
56 | + | ||
57 | + h5f.close() | ||
58 | + | ||
59 | + self.transform = transforms.Compose([ | ||
60 | + transforms.ToTensor() | ||
61 | + ]) | ||
62 | + | ||
63 | + def __getitem__(self, index): | ||
64 | + size = self.size | ||
65 | + | ||
66 | + item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)] | ||
67 | + item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)] | ||
68 | + item = [random_flip_and_rotate(hr, lr) for hr, lr in item] | ||
69 | + | ||
70 | + return [(self.transform(hr), self.transform(lr)) for hr, lr in item] | ||
71 | + | ||
72 | + def __len__(self): | ||
73 | + return len(self.hr) | ||
74 | + | ||
75 | + | ||
76 | +class TestDataset(data.Dataset): | ||
77 | + def __init__(self, dirname, scale): | ||
78 | + super(TestDataset, self).__init__() | ||
79 | + | ||
80 | + self.name = dirname.split("/")[-1] | ||
81 | + self.scale = scale | ||
82 | + | ||
83 | + if "DIV" in self.name: | ||
84 | + self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png")) | ||
85 | + self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname), | ||
86 | + "X{}/*.png".format(scale))) | ||
87 | + else: | ||
88 | + all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale))) | ||
89 | + self.hr = [name for name in all_files if "HR" in name] | ||
90 | + self.lr = [name for name in all_files if "LR" in name] | ||
91 | + | ||
92 | + self.hr.sort() | ||
93 | + self.lr.sort() | ||
94 | + | ||
95 | + self.transform = transforms.Compose([ | ||
96 | + transforms.ToTensor() | ||
97 | + ]) | ||
98 | + | ||
99 | + def __getitem__(self, index): | ||
100 | + hr = Image.open(self.hr[index]) | ||
101 | + lr = Image.open(self.lr[index]) | ||
102 | + | ||
103 | + hr = hr.convert("RGB") | ||
104 | + lr = lr.convert("RGB") | ||
105 | + filename = self.hr[index].split("/")[-1] | ||
106 | + | ||
107 | + return self.transform(hr), self.transform(lr), filename | ||
108 | + | ||
109 | + def __len__(self): | ||
110 | + return len(self.hr) |
carn/carn/model/__init__.py
0 → 100644
File mode changed
carn/carn/model/carn.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import model.ops as ops | ||
4 | + | ||
5 | +class Block(nn.Module): | ||
6 | + def __init__(self, | ||
7 | + in_channels, out_channels, | ||
8 | + group=1): | ||
9 | + super(Block, self).__init__() | ||
10 | + | ||
11 | + self.b1 = ops.ResidualBlock(64, 64) | ||
12 | + self.b2 = ops.ResidualBlock(64, 64) | ||
13 | + self.b3 = ops.ResidualBlock(64, 64) | ||
14 | + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) | ||
15 | + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) | ||
16 | + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) | ||
17 | + | ||
18 | + def forward(self, x): | ||
19 | + c0 = o0 = x | ||
20 | + | ||
21 | + b1 = self.b1(o0) | ||
22 | + c1 = torch.cat([c0, b1], dim=1) | ||
23 | + o1 = self.c1(c1) | ||
24 | + | ||
25 | + b2 = self.b2(o1) | ||
26 | + c2 = torch.cat([c1, b2], dim=1) | ||
27 | + o2 = self.c2(c2) | ||
28 | + | ||
29 | + b3 = self.b3(o2) | ||
30 | + c3 = torch.cat([c2, b3], dim=1) | ||
31 | + o3 = self.c3(c3) | ||
32 | + | ||
33 | + return o3 | ||
34 | + | ||
35 | + | ||
36 | +class Net(nn.Module): | ||
37 | + def __init__(self, **kwargs): | ||
38 | + super(Net, self).__init__() | ||
39 | + | ||
40 | + scale = kwargs.get("scale") | ||
41 | + multi_scale = kwargs.get("multi_scale") | ||
42 | + group = kwargs.get("group", 1) | ||
43 | + | ||
44 | + self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) | ||
45 | + self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) | ||
46 | + | ||
47 | + self.entry = nn.Conv2d(3, 64, 3, 1, 1) | ||
48 | + | ||
49 | + self.b1 = Block(64, 64) | ||
50 | + self.b2 = Block(64, 64) | ||
51 | + self.b3 = Block(64, 64) | ||
52 | + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) | ||
53 | + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) | ||
54 | + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) | ||
55 | + | ||
56 | + self.upsample = ops.UpsampleBlock(64, scale=scale, | ||
57 | + multi_scale=multi_scale, | ||
58 | + group=group) | ||
59 | + self.exit = nn.Conv2d(64, 3, 3, 1, 1) | ||
60 | + | ||
61 | + def forward(self, x, scale): | ||
62 | + x = self.sub_mean(x) | ||
63 | + x = self.entry(x) | ||
64 | + c0 = o0 = x | ||
65 | + | ||
66 | + b1 = self.b1(o0) | ||
67 | + c1 = torch.cat([c0, b1], dim=1) | ||
68 | + o1 = self.c1(c1) | ||
69 | + | ||
70 | + b2 = self.b2(o1) | ||
71 | + c2 = torch.cat([c1, b2], dim=1) | ||
72 | + o2 = self.c2(c2) | ||
73 | + | ||
74 | + b3 = self.b3(o2) | ||
75 | + c3 = torch.cat([c2, b3], dim=1) | ||
76 | + o3 = self.c3(c3) | ||
77 | + | ||
78 | + out = self.upsample(o3, scale=scale) | ||
79 | + | ||
80 | + out = self.exit(out) | ||
81 | + out = self.add_mean(out) | ||
82 | + | ||
83 | + return out |
carn/carn/model/carn_m.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import model.ops as ops | ||
4 | + | ||
5 | +class Block(nn.Module): | ||
6 | + def __init__(self, | ||
7 | + in_channels, out_channels, | ||
8 | + group=1): | ||
9 | + super(Block, self).__init__() | ||
10 | + | ||
11 | + self.b1 = ops.EResidualBlock(64, 64, group=group) | ||
12 | + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) | ||
13 | + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) | ||
14 | + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) | ||
15 | + | ||
16 | + def forward(self, x): | ||
17 | + c0 = o0 = x | ||
18 | + | ||
19 | + b1 = self.b1(o0) | ||
20 | + c1 = torch.cat([c0, b1], dim=1) | ||
21 | + o1 = self.c1(c1) | ||
22 | + | ||
23 | + b2 = self.b1(o1) | ||
24 | + c2 = torch.cat([c1, b2], dim=1) | ||
25 | + o2 = self.c2(c2) | ||
26 | + | ||
27 | + b3 = self.b1(o2) | ||
28 | + c3 = torch.cat([c2, b3], dim=1) | ||
29 | + o3 = self.c3(c3) | ||
30 | + | ||
31 | + return o3 | ||
32 | + | ||
33 | + | ||
34 | +class Net(nn.Module): | ||
35 | + def __init__(self, **kwargs): | ||
36 | + super(Net, self).__init__() | ||
37 | + | ||
38 | + scale = kwargs.get("scale") | ||
39 | + multi_scale = kwargs.get("multi_scale") | ||
40 | + group = kwargs.get("group", 1) | ||
41 | + | ||
42 | + self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) | ||
43 | + self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) | ||
44 | + | ||
45 | + self.entry = nn.Conv2d(3, 64, 3, 1, 1) | ||
46 | + | ||
47 | + self.b1 = Block(64, 64, group=group) | ||
48 | + self.b2 = Block(64, 64, group=group) | ||
49 | + self.b3 = Block(64, 64, group=group) | ||
50 | + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) | ||
51 | + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) | ||
52 | + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) | ||
53 | + | ||
54 | + self.upsample = ops.UpsampleBlock(64, scale=scale, | ||
55 | + multi_scale=multi_scale, | ||
56 | + group=group) | ||
57 | + self.exit = nn.Conv2d(64, 3, 3, 1, 1) | ||
58 | + | ||
59 | + def forward(self, x, scale): | ||
60 | + x = self.sub_mean(x) | ||
61 | + x = self.entry(x) | ||
62 | + c0 = o0 = x | ||
63 | + | ||
64 | + b1 = self.b1(o0) | ||
65 | + c1 = torch.cat([c0, b1], dim=1) | ||
66 | + o1 = self.c1(c1) | ||
67 | + | ||
68 | + b2 = self.b2(o1) | ||
69 | + c2 = torch.cat([c1, b2], dim=1) | ||
70 | + o2 = self.c2(c2) | ||
71 | + | ||
72 | + b3 = self.b3(o2) | ||
73 | + c3 = torch.cat([c2, b3], dim=1) | ||
74 | + o3 = self.c3(c3) | ||
75 | + | ||
76 | + out = self.upsample(o3, scale=scale) | ||
77 | + | ||
78 | + out = self.exit(out) | ||
79 | + out = self.add_mean(out) | ||
80 | + | ||
81 | + return out |
carn/carn/model/ops.py
0 → 100644
1 | +import math | ||
2 | +import torch | ||
3 | +import torch.nn as nn | ||
4 | +import torch.nn.init as init | ||
5 | +import torch.nn.functional as F | ||
6 | + | ||
7 | +def init_weights(modules): | ||
8 | + pass | ||
9 | + | ||
10 | + | ||
11 | +class MeanShift(nn.Module): | ||
12 | + def __init__(self, mean_rgb, sub): | ||
13 | + super(MeanShift, self).__init__() | ||
14 | + | ||
15 | + sign = -1 if sub else 1 | ||
16 | + r = mean_rgb[0] * sign | ||
17 | + g = mean_rgb[1] * sign | ||
18 | + b = mean_rgb[2] * sign | ||
19 | + | ||
20 | + self.shifter = nn.Conv2d(3, 3, 1, 1, 0) | ||
21 | + self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) | ||
22 | + self.shifter.bias.data = torch.Tensor([r, g, b]) | ||
23 | + | ||
24 | + # Freeze the mean shift layer | ||
25 | + for params in self.shifter.parameters(): | ||
26 | + params.requires_grad = False | ||
27 | + | ||
28 | + def forward(self, x): | ||
29 | + x = self.shifter(x) | ||
30 | + return x | ||
31 | + | ||
32 | + | ||
33 | +class BasicBlock(nn.Module): | ||
34 | + def __init__(self, | ||
35 | + in_channels, out_channels, | ||
36 | + ksize=3, stride=1, pad=1): | ||
37 | + super(BasicBlock, self).__init__() | ||
38 | + | ||
39 | + self.body = nn.Sequential( | ||
40 | + nn.Conv2d(in_channels, out_channels, ksize, stride, pad), | ||
41 | + nn.ReLU(inplace=True) | ||
42 | + ) | ||
43 | + | ||
44 | + init_weights(self.modules) | ||
45 | + | ||
46 | + def forward(self, x): | ||
47 | + out = self.body(x) | ||
48 | + return out | ||
49 | + | ||
50 | + | ||
51 | +class ResidualBlock(nn.Module): | ||
52 | + def __init__(self, | ||
53 | + in_channels, out_channels): | ||
54 | + super(ResidualBlock, self).__init__() | ||
55 | + | ||
56 | + self.body = nn.Sequential( | ||
57 | + nn.Conv2d(in_channels, out_channels, 3, 1, 1), | ||
58 | + nn.ReLU(inplace=True), | ||
59 | + nn.Conv2d(out_channels, out_channels, 3, 1, 1), | ||
60 | + ) | ||
61 | + | ||
62 | + init_weights(self.modules) | ||
63 | + | ||
64 | + def forward(self, x): | ||
65 | + out = self.body(x) | ||
66 | + out = F.relu(out + x) | ||
67 | + return out | ||
68 | + | ||
69 | + | ||
70 | +class EResidualBlock(nn.Module): | ||
71 | + def __init__(self, | ||
72 | + in_channels, out_channels, | ||
73 | + group=1): | ||
74 | + super(EResidualBlock, self).__init__() | ||
75 | + | ||
76 | + self.body = nn.Sequential( | ||
77 | + nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), | ||
78 | + nn.ReLU(inplace=True), | ||
79 | + nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), | ||
80 | + nn.ReLU(inplace=True), | ||
81 | + nn.Conv2d(out_channels, out_channels, 1, 1, 0), | ||
82 | + ) | ||
83 | + | ||
84 | + init_weights(self.modules) | ||
85 | + | ||
86 | + def forward(self, x): | ||
87 | + out = self.body(x) | ||
88 | + out = F.relu(out + x) | ||
89 | + return out | ||
90 | + | ||
91 | + | ||
92 | +class UpsampleBlock(nn.Module): | ||
93 | + def __init__(self, | ||
94 | + n_channels, scale, multi_scale, | ||
95 | + group=1): | ||
96 | + super(UpsampleBlock, self).__init__() | ||
97 | + | ||
98 | + if multi_scale: | ||
99 | + self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) | ||
100 | + self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) | ||
101 | + self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) | ||
102 | + else: | ||
103 | + self.up = _UpsampleBlock(n_channels, scale=scale, group=group) | ||
104 | + | ||
105 | + self.multi_scale = multi_scale | ||
106 | + | ||
107 | + def forward(self, x, scale): | ||
108 | + if self.multi_scale: | ||
109 | + if scale == 2: | ||
110 | + return self.up2(x) | ||
111 | + elif scale == 3: | ||
112 | + return self.up3(x) | ||
113 | + elif scale == 4: | ||
114 | + return self.up4(x) | ||
115 | + else: | ||
116 | + return self.up(x) | ||
117 | + | ||
118 | + | ||
119 | +class _UpsampleBlock(nn.Module): | ||
120 | + def __init__(self, | ||
121 | + n_channels, scale, | ||
122 | + group=1): | ||
123 | + super(_UpsampleBlock, self).__init__() | ||
124 | + | ||
125 | + modules = [] | ||
126 | + if scale == 2 or scale == 4 or scale == 8: | ||
127 | + for _ in range(int(math.log(scale, 2))): | ||
128 | + modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] | ||
129 | + modules += [nn.PixelShuffle(2)] | ||
130 | + elif scale == 3: | ||
131 | + modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] | ||
132 | + modules += [nn.PixelShuffle(3)] | ||
133 | + | ||
134 | + self.body = nn.Sequential(*modules) | ||
135 | + init_weights(self.modules) | ||
136 | + | ||
137 | + def forward(self, x): | ||
138 | + out = self.body(x) | ||
139 | + return out |
carn/carn/sample.py
0 → 100644
1 | +import os | ||
2 | +import json | ||
3 | +import time | ||
4 | +import importlib | ||
5 | +import argparse | ||
6 | +import numpy as np | ||
7 | +from collections import OrderedDict | ||
8 | +import torch | ||
9 | +import torch.nn as nn | ||
10 | +from torch.autograd import Variable | ||
11 | +from dataset import TestDataset | ||
12 | +from PIL import Image | ||
13 | + | ||
14 | +def parse_args(): | ||
15 | + parser = argparse.ArgumentParser() | ||
16 | + parser.add_argument("--model", type=str) | ||
17 | + parser.add_argument("--ckpt_path", type=str) | ||
18 | + parser.add_argument("--group", type=int, default=1) | ||
19 | + parser.add_argument("--sample_dir", type=str) | ||
20 | + parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100") | ||
21 | + parser.add_argument("--cuda", action="store_true") | ||
22 | + parser.add_argument("--scale", type=int, default=4) | ||
23 | + parser.add_argument("--shave", type=int, default=20) | ||
24 | + | ||
25 | + return parser.parse_args() | ||
26 | + | ||
27 | + | ||
28 | +def save_image(tensor, filename): | ||
29 | + tensor = tensor.cpu() | ||
30 | + ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | ||
31 | + im = Image.fromarray(ndarr) | ||
32 | + im.save(filename) | ||
33 | + | ||
34 | + | ||
35 | +def sample(net, device, dataset, cfg): | ||
36 | + scale = cfg.scale | ||
37 | + for step, (hr, lr, name) in enumerate(dataset): | ||
38 | + if "DIV2K" in dataset.name: | ||
39 | + t1 = time.time() | ||
40 | + h, w = lr.size()[1:] | ||
41 | + h_half, w_half = int(h/2), int(w/2) | ||
42 | + h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave | ||
43 | + | ||
44 | + lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float) | ||
45 | + lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop]) | ||
46 | + lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w]) | ||
47 | + lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop]) | ||
48 | + lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w]) | ||
49 | + lr_patch = lr_patch.to(device) | ||
50 | + | ||
51 | + sr = net(lr_patch, cfg.scale).detach() | ||
52 | + | ||
53 | + h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale | ||
54 | + w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale | ||
55 | + | ||
56 | + result = torch.tensor((3, h, w), dtype=torch.float).to(device) | ||
57 | + result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half]) | ||
58 | + result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop]) | ||
59 | + result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half]) | ||
60 | + result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop]) | ||
61 | + sr = result | ||
62 | + t2 = time.time() | ||
63 | + else: | ||
64 | + t1 = time.time() | ||
65 | + lr = lr.unsqueeze(0).to(device) | ||
66 | + sr = net(lr, cfg.scale).detach().squeeze(0) | ||
67 | + lr = lr.squeeze(0) | ||
68 | + t2 = time.time() | ||
69 | + | ||
70 | + model_name = cfg.ckpt_path.split(".")[0].split("/")[-1] | ||
71 | + sr_dir = os.path.join(cfg.sample_dir, | ||
72 | + model_name, | ||
73 | + cfg.test_data_dir.split("/")[-1], | ||
74 | + "x{}".format(cfg.scale), | ||
75 | + "SR") | ||
76 | + hr_dir = os.path.join(cfg.sample_dir, | ||
77 | + model_name, | ||
78 | + cfg.test_data_dir.split("/")[-1], | ||
79 | + "x{}".format(cfg.scale), | ||
80 | + "HR") | ||
81 | + | ||
82 | + os.makedirs(sr_dir, exist_ok=True) | ||
83 | + os.makedirs(hr_dir, exist_ok=True) | ||
84 | + | ||
85 | + sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR"))) | ||
86 | + hr_im_path = os.path.join(hr_dir, "{}".format(name)) | ||
87 | + | ||
88 | + save_image(sr, sr_im_path) | ||
89 | + save_image(hr, hr_im_path) | ||
90 | + print("Saved {} ({}x{} -> {}x{}, {:.3f}s)" | ||
91 | + .format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1)) | ||
92 | + | ||
93 | + | ||
94 | +def main(cfg): | ||
95 | + module = importlib.import_module("model.{}".format(cfg.model)) | ||
96 | + net = module.Net(multi_scale=True, | ||
97 | + group=cfg.group) | ||
98 | + print(json.dumps(vars(cfg), indent=4, sort_keys=True)) | ||
99 | + | ||
100 | + state_dict = torch.load(cfg.ckpt_path) | ||
101 | + new_state_dict = OrderedDict() | ||
102 | + for k, v in state_dict.items(): | ||
103 | + name = k | ||
104 | + # name = k[7:] # remove "module." | ||
105 | + new_state_dict[name] = v | ||
106 | + | ||
107 | + net.load_state_dict(new_state_dict) | ||
108 | + | ||
109 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
110 | + net = net.to(device) | ||
111 | + | ||
112 | + dataset = TestDataset(cfg.test_data_dir, cfg.scale) | ||
113 | + sample(net, device, dataset, cfg) | ||
114 | + | ||
115 | + | ||
116 | +if __name__ == "__main__": | ||
117 | + cfg = parse_args() | ||
118 | + main(cfg) |
carn/carn/solver.py
0 → 100644
1 | +import os | ||
2 | +import random | ||
3 | +import numpy as np | ||
4 | +import scipy.misc as misc | ||
5 | +import skimage.measure as measure | ||
6 | +from tensorboardX import SummaryWriter | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +import torch.optim as optim | ||
10 | +from torch.utils.data import DataLoader | ||
11 | +from dataset import TrainDataset, TestDataset | ||
12 | + | ||
13 | +class Solver(): | ||
14 | + def __init__(self, model, cfg): | ||
15 | + if cfg.scale > 0: | ||
16 | + self.refiner = model(scale=cfg.scale, | ||
17 | + group=cfg.group) | ||
18 | + else: | ||
19 | + self.refiner = model(multi_scale=True, | ||
20 | + group=cfg.group) | ||
21 | + | ||
22 | + if cfg.loss_fn in ["MSE"]: | ||
23 | + self.loss_fn = nn.MSELoss() | ||
24 | + elif cfg.loss_fn in ["L1"]: | ||
25 | + self.loss_fn = nn.L1Loss() | ||
26 | + elif cfg.loss_fn in ["SmoothL1"]: | ||
27 | + self.loss_fn = nn.SmoothL1Loss() | ||
28 | + | ||
29 | + self.optim = optim.Adam( | ||
30 | + filter(lambda p: p.requires_grad, self.refiner.parameters()), | ||
31 | + cfg.lr) | ||
32 | + | ||
33 | + self.train_data = TrainDataset(cfg.train_data_path, | ||
34 | + scale=cfg.scale, | ||
35 | + size=cfg.patch_size) | ||
36 | + self.train_loader = DataLoader(self.train_data, | ||
37 | + batch_size=cfg.batch_size, | ||
38 | + num_workers=1, | ||
39 | + shuffle=True, drop_last=True) | ||
40 | + | ||
41 | + | ||
42 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
43 | + self.refiner = self.refiner.to(self.device) | ||
44 | + self.loss_fn = self.loss_fn | ||
45 | + | ||
46 | + self.cfg = cfg | ||
47 | + self.step = 0 | ||
48 | + | ||
49 | + self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name)) | ||
50 | + if cfg.verbose: | ||
51 | + num_params = 0 | ||
52 | + for param in self.refiner.parameters(): | ||
53 | + num_params += param.nelement() | ||
54 | + print("# of params:", num_params) | ||
55 | + | ||
56 | + os.makedirs(cfg.ckpt_dir, exist_ok=True) | ||
57 | + | ||
58 | + def fit(self): | ||
59 | + cfg = self.cfg | ||
60 | + refiner = nn.DataParallel(self.refiner, | ||
61 | + device_ids=range(cfg.num_gpu)) | ||
62 | + | ||
63 | + learning_rate = cfg.lr | ||
64 | + while True: | ||
65 | + for inputs in self.train_loader: | ||
66 | + self.refiner.train() | ||
67 | + | ||
68 | + if cfg.scale > 0: | ||
69 | + scale = cfg.scale | ||
70 | + hr, lr = inputs[-1][0], inputs[-1][1] | ||
71 | + else: | ||
72 | + # only use one of multi-scale data | ||
73 | + # i know this is stupid but just temporary | ||
74 | + scale = random.randint(2, 4) | ||
75 | + hr, lr = inputs[scale-2][0], inputs[scale-2][1] | ||
76 | + | ||
77 | + hr = hr.to(self.device) | ||
78 | + lr = lr.to(self.device) | ||
79 | + | ||
80 | + sr = refiner(lr, scale) | ||
81 | + loss = self.loss_fn(sr, hr) | ||
82 | + | ||
83 | + self.optim.zero_grad() | ||
84 | + loss.backward() | ||
85 | + nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip) | ||
86 | + self.optim.step() | ||
87 | + | ||
88 | + learning_rate = self.decay_learning_rate() | ||
89 | + for param_group in self.optim.param_groups: | ||
90 | + param_group["lr"] = learning_rate | ||
91 | + | ||
92 | + self.step += 1 | ||
93 | + if cfg.verbose and self.step % cfg.print_interval == 0: | ||
94 | + if cfg.scale > 0: | ||
95 | + psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) | ||
96 | + self.writer.add_scalar("Urban100", psnr, self.step) | ||
97 | + else: | ||
98 | + psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] | ||
99 | + self.writer.add_scalar("Urban100_2x", psnr[0], self.step) | ||
100 | + self.writer.add_scalar("Urban100_3x", psnr[1], self.step) | ||
101 | + self.writer.add_scalar("Urban100_4x", psnr[2], self.step) | ||
102 | + | ||
103 | + self.save(cfg.ckpt_dir, cfg.ckpt_name) | ||
104 | + | ||
105 | + if self.step > cfg.max_steps: break | ||
106 | + | ||
107 | + def evaluate(self, test_data_dir, scale=2, num_step=0): | ||
108 | + cfg = self.cfg | ||
109 | + mean_psnr = 0 | ||
110 | + self.refiner.eval() | ||
111 | + | ||
112 | + test_data = TestDataset(test_data_dir, scale=scale) | ||
113 | + test_loader = DataLoader(test_data, | ||
114 | + batch_size=1, | ||
115 | + num_workers=1, | ||
116 | + shuffle=False) | ||
117 | + | ||
118 | + for step, inputs in enumerate(test_loader): | ||
119 | + hr = inputs[0].squeeze(0) | ||
120 | + lr = inputs[1].squeeze(0) | ||
121 | + name = inputs[2][0] | ||
122 | + | ||
123 | + h, w = lr.size()[1:] | ||
124 | + h_half, w_half = int(h/2), int(w/2) | ||
125 | + h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave | ||
126 | + | ||
127 | + # split large image to 4 patch to avoid OOM error | ||
128 | + lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop) | ||
129 | + lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop]) | ||
130 | + lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w]) | ||
131 | + lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop]) | ||
132 | + lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w]) | ||
133 | + lr_patch = lr_patch.to(self.device) | ||
134 | + | ||
135 | + # run refine process in here! | ||
136 | + sr = self.refiner(lr_patch, scale).data | ||
137 | + | ||
138 | + h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale | ||
139 | + w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale | ||
140 | + | ||
141 | + # merge splited patch images | ||
142 | + result = torch.FloatTensor(3, h, w).to(self.device) | ||
143 | + result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half]) | ||
144 | + result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop]) | ||
145 | + result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half]) | ||
146 | + result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop]) | ||
147 | + sr = result | ||
148 | + | ||
149 | + hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | ||
150 | + sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | ||
151 | + | ||
152 | + # evaluate PSNR | ||
153 | + # this evaluation is different to MATLAB version | ||
154 | + # we evaluate PSNR in RGB channel not Y in YCbCR | ||
155 | + bnd = scale | ||
156 | + im1 = hr[bnd:-bnd, bnd:-bnd] | ||
157 | + im2 = sr[bnd:-bnd, bnd:-bnd] | ||
158 | + mean_psnr += psnr(im1, im2) / len(test_data) | ||
159 | + | ||
160 | + return mean_psnr | ||
161 | + | ||
162 | + def load(self, path): | ||
163 | + self.refiner.load_state_dict(torch.load(path)) | ||
164 | + splited = path.split(".")[0].split("_")[-1] | ||
165 | + try: | ||
166 | + self.step = int(path.split(".")[0].split("_")[-1]) | ||
167 | + except ValueError: | ||
168 | + self.step = 0 | ||
169 | + print("Load pretrained {} model".format(path)) | ||
170 | + | ||
171 | + def save(self, ckpt_dir, ckpt_name): | ||
172 | + save_path = os.path.join( | ||
173 | + ckpt_dir, "{}_{}.pth".format(ckpt_name, self.step)) | ||
174 | + torch.save(self.refiner.state_dict(), save_path) | ||
175 | + | ||
176 | + def decay_learning_rate(self): | ||
177 | + lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) | ||
178 | + return lr | ||
179 | + | ||
180 | + | ||
181 | +def psnr(im1, im2): | ||
182 | + def im2double(im): | ||
183 | + min_val, max_val = 0, 255 | ||
184 | + out = (im.astype(np.float64)-min_val) / (max_val-min_val) | ||
185 | + return out | ||
186 | + | ||
187 | + im1 = im2double(im1) | ||
188 | + im2 = im2double(im2) | ||
189 | + psnr = measure.compare_psnr(im1, im2, data_range=1) | ||
190 | + return psnr |
carn/carn/train.py
0 → 100644
1 | +import os | ||
2 | +import sys | ||
3 | +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
4 | +import json | ||
5 | +import argparse | ||
6 | +import importlib | ||
7 | +from solver import Solver | ||
8 | + | ||
9 | +def parse_args(): | ||
10 | + parser = argparse.ArgumentParser() | ||
11 | + parser.add_argument("--model", type=str) | ||
12 | + parser.add_argument("--ckpt_name", type=str) | ||
13 | + | ||
14 | + parser.add_argument("--print_interval", type=int, default=1000) | ||
15 | + parser.add_argument("--train_data_path", type=str, | ||
16 | + default="dataset/DIV2K_train.h5") | ||
17 | + parser.add_argument("--ckpt_dir", type=str, | ||
18 | + default="checkpoint") | ||
19 | + parser.add_argument("--sample_dir", type=str, | ||
20 | + default="sample/") | ||
21 | + | ||
22 | + parser.add_argument("--num_gpu", type=int, default=1) | ||
23 | + parser.add_argument("--shave", type=int, default=20) | ||
24 | + parser.add_argument("--scale", type=int, default=2) | ||
25 | + | ||
26 | + parser.add_argument("--verbose", action="store_true", default="store_true") | ||
27 | + | ||
28 | + parser.add_argument("--group", type=int, default=1) | ||
29 | + | ||
30 | + parser.add_argument("--patch_size", type=int, default=64) | ||
31 | + parser.add_argument("--batch_size", type=int, default=64) | ||
32 | + parser.add_argument("--max_steps", type=int, default=200000) | ||
33 | + parser.add_argument("--decay", type=int, default=150000) | ||
34 | + parser.add_argument("--lr", type=float, default=0.0001) | ||
35 | + parser.add_argument("--clip", type=float, default=10.0) | ||
36 | + | ||
37 | + parser.add_argument("--loss_fn", type=str, | ||
38 | + choices=["MSE", "L1", "SmoothL1"], default="L1") | ||
39 | + | ||
40 | + return parser.parse_args() | ||
41 | + | ||
42 | +def main(cfg): | ||
43 | + # dynamic import using --model argument | ||
44 | + net = importlib.import_module("model.{}".format(cfg.model)).Net | ||
45 | + print(json.dumps(vars(cfg), indent=4, sort_keys=True)) | ||
46 | + | ||
47 | + solver = Solver(net, cfg) | ||
48 | + solver.fit() | ||
49 | + | ||
50 | +if __name__ == "__main__": | ||
51 | + cfg = parse_args() | ||
52 | + main(cfg) |
carn/checkpoint/carn.pth
0 → 100644
No preview for this file type
carn/checkpoint/carn_m.pth
0 → 100644
No preview for this file type
carn/dataset/.gitignore
0 → 100644
carn/dataset/div2h5.py
0 → 100644
1 | +import os | ||
2 | +import glob | ||
3 | +import h5py | ||
4 | +import scipy.misc as misc | ||
5 | +import numpy as np | ||
6 | + | ||
7 | +dataset_dir = "DIV2K/" | ||
8 | +dataset_type = "train" | ||
9 | + | ||
10 | +f = h5py.File("DIV2K_{}.h5".format(dataset_type), "w") | ||
11 | +dt = h5py.special_dtype(vlen=np.dtype('uint8')) | ||
12 | + | ||
13 | +for subdir in ["HR", "X2", "X3", "X4"]: | ||
14 | + if subdir in ["HR"]: | ||
15 | + im_paths = glob.glob(os.path.join(dataset_dir, | ||
16 | + "DIV2K_{}_HR".format(dataset_type), | ||
17 | + "*.png")) | ||
18 | + | ||
19 | + else: | ||
20 | + im_paths = glob.glob(os.path.join(dataset_dir, | ||
21 | + "DIV2K_{}_LR_bicubic".format(dataset_type), | ||
22 | + subdir, "*.png")) | ||
23 | + im_paths.sort() | ||
24 | + grp = f.create_group(subdir) | ||
25 | + | ||
26 | + for i, path in enumerate(im_paths): | ||
27 | + im = misc.imread(path) | ||
28 | + print(path) | ||
29 | + grp.create_dataset(str(i), data=im) |
carn/sample/.gitignore
0 → 100644
-
Please register or login to post a comment