김재형

CARN 학습 시 PSNR, SSIM eval 코드 추가

...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
2 import random 2 import random
3 import numpy as np 3 import numpy as np
4 import scipy.misc as misc 4 import scipy.misc as misc
5 -import skimage.measure as measure 5 +import skimage.metrics as metrics
6 from tensorboardX import SummaryWriter 6 from tensorboardX import SummaryWriter
7 import torch 7 import torch
8 import torch.nn as nn 8 import torch.nn as nn
...@@ -92,8 +92,9 @@ class Solver(): ...@@ -92,8 +92,9 @@ class Solver():
92 self.step += 1 92 self.step += 1
93 if cfg.verbose and self.step % cfg.print_interval == 0: 93 if cfg.verbose and self.step % cfg.print_interval == 0:
94 if cfg.scale > 0: 94 if cfg.scale > 0:
95 - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) 95 + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
96 - self.writer.add_scalar("Urban100", psnr, self.step) 96 + self.writer.add_scalar("PSNR", psnr, self.step)
97 + self.writer.add_scalar("SSIM", ssim, self.step)
97 else: 98 else:
98 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] 99 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
99 self.writer.add_scalar("Urban100_2x", psnr[0], self.step) 100 self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
...@@ -107,6 +108,7 @@ class Solver(): ...@@ -107,6 +108,7 @@ class Solver():
107 def evaluate(self, test_data_dir, scale=2, num_step=0): 108 def evaluate(self, test_data_dir, scale=2, num_step=0):
108 cfg = self.cfg 109 cfg = self.cfg
109 mean_psnr = 0 110 mean_psnr = 0
111 + mean_ssim = 0
110 self.refiner.eval() 112 self.refiner.eval()
111 113
112 test_data = TestDataset(test_data_dir, scale=scale) 114 test_data = TestDataset(test_data_dir, scale=scale)
...@@ -149,15 +151,16 @@ class Solver(): ...@@ -149,15 +151,16 @@ class Solver():
149 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 151 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
150 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 152 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
151 153
152 - # evaluate PSNR 154 + # evaluate PSNR and SSIM
153 # this evaluation is different to MATLAB version 155 # this evaluation is different to MATLAB version
154 # we evaluate PSNR in RGB channel not Y in YCbCR 156 # we evaluate PSNR in RGB channel not Y in YCbCR
155 bnd = scale 157 bnd = scale
156 - im1 = hr[bnd:-bnd, bnd:-bnd] 158 + im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
157 - im2 = sr[bnd:-bnd, bnd:-bnd] 159 + im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
158 mean_psnr += psnr(im1, im2) / len(test_data) 160 mean_psnr += psnr(im1, im2) / len(test_data)
161 + mean_ssim += ssim(im1, im2) / len(test_data)
159 162
160 - return mean_psnr 163 + return mean_psnr, mean_ssim
161 164
162 def load(self, path): 165 def load(self, path):
163 self.refiner.load_state_dict(torch.load(path)) 166 self.refiner.load_state_dict(torch.load(path))
...@@ -177,14 +180,15 @@ class Solver(): ...@@ -177,14 +180,15 @@ class Solver():
177 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) 180 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
178 return lr 181 return lr
179 182
180 - 183 +def im2double(im):
181 -def psnr(im1, im2):
182 - def im2double(im):
183 min_val, max_val = 0, 255 184 min_val, max_val = 0, 255
184 out = (im.astype(np.float64)-min_val) / (max_val-min_val) 185 out = (im.astype(np.float64)-min_val) / (max_val-min_val)
185 return out 186 return out
186 187
187 - im1 = im2double(im1) 188 +def psnr(im1, im2):
188 - im2 = im2double(im2) 189 + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
189 - psnr = measure.compare_psnr(im1, im2, data_range=1)
190 return psnr 190 return psnr
191 +
192 +def ssim(im1, im2):
193 + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
194 + return ssim
......