김재형
1 +*.swp
2 +DS_Store
3 +__pycache__
4 +
1 +MIT License
2 +
3 +Copyright (c) 2018 Namhyuk Ahn
4 +
5 +Permission is hereby granted, free of charge, to any person obtaining a copy
6 +of this software and associated documentation files (the "Software"), to deal
7 +in the Software without restriction, including without limitation the rights
8 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 +copies of the Software, and to permit persons to whom the Software is
10 +furnished to do so, subject to the following conditions:
11 +
12 +The above copyright notice and this permission notice shall be included in all
13 +copies or substantial portions of the Software.
14 +
15 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 +SOFTWARE.
1 +# Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network
2 +Namhyuk Ahn, Byungkon Kang, Kyung-Ah Sohn.<br>
3 +European Conference on Computer Vision (ECCV), 2018.
4 +[[arXiv](https://arxiv.org/abs/1803.08664)]
5 +
6 +<img src="assets/benchmark.png">
7 +
8 +### Abstract
9 +In recent years, deep learning methods have been successfully applied to single-image super-resolution tasks. Despite their great performances, deep learning methods cannot be easily applied to real-world applications due to the requirement of heavy computation. In this paper, we address this issue by proposing an accurate and lightweight deep learning model for image super-resolution. In detail, we design an architecture that implements a cascading mechanism upon a residual network. We also present a variant model of the proposed cascading residual network to further improve efficiency. Our extensive experiments show that even with much fewer parameters and operations, our models achieve performance comparable to that of state-of-the-art methods.
10 +
11 +### FAQs
12 +1. Can't reproduce PSNR/SSIM as recorded in the paper: See [issue#6](https://github.com/nmhkahn/CARN-pytorch/issues/6)
13 +
14 +### Requirements
15 +- Python 3
16 +- [PyTorch](https://github.com/pytorch/pytorch) (0.4.0), [torchvision](https://github.com/pytorch/vision)
17 +- Numpy, Scipy
18 +- Pillow, Scikit-image
19 +- h5py
20 +- importlib
21 +
22 +### Dataset
23 +We use DIV2K dataset for training and Set5, Set14, B100 and Urban100 dataset for the benchmark test. Here are the following steps to prepare datasets.
24 +
25 +1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) and unzip on `dataset` directory as below:
26 + ```
27 + dataset
28 + └── DIV2K
29 + ├── DIV2K_train_HR
30 + ├── DIV2K_train_LR_bicubic
31 + ├── DIV2K_valid_HR
32 + └── DIV2K_valid_LR_bicubic
33 + ```
34 +2. To accelerate training, we first convert training images to h5 format as follow (h5py module has to be installed).
35 +```shell
36 +$ cd datasets && python div2h5.py
37 +```
38 +3. Other benchmark datasets can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1t2le0-Wz7GZQ4M2mJqmRamw5o4ce2AVw?usp=sharing). Same as DIV2K, please put all the datasets in `dataset` directory.
39 +
40 +### Test Pretrained Models
41 +We provide the pretrained models in `checkpoint` directory. To test CARN on benchmark dataset:
42 +```shell
43 +$ python carn/sample.py --model carn \
44 + --test_data_dir dataset/<dataset> \
45 + --scale [2|3|4] \
46 + --ckpt_path ./checkpoint/<path>.pth \
47 + --sample_dir <sample_dir>
48 +```
49 +and for CARN-M,
50 +```shell
51 +$ python carn/sample.py --model carn_m \
52 + --test_data_dir dataset/<dataset> \
53 + --scale [2|3|4] \
54 + --ckpt_path ./checkpoint/<path>.pth \
55 + --sample_dir <sample_dir> \
56 + --group 4
57 +```
58 +We provide our results on four benchmark dataset (Set5, Set14, B100 and Urban100). [Google Drive](https://drive.google.com/drive/folders/1R4vZMs3Adf8UlYbIzStY98qlsl5y1wxH?usp=sharing)
59 +
60 +### Training Models
61 +Here are our settings to train CARN and CARN-M. Note: We use two GPU to utilize large batch size, but if OOM error arise, please reduce batch size.
62 +```shell
63 +# For CARN
64 +$ python carn/train.py --patch_size 64 \
65 + --batch_size 64 \
66 + --max_steps 600000 \
67 + --decay 400000 \
68 + --model carn \
69 + --ckpt_name carn \
70 + --ckpt_dir checkpoint/carn \
71 + --scale 0 \
72 + --num_gpu 2
73 +# For CARN-M
74 +$ python carn/train.py --patch_size 64 \
75 + --batch_size 64 \
76 + --max_steps 600000 \
77 + --decay 400000 \
78 + --model carn_m \
79 + --ckpt_name carn_m \
80 + --ckpt_dir checkpoint/carn_m \
81 + --scale 0 \
82 + --group 4 \
83 + --num_gpu 2
84 +```
85 +In the `--scale` argument, [2, 3, 4] is for single-scale training and 0 for multi-scale learning. `--group` represents group size of group convolution. The differences from previous version are: 1) we increase batch size and patch size to 64 and 64. 2) Instead of using `reduce_upsample` argument which replace 3x3 conv of the upsample block to 1x1, we use group convolution as same way to the efficient residual block.
86 +
87 +### Results
88 +**Note:** As pointed out in [#2](https://github.com/nmhkahn/CARN-pytorch/issues/2), previous Urban100 benchmark dataset was incorrect. The issue is related to the mismatch of the HR image resolution from the original dataset in x2 and x3 scale. We correct this problem, and provided dataset and results are fixed ones.
89 +
90 +<img src="assets/table.png">
91 +<img src="assets/visual.png">
92 +
93 +### Citation
94 +```
95 +@article{ahn2018fast,
96 + title={Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network},
97 + author={Ahn, Namhyuk and Kang, Byungkon and Sohn, Kyung-Ah},
98 + journal={arXiv preprint arXiv:1803.08664},
99 + year={2018}
100 +}
101 +```
File mode changed
1 +import os
2 +import glob
3 +import h5py
4 +import random
5 +import numpy as np
6 +from PIL import Image
7 +import torch.utils.data as data
8 +import torchvision.transforms as transforms
9 +
10 +def random_crop(hr, lr, size, scale):
11 + h, w = lr.shape[:-1]
12 + x = random.randint(0, w-size)
13 + y = random.randint(0, h-size)
14 +
15 + hsize = size*scale
16 + hx, hy = x*scale, y*scale
17 +
18 + crop_lr = lr[y:y+size, x:x+size].copy()
19 + crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
20 +
21 + return crop_hr, crop_lr
22 +
23 +
24 +def random_flip_and_rotate(im1, im2):
25 + if random.random() < 0.5:
26 + im1 = np.flipud(im1)
27 + im2 = np.flipud(im2)
28 +
29 + if random.random() < 0.5:
30 + im1 = np.fliplr(im1)
31 + im2 = np.fliplr(im2)
32 +
33 + angle = random.choice([0, 1, 2, 3])
34 + im1 = np.rot90(im1, angle)
35 + im2 = np.rot90(im2, angle)
36 +
37 + # have to copy before be called by transform function
38 + return im1.copy(), im2.copy()
39 +
40 +
41 +class TrainDataset(data.Dataset):
42 + def __init__(self, path, size, scale):
43 + super(TrainDataset, self).__init__()
44 +
45 + self.size = size
46 + h5f = h5py.File(path, "r")
47 +
48 + self.hr = [v[:] for v in h5f["HR"].values()]
49 + # perform multi-scale training
50 + if scale == 0:
51 + self.scale = [2, 3, 4]
52 + self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
53 + else:
54 + self.scale = [scale]
55 + self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
56 +
57 + h5f.close()
58 +
59 + self.transform = transforms.Compose([
60 + transforms.ToTensor()
61 + ])
62 +
63 + def __getitem__(self, index):
64 + size = self.size
65 +
66 + item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
67 + item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
68 + item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
69 +
70 + return [(self.transform(hr), self.transform(lr)) for hr, lr in item]
71 +
72 + def __len__(self):
73 + return len(self.hr)
74 +
75 +
76 +class TestDataset(data.Dataset):
77 + def __init__(self, dirname, scale):
78 + super(TestDataset, self).__init__()
79 +
80 + self.name = dirname.split("/")[-1]
81 + self.scale = scale
82 +
83 + if "DIV" in self.name:
84 + self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
85 + self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
86 + "X{}/*.png".format(scale)))
87 + else:
88 + all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
89 + self.hr = [name for name in all_files if "HR" in name]
90 + self.lr = [name for name in all_files if "LR" in name]
91 +
92 + self.hr.sort()
93 + self.lr.sort()
94 +
95 + self.transform = transforms.Compose([
96 + transforms.ToTensor()
97 + ])
98 +
99 + def __getitem__(self, index):
100 + hr = Image.open(self.hr[index])
101 + lr = Image.open(self.lr[index])
102 +
103 + hr = hr.convert("RGB")
104 + lr = lr.convert("RGB")
105 + filename = self.hr[index].split("/")[-1]
106 +
107 + return self.transform(hr), self.transform(lr), filename
108 +
109 + def __len__(self):
110 + return len(self.hr)
1 +import torch
2 +import torch.nn as nn
3 +import model.ops as ops
4 +
5 +class Block(nn.Module):
6 + def __init__(self,
7 + in_channels, out_channels,
8 + group=1):
9 + super(Block, self).__init__()
10 +
11 + self.b1 = ops.ResidualBlock(64, 64)
12 + self.b2 = ops.ResidualBlock(64, 64)
13 + self.b3 = ops.ResidualBlock(64, 64)
14 + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
15 + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
16 + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
17 +
18 + def forward(self, x):
19 + c0 = o0 = x
20 +
21 + b1 = self.b1(o0)
22 + c1 = torch.cat([c0, b1], dim=1)
23 + o1 = self.c1(c1)
24 +
25 + b2 = self.b2(o1)
26 + c2 = torch.cat([c1, b2], dim=1)
27 + o2 = self.c2(c2)
28 +
29 + b3 = self.b3(o2)
30 + c3 = torch.cat([c2, b3], dim=1)
31 + o3 = self.c3(c3)
32 +
33 + return o3
34 +
35 +
36 +class Net(nn.Module):
37 + def __init__(self, **kwargs):
38 + super(Net, self).__init__()
39 +
40 + scale = kwargs.get("scale")
41 + multi_scale = kwargs.get("multi_scale")
42 + group = kwargs.get("group", 1)
43 +
44 + self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
45 + self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
46 +
47 + self.entry = nn.Conv2d(3, 64, 3, 1, 1)
48 +
49 + self.b1 = Block(64, 64)
50 + self.b2 = Block(64, 64)
51 + self.b3 = Block(64, 64)
52 + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
53 + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
54 + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
55 +
56 + self.upsample = ops.UpsampleBlock(64, scale=scale,
57 + multi_scale=multi_scale,
58 + group=group)
59 + self.exit = nn.Conv2d(64, 3, 3, 1, 1)
60 +
61 + def forward(self, x, scale):
62 + x = self.sub_mean(x)
63 + x = self.entry(x)
64 + c0 = o0 = x
65 +
66 + b1 = self.b1(o0)
67 + c1 = torch.cat([c0, b1], dim=1)
68 + o1 = self.c1(c1)
69 +
70 + b2 = self.b2(o1)
71 + c2 = torch.cat([c1, b2], dim=1)
72 + o2 = self.c2(c2)
73 +
74 + b3 = self.b3(o2)
75 + c3 = torch.cat([c2, b3], dim=1)
76 + o3 = self.c3(c3)
77 +
78 + out = self.upsample(o3, scale=scale)
79 +
80 + out = self.exit(out)
81 + out = self.add_mean(out)
82 +
83 + return out
1 +import torch
2 +import torch.nn as nn
3 +import model.ops as ops
4 +
5 +class Block(nn.Module):
6 + def __init__(self,
7 + in_channels, out_channels,
8 + group=1):
9 + super(Block, self).__init__()
10 +
11 + self.b1 = ops.EResidualBlock(64, 64, group=group)
12 + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
13 + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
14 + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
15 +
16 + def forward(self, x):
17 + c0 = o0 = x
18 +
19 + b1 = self.b1(o0)
20 + c1 = torch.cat([c0, b1], dim=1)
21 + o1 = self.c1(c1)
22 +
23 + b2 = self.b1(o1)
24 + c2 = torch.cat([c1, b2], dim=1)
25 + o2 = self.c2(c2)
26 +
27 + b3 = self.b1(o2)
28 + c3 = torch.cat([c2, b3], dim=1)
29 + o3 = self.c3(c3)
30 +
31 + return o3
32 +
33 +
34 +class Net(nn.Module):
35 + def __init__(self, **kwargs):
36 + super(Net, self).__init__()
37 +
38 + scale = kwargs.get("scale")
39 + multi_scale = kwargs.get("multi_scale")
40 + group = kwargs.get("group", 1)
41 +
42 + self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
43 + self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
44 +
45 + self.entry = nn.Conv2d(3, 64, 3, 1, 1)
46 +
47 + self.b1 = Block(64, 64, group=group)
48 + self.b2 = Block(64, 64, group=group)
49 + self.b3 = Block(64, 64, group=group)
50 + self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
51 + self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
52 + self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
53 +
54 + self.upsample = ops.UpsampleBlock(64, scale=scale,
55 + multi_scale=multi_scale,
56 + group=group)
57 + self.exit = nn.Conv2d(64, 3, 3, 1, 1)
58 +
59 + def forward(self, x, scale):
60 + x = self.sub_mean(x)
61 + x = self.entry(x)
62 + c0 = o0 = x
63 +
64 + b1 = self.b1(o0)
65 + c1 = torch.cat([c0, b1], dim=1)
66 + o1 = self.c1(c1)
67 +
68 + b2 = self.b2(o1)
69 + c2 = torch.cat([c1, b2], dim=1)
70 + o2 = self.c2(c2)
71 +
72 + b3 = self.b3(o2)
73 + c3 = torch.cat([c2, b3], dim=1)
74 + o3 = self.c3(c3)
75 +
76 + out = self.upsample(o3, scale=scale)
77 +
78 + out = self.exit(out)
79 + out = self.add_mean(out)
80 +
81 + return out
1 +import math
2 +import torch
3 +import torch.nn as nn
4 +import torch.nn.init as init
5 +import torch.nn.functional as F
6 +
7 +def init_weights(modules):
8 + pass
9 +
10 +
11 +class MeanShift(nn.Module):
12 + def __init__(self, mean_rgb, sub):
13 + super(MeanShift, self).__init__()
14 +
15 + sign = -1 if sub else 1
16 + r = mean_rgb[0] * sign
17 + g = mean_rgb[1] * sign
18 + b = mean_rgb[2] * sign
19 +
20 + self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
21 + self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
22 + self.shifter.bias.data = torch.Tensor([r, g, b])
23 +
24 + # Freeze the mean shift layer
25 + for params in self.shifter.parameters():
26 + params.requires_grad = False
27 +
28 + def forward(self, x):
29 + x = self.shifter(x)
30 + return x
31 +
32 +
33 +class BasicBlock(nn.Module):
34 + def __init__(self,
35 + in_channels, out_channels,
36 + ksize=3, stride=1, pad=1):
37 + super(BasicBlock, self).__init__()
38 +
39 + self.body = nn.Sequential(
40 + nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
41 + nn.ReLU(inplace=True)
42 + )
43 +
44 + init_weights(self.modules)
45 +
46 + def forward(self, x):
47 + out = self.body(x)
48 + return out
49 +
50 +
51 +class ResidualBlock(nn.Module):
52 + def __init__(self,
53 + in_channels, out_channels):
54 + super(ResidualBlock, self).__init__()
55 +
56 + self.body = nn.Sequential(
57 + nn.Conv2d(in_channels, out_channels, 3, 1, 1),
58 + nn.ReLU(inplace=True),
59 + nn.Conv2d(out_channels, out_channels, 3, 1, 1),
60 + )
61 +
62 + init_weights(self.modules)
63 +
64 + def forward(self, x):
65 + out = self.body(x)
66 + out = F.relu(out + x)
67 + return out
68 +
69 +
70 +class EResidualBlock(nn.Module):
71 + def __init__(self,
72 + in_channels, out_channels,
73 + group=1):
74 + super(EResidualBlock, self).__init__()
75 +
76 + self.body = nn.Sequential(
77 + nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
78 + nn.ReLU(inplace=True),
79 + nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
80 + nn.ReLU(inplace=True),
81 + nn.Conv2d(out_channels, out_channels, 1, 1, 0),
82 + )
83 +
84 + init_weights(self.modules)
85 +
86 + def forward(self, x):
87 + out = self.body(x)
88 + out = F.relu(out + x)
89 + return out
90 +
91 +
92 +class UpsampleBlock(nn.Module):
93 + def __init__(self,
94 + n_channels, scale, multi_scale,
95 + group=1):
96 + super(UpsampleBlock, self).__init__()
97 +
98 + if multi_scale:
99 + self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
100 + self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
101 + self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
102 + else:
103 + self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
104 +
105 + self.multi_scale = multi_scale
106 +
107 + def forward(self, x, scale):
108 + if self.multi_scale:
109 + if scale == 2:
110 + return self.up2(x)
111 + elif scale == 3:
112 + return self.up3(x)
113 + elif scale == 4:
114 + return self.up4(x)
115 + else:
116 + return self.up(x)
117 +
118 +
119 +class _UpsampleBlock(nn.Module):
120 + def __init__(self,
121 + n_channels, scale,
122 + group=1):
123 + super(_UpsampleBlock, self).__init__()
124 +
125 + modules = []
126 + if scale == 2 or scale == 4 or scale == 8:
127 + for _ in range(int(math.log(scale, 2))):
128 + modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
129 + modules += [nn.PixelShuffle(2)]
130 + elif scale == 3:
131 + modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
132 + modules += [nn.PixelShuffle(3)]
133 +
134 + self.body = nn.Sequential(*modules)
135 + init_weights(self.modules)
136 +
137 + def forward(self, x):
138 + out = self.body(x)
139 + return out
1 +import os
2 +import json
3 +import time
4 +import importlib
5 +import argparse
6 +import numpy as np
7 +from collections import OrderedDict
8 +import torch
9 +import torch.nn as nn
10 +from torch.autograd import Variable
11 +from dataset import TestDataset
12 +from PIL import Image
13 +
14 +def parse_args():
15 + parser = argparse.ArgumentParser()
16 + parser.add_argument("--model", type=str)
17 + parser.add_argument("--ckpt_path", type=str)
18 + parser.add_argument("--group", type=int, default=1)
19 + parser.add_argument("--sample_dir", type=str)
20 + parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100")
21 + parser.add_argument("--cuda", action="store_true")
22 + parser.add_argument("--scale", type=int, default=4)
23 + parser.add_argument("--shave", type=int, default=20)
24 +
25 + return parser.parse_args()
26 +
27 +
28 +def save_image(tensor, filename):
29 + tensor = tensor.cpu()
30 + ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
31 + im = Image.fromarray(ndarr)
32 + im.save(filename)
33 +
34 +
35 +def sample(net, device, dataset, cfg):
36 + scale = cfg.scale
37 + for step, (hr, lr, name) in enumerate(dataset):
38 + if "DIV2K" in dataset.name:
39 + t1 = time.time()
40 + h, w = lr.size()[1:]
41 + h_half, w_half = int(h/2), int(w/2)
42 + h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave
43 +
44 + lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float)
45 + lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
46 + lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
47 + lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
48 + lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
49 + lr_patch = lr_patch.to(device)
50 +
51 + sr = net(lr_patch, cfg.scale).detach()
52 +
53 + h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
54 + w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
55 +
56 + result = torch.tensor((3, h, w), dtype=torch.float).to(device)
57 + result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
58 + result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
59 + result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
60 + result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
61 + sr = result
62 + t2 = time.time()
63 + else:
64 + t1 = time.time()
65 + lr = lr.unsqueeze(0).to(device)
66 + sr = net(lr, cfg.scale).detach().squeeze(0)
67 + lr = lr.squeeze(0)
68 + t2 = time.time()
69 +
70 + model_name = cfg.ckpt_path.split(".")[0].split("/")[-1]
71 + sr_dir = os.path.join(cfg.sample_dir,
72 + model_name,
73 + cfg.test_data_dir.split("/")[-1],
74 + "x{}".format(cfg.scale),
75 + "SR")
76 + hr_dir = os.path.join(cfg.sample_dir,
77 + model_name,
78 + cfg.test_data_dir.split("/")[-1],
79 + "x{}".format(cfg.scale),
80 + "HR")
81 +
82 + os.makedirs(sr_dir, exist_ok=True)
83 + os.makedirs(hr_dir, exist_ok=True)
84 +
85 + sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR")))
86 + hr_im_path = os.path.join(hr_dir, "{}".format(name))
87 +
88 + save_image(sr, sr_im_path)
89 + save_image(hr, hr_im_path)
90 + print("Saved {} ({}x{} -> {}x{}, {:.3f}s)"
91 + .format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1))
92 +
93 +
94 +def main(cfg):
95 + module = importlib.import_module("model.{}".format(cfg.model))
96 + net = module.Net(multi_scale=True,
97 + group=cfg.group)
98 + print(json.dumps(vars(cfg), indent=4, sort_keys=True))
99 +
100 + state_dict = torch.load(cfg.ckpt_path)
101 + new_state_dict = OrderedDict()
102 + for k, v in state_dict.items():
103 + name = k
104 + # name = k[7:] # remove "module."
105 + new_state_dict[name] = v
106 +
107 + net.load_state_dict(new_state_dict)
108 +
109 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110 + net = net.to(device)
111 +
112 + dataset = TestDataset(cfg.test_data_dir, cfg.scale)
113 + sample(net, device, dataset, cfg)
114 +
115 +
116 +if __name__ == "__main__":
117 + cfg = parse_args()
118 + main(cfg)
1 +import os
2 +import random
3 +import numpy as np
4 +import scipy.misc as misc
5 +import skimage.measure as measure
6 +from tensorboardX import SummaryWriter
7 +import torch
8 +import torch.nn as nn
9 +import torch.optim as optim
10 +from torch.utils.data import DataLoader
11 +from dataset import TrainDataset, TestDataset
12 +
13 +class Solver():
14 + def __init__(self, model, cfg):
15 + if cfg.scale > 0:
16 + self.refiner = model(scale=cfg.scale,
17 + group=cfg.group)
18 + else:
19 + self.refiner = model(multi_scale=True,
20 + group=cfg.group)
21 +
22 + if cfg.loss_fn in ["MSE"]:
23 + self.loss_fn = nn.MSELoss()
24 + elif cfg.loss_fn in ["L1"]:
25 + self.loss_fn = nn.L1Loss()
26 + elif cfg.loss_fn in ["SmoothL1"]:
27 + self.loss_fn = nn.SmoothL1Loss()
28 +
29 + self.optim = optim.Adam(
30 + filter(lambda p: p.requires_grad, self.refiner.parameters()),
31 + cfg.lr)
32 +
33 + self.train_data = TrainDataset(cfg.train_data_path,
34 + scale=cfg.scale,
35 + size=cfg.patch_size)
36 + self.train_loader = DataLoader(self.train_data,
37 + batch_size=cfg.batch_size,
38 + num_workers=1,
39 + shuffle=True, drop_last=True)
40 +
41 +
42 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43 + self.refiner = self.refiner.to(self.device)
44 + self.loss_fn = self.loss_fn
45 +
46 + self.cfg = cfg
47 + self.step = 0
48 +
49 + self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name))
50 + if cfg.verbose:
51 + num_params = 0
52 + for param in self.refiner.parameters():
53 + num_params += param.nelement()
54 + print("# of params:", num_params)
55 +
56 + os.makedirs(cfg.ckpt_dir, exist_ok=True)
57 +
58 + def fit(self):
59 + cfg = self.cfg
60 + refiner = nn.DataParallel(self.refiner,
61 + device_ids=range(cfg.num_gpu))
62 +
63 + learning_rate = cfg.lr
64 + while True:
65 + for inputs in self.train_loader:
66 + self.refiner.train()
67 +
68 + if cfg.scale > 0:
69 + scale = cfg.scale
70 + hr, lr = inputs[-1][0], inputs[-1][1]
71 + else:
72 + # only use one of multi-scale data
73 + # i know this is stupid but just temporary
74 + scale = random.randint(2, 4)
75 + hr, lr = inputs[scale-2][0], inputs[scale-2][1]
76 +
77 + hr = hr.to(self.device)
78 + lr = lr.to(self.device)
79 +
80 + sr = refiner(lr, scale)
81 + loss = self.loss_fn(sr, hr)
82 +
83 + self.optim.zero_grad()
84 + loss.backward()
85 + nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip)
86 + self.optim.step()
87 +
88 + learning_rate = self.decay_learning_rate()
89 + for param_group in self.optim.param_groups:
90 + param_group["lr"] = learning_rate
91 +
92 + self.step += 1
93 + if cfg.verbose and self.step % cfg.print_interval == 0:
94 + if cfg.scale > 0:
95 + psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
96 + self.writer.add_scalar("Urban100", psnr, self.step)
97 + else:
98 + psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
99 + self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
100 + self.writer.add_scalar("Urban100_3x", psnr[1], self.step)
101 + self.writer.add_scalar("Urban100_4x", psnr[2], self.step)
102 +
103 + self.save(cfg.ckpt_dir, cfg.ckpt_name)
104 +
105 + if self.step > cfg.max_steps: break
106 +
107 + def evaluate(self, test_data_dir, scale=2, num_step=0):
108 + cfg = self.cfg
109 + mean_psnr = 0
110 + self.refiner.eval()
111 +
112 + test_data = TestDataset(test_data_dir, scale=scale)
113 + test_loader = DataLoader(test_data,
114 + batch_size=1,
115 + num_workers=1,
116 + shuffle=False)
117 +
118 + for step, inputs in enumerate(test_loader):
119 + hr = inputs[0].squeeze(0)
120 + lr = inputs[1].squeeze(0)
121 + name = inputs[2][0]
122 +
123 + h, w = lr.size()[1:]
124 + h_half, w_half = int(h/2), int(w/2)
125 + h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave
126 +
127 + # split large image to 4 patch to avoid OOM error
128 + lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop)
129 + lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
130 + lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
131 + lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
132 + lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
133 + lr_patch = lr_patch.to(self.device)
134 +
135 + # run refine process in here!
136 + sr = self.refiner(lr_patch, scale).data
137 +
138 + h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
139 + w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
140 +
141 + # merge splited patch images
142 + result = torch.FloatTensor(3, h, w).to(self.device)
143 + result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
144 + result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
145 + result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
146 + result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
147 + sr = result
148 +
149 + hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
150 + sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
151 +
152 + # evaluate PSNR
153 + # this evaluation is different to MATLAB version
154 + # we evaluate PSNR in RGB channel not Y in YCbCR
155 + bnd = scale
156 + im1 = hr[bnd:-bnd, bnd:-bnd]
157 + im2 = sr[bnd:-bnd, bnd:-bnd]
158 + mean_psnr += psnr(im1, im2) / len(test_data)
159 +
160 + return mean_psnr
161 +
162 + def load(self, path):
163 + self.refiner.load_state_dict(torch.load(path))
164 + splited = path.split(".")[0].split("_")[-1]
165 + try:
166 + self.step = int(path.split(".")[0].split("_")[-1])
167 + except ValueError:
168 + self.step = 0
169 + print("Load pretrained {} model".format(path))
170 +
171 + def save(self, ckpt_dir, ckpt_name):
172 + save_path = os.path.join(
173 + ckpt_dir, "{}_{}.pth".format(ckpt_name, self.step))
174 + torch.save(self.refiner.state_dict(), save_path)
175 +
176 + def decay_learning_rate(self):
177 + lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
178 + return lr
179 +
180 +
181 +def psnr(im1, im2):
182 + def im2double(im):
183 + min_val, max_val = 0, 255
184 + out = (im.astype(np.float64)-min_val) / (max_val-min_val)
185 + return out
186 +
187 + im1 = im2double(im1)
188 + im2 = im2double(im2)
189 + psnr = measure.compare_psnr(im1, im2, data_range=1)
190 + return psnr
1 +import os
2 +import sys
3 +sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
4 +import json
5 +import argparse
6 +import importlib
7 +from solver import Solver
8 +
9 +def parse_args():
10 + parser = argparse.ArgumentParser()
11 + parser.add_argument("--model", type=str)
12 + parser.add_argument("--ckpt_name", type=str)
13 +
14 + parser.add_argument("--print_interval", type=int, default=1000)
15 + parser.add_argument("--train_data_path", type=str,
16 + default="dataset/DIV2K_train.h5")
17 + parser.add_argument("--ckpt_dir", type=str,
18 + default="checkpoint")
19 + parser.add_argument("--sample_dir", type=str,
20 + default="sample/")
21 +
22 + parser.add_argument("--num_gpu", type=int, default=1)
23 + parser.add_argument("--shave", type=int, default=20)
24 + parser.add_argument("--scale", type=int, default=2)
25 +
26 + parser.add_argument("--verbose", action="store_true", default="store_true")
27 +
28 + parser.add_argument("--group", type=int, default=1)
29 +
30 + parser.add_argument("--patch_size", type=int, default=64)
31 + parser.add_argument("--batch_size", type=int, default=64)
32 + parser.add_argument("--max_steps", type=int, default=200000)
33 + parser.add_argument("--decay", type=int, default=150000)
34 + parser.add_argument("--lr", type=float, default=0.0001)
35 + parser.add_argument("--clip", type=float, default=10.0)
36 +
37 + parser.add_argument("--loss_fn", type=str,
38 + choices=["MSE", "L1", "SmoothL1"], default="L1")
39 +
40 + return parser.parse_args()
41 +
42 +def main(cfg):
43 + # dynamic import using --model argument
44 + net = importlib.import_module("model.{}".format(cfg.model)).Net
45 + print(json.dumps(vars(cfg), indent=4, sort_keys=True))
46 +
47 + solver = Solver(net, cfg)
48 + solver.fit()
49 +
50 +if __name__ == "__main__":
51 + cfg = parse_args()
52 + main(cfg)
No preview for this file type
No preview for this file type
1 +*
2 +!.gitignore
3 +!div2h5.py
1 +import os
2 +import glob
3 +import h5py
4 +import scipy.misc as misc
5 +import numpy as np
6 +
7 +dataset_dir = "DIV2K/"
8 +dataset_type = "train"
9 +
10 +f = h5py.File("DIV2K_{}.h5".format(dataset_type), "w")
11 +dt = h5py.special_dtype(vlen=np.dtype('uint8'))
12 +
13 +for subdir in ["HR", "X2", "X3", "X4"]:
14 + if subdir in ["HR"]:
15 + im_paths = glob.glob(os.path.join(dataset_dir,
16 + "DIV2K_{}_HR".format(dataset_type),
17 + "*.png"))
18 +
19 + else:
20 + im_paths = glob.glob(os.path.join(dataset_dir,
21 + "DIV2K_{}_LR_bicubic".format(dataset_type),
22 + subdir, "*.png"))
23 + im_paths.sort()
24 + grp = f.create_group(subdir)
25 +
26 + for i, path in enumerate(im_paths):
27 + im = misc.imread(path)
28 + print(path)
29 + grp.create_dataset(str(i), data=im)
1 +*
2 +!.gitignore