hanbin9775
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
2 import random 2 import random
3 import numpy as np 3 import numpy as np
4 import scipy.misc as misc 4 import scipy.misc as misc
5 -import skimage.measure as measure 5 +import skimage.metrics as metrics
6 from tensorboardX import SummaryWriter 6 from tensorboardX import SummaryWriter
7 import torch 7 import torch
8 import torch.nn as nn 8 import torch.nn as nn
...@@ -92,8 +92,9 @@ class Solver(): ...@@ -92,8 +92,9 @@ class Solver():
92 self.step += 1 92 self.step += 1
93 if cfg.verbose and self.step % cfg.print_interval == 0: 93 if cfg.verbose and self.step % cfg.print_interval == 0:
94 if cfg.scale > 0: 94 if cfg.scale > 0:
95 - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) 95 + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
96 - self.writer.add_scalar("Urban100", psnr, self.step) 96 + self.writer.add_scalar("PSNR", psnr, self.step)
97 + self.writer.add_scalar("SSIM", ssim, self.step)
97 else: 98 else:
98 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] 99 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
99 self.writer.add_scalar("Urban100_2x", psnr[0], self.step) 100 self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
...@@ -107,6 +108,7 @@ class Solver(): ...@@ -107,6 +108,7 @@ class Solver():
107 def evaluate(self, test_data_dir, scale=2, num_step=0): 108 def evaluate(self, test_data_dir, scale=2, num_step=0):
108 cfg = self.cfg 109 cfg = self.cfg
109 mean_psnr = 0 110 mean_psnr = 0
111 + mean_ssim = 0
110 self.refiner.eval() 112 self.refiner.eval()
111 113
112 test_data = TestDataset(test_data_dir, scale=scale) 114 test_data = TestDataset(test_data_dir, scale=scale)
...@@ -149,15 +151,16 @@ class Solver(): ...@@ -149,15 +151,16 @@ class Solver():
149 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 151 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
150 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 152 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
151 153
152 - # evaluate PSNR 154 + # evaluate PSNR and SSIM
153 # this evaluation is different to MATLAB version 155 # this evaluation is different to MATLAB version
154 # we evaluate PSNR in RGB channel not Y in YCbCR 156 # we evaluate PSNR in RGB channel not Y in YCbCR
155 bnd = scale 157 bnd = scale
156 - im1 = hr[bnd:-bnd, bnd:-bnd] 158 + im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
157 - im2 = sr[bnd:-bnd, bnd:-bnd] 159 + im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
158 mean_psnr += psnr(im1, im2) / len(test_data) 160 mean_psnr += psnr(im1, im2) / len(test_data)
161 + mean_ssim += ssim(im1, im2) / len(test_data)
159 162
160 - return mean_psnr 163 + return mean_psnr, mean_ssim
161 164
162 def load(self, path): 165 def load(self, path):
163 self.refiner.load_state_dict(torch.load(path)) 166 self.refiner.load_state_dict(torch.load(path))
...@@ -177,14 +180,15 @@ class Solver(): ...@@ -177,14 +180,15 @@ class Solver():
177 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) 180 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
178 return lr 181 return lr
179 182
180 - 183 +def im2double(im):
181 -def psnr(im1, im2):
182 - def im2double(im):
183 min_val, max_val = 0, 255 184 min_val, max_val = 0, 255
184 out = (im.astype(np.float64)-min_val) / (max_val-min_val) 185 out = (im.astype(np.float64)-min_val) / (max_val-min_val)
185 return out 186 return out
186 187
187 - im1 = im2double(im1) 188 +def psnr(im1, im2):
188 - im2 = im2double(im2) 189 + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
189 - psnr = measure.compare_psnr(im1, im2, data_range=1)
190 return psnr 190 return psnr
191 +
192 +def ssim(im1, im2):
193 + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
194 + return ssim
......
...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): ...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg):
76 76
77 def main(cfg): 77 def main(cfg):
78 module = importlib.import_module("model.{}".format(cfg.model)) 78 module = importlib.import_module("model.{}".format(cfg.model))
79 - net = module.Net(multi_scale=True, 79 + net = module.Net(multi_scale=False,
80 + scale=cfg.scale,
80 group=cfg.group) 81 group=cfg.group)
81 print(json.dumps(vars(cfg), indent=4, sort_keys=True)) 82 print(json.dumps(vars(cfg), indent=4, sort_keys=True))
82 83
83 state_dict = torch.load(cfg.ckpt_path) 84 state_dict = torch.load(cfg.ckpt_path)
85 + # print(state_dict.keys())
84 new_state_dict = OrderedDict() 86 new_state_dict = OrderedDict()
85 for k, v in state_dict.items(): 87 for k, v in state_dict.items():
86 name = k 88 name = k
...@@ -88,11 +90,13 @@ def main(cfg): ...@@ -88,11 +90,13 @@ def main(cfg):
88 new_state_dict[name] = v 90 new_state_dict[name] = v
89 91
90 net.load_state_dict(new_state_dict) 92 net.load_state_dict(new_state_dict)
93 + net.eval()
91 94
92 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 net = net.to(device) 96 net = net.to(device)
94 97
95 dataset = TestDataset(cfg.test_data_dir, cfg.scale) 98 dataset = TestDataset(cfg.test_data_dir, cfg.scale)
99 + with torch.no_grad():
96 sample(net, device, dataset, cfg) 100 sample(net, device, dataset, cfg)
97 101
98 102
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -2,10 +2,22 @@ ...@@ -2,10 +2,22 @@
2 "cells": [ 2 "cells": [
3 { 3 {
4 "cell_type": "code", 4 "cell_type": "code",
5 - "execution_count": 15, 5 + "execution_count": 1,
6 "id": "automotive-circus", 6 "id": "automotive-circus",
7 "metadata": {}, 7 "metadata": {},
8 - "outputs": [], 8 + "outputs": [
9 + {
10 + "output_type": "error",
11 + "ename": "ModuleNotFoundError",
12 + "evalue": "No module named 'cv2'",
13 + "traceback": [
14 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
15 + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16 + "\u001b[1;32m<ipython-input-1-03d1a01a87c6>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mgt_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_1080p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
17 + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
18 + ]
19 + }
20 + ],
9 "source": [ 21 "source": [
10 "from glob import glob\n", 22 "from glob import glob\n",
11 "import cv2\n", 23 "import cv2\n",
......
...@@ -5,7 +5,19 @@ ...@@ -5,7 +5,19 @@
5 "execution_count": 1, 5 "execution_count": 1,
6 "id": "ahead-paste", 6 "id": "ahead-paste",
7 "metadata": {}, 7 "metadata": {},
8 - "outputs": [], 8 + "outputs": [
9 + {
10 + "output_type": "error",
11 + "ename": "ModuleNotFoundError",
12 + "evalue": "No module named 'cv2'",
13 + "traceback": [
14 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
15 + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16 + "\u001b[1;32m<ipython-input-1-ff55b1ddb4f1>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_540p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
17 + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
18 + ]
19 + }
20 + ],
9 "source": [ 21 "source": [
10 "from glob import glob\n", 22 "from glob import glob\n",
11 "import cv2\n", 23 "import cv2\n",
......
1 +{
2 + "cells": [
3 + {
4 + "cell_type": "code",
5 + "execution_count": 1,
6 + "id": "ahead-paste",
7 + "metadata": {},
8 + "outputs": [],
9 + "source": [
10 + "from glob import glob\n",
11 + "import cv2\n",
12 + "\n",
13 + "images = sorted(glob(\"./tennis_test_1080p/*.png\"))"
14 + ]
15 + },
16 + {
17 + "cell_type": "code",
18 + "execution_count": 2,
19 + "id": "rapid-tension",
20 + "metadata": {},
21 + "outputs": [],
22 + "source": [
23 + "from pathlib import Path\n",
24 + "Path(\"./dataset/Urban100/x2\").mkdir(parents=True, exist_ok=True)"
25 + ]
26 + },
27 + {
28 + "cell_type": "code",
29 + "execution_count": 3,
30 + "id": "visible-texas",
31 + "metadata": {},
32 + "outputs": [
33 + {
34 + "name": "stderr",
35 + "output_type": "stream",
36 + "text": [
37 + "100%|██████████| 125/125 [00:18<00:00, 6.61it/s]\n"
38 + ]
39 + }
40 + ],
41 + "source": [
42 + "from tqdm import tqdm\n",
43 + "for image in tqdm(images):\n",
44 + " hr = cv2.imread(image, cv2.IMREAD_COLOR)\n",
45 + " lr = cv2.resize(hr, dsize=(960, 540), interpolation=cv2.INTER_CUBIC)\n",
46 + "\n",
47 + " cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_HR.png\", hr)\n",
48 + " cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_LR.png\", lr)"
49 + ]
50 + },
51 + {
52 + "cell_type": "code",
53 + "execution_count": null,
54 + "id": "fallen-religion",
55 + "metadata": {},
56 + "outputs": [],
57 + "source": []
58 + }
59 + ],
60 + "metadata": {
61 + "kernelspec": {
62 + "display_name": "Python 3",
63 + "language": "python",
64 + "name": "python3"
65 + },
66 + "language_info": {
67 + "codemirror_mode": {
68 + "name": "ipython",
69 + "version": 3
70 + },
71 + "file_extension": ".py",
72 + "mimetype": "text/x-python",
73 + "name": "python",
74 + "nbconvert_exporter": "python",
75 + "pygments_lexer": "ipython3",
76 + "version": "3.7.7"
77 + }
78 + },
79 + "nbformat": 4,
80 + "nbformat_minor": 5
81 +}
No preview for this file type