Showing
1 changed file
with
17 additions
and
13 deletions
... | @@ -2,7 +2,7 @@ import os | ... | @@ -2,7 +2,7 @@ import os |
2 | import random | 2 | import random |
3 | import numpy as np | 3 | import numpy as np |
4 | import scipy.misc as misc | 4 | import scipy.misc as misc |
5 | -import skimage.measure as measure | 5 | +import skimage.metrics as metrics |
6 | from tensorboardX import SummaryWriter | 6 | from tensorboardX import SummaryWriter |
7 | import torch | 7 | import torch |
8 | import torch.nn as nn | 8 | import torch.nn as nn |
... | @@ -92,8 +92,9 @@ class Solver(): | ... | @@ -92,8 +92,9 @@ class Solver(): |
92 | self.step += 1 | 92 | self.step += 1 |
93 | if cfg.verbose and self.step % cfg.print_interval == 0: | 93 | if cfg.verbose and self.step % cfg.print_interval == 0: |
94 | if cfg.scale > 0: | 94 | if cfg.scale > 0: |
95 | - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) | 95 | + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) |
96 | - self.writer.add_scalar("Urban100", psnr, self.step) | 96 | + self.writer.add_scalar("PSNR", psnr, self.step) |
97 | + self.writer.add_scalar("SSIM", ssim, self.step) | ||
97 | else: | 98 | else: |
98 | psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] | 99 | psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] |
99 | self.writer.add_scalar("Urban100_2x", psnr[0], self.step) | 100 | self.writer.add_scalar("Urban100_2x", psnr[0], self.step) |
... | @@ -107,6 +108,7 @@ class Solver(): | ... | @@ -107,6 +108,7 @@ class Solver(): |
107 | def evaluate(self, test_data_dir, scale=2, num_step=0): | 108 | def evaluate(self, test_data_dir, scale=2, num_step=0): |
108 | cfg = self.cfg | 109 | cfg = self.cfg |
109 | mean_psnr = 0 | 110 | mean_psnr = 0 |
111 | + mean_ssim = 0 | ||
110 | self.refiner.eval() | 112 | self.refiner.eval() |
111 | 113 | ||
112 | test_data = TestDataset(test_data_dir, scale=scale) | 114 | test_data = TestDataset(test_data_dir, scale=scale) |
... | @@ -149,15 +151,16 @@ class Solver(): | ... | @@ -149,15 +151,16 @@ class Solver(): |
149 | hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | 151 | hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() |
150 | sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | 152 | sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() |
151 | 153 | ||
152 | - # evaluate PSNR | 154 | + # evaluate PSNR and SSIM |
153 | # this evaluation is different to MATLAB version | 155 | # this evaluation is different to MATLAB version |
154 | # we evaluate PSNR in RGB channel not Y in YCbCR | 156 | # we evaluate PSNR in RGB channel not Y in YCbCR |
155 | bnd = scale | 157 | bnd = scale |
156 | - im1 = hr[bnd:-bnd, bnd:-bnd] | 158 | + im1 = im2double(hr[bnd:-bnd, bnd:-bnd]) |
157 | - im2 = sr[bnd:-bnd, bnd:-bnd] | 159 | + im2 = im2double(sr[bnd:-bnd, bnd:-bnd]) |
158 | mean_psnr += psnr(im1, im2) / len(test_data) | 160 | mean_psnr += psnr(im1, im2) / len(test_data) |
161 | + mean_ssim += ssim(im1, im2) / len(test_data) | ||
159 | 162 | ||
160 | - return mean_psnr | 163 | + return mean_psnr, mean_ssim |
161 | 164 | ||
162 | def load(self, path): | 165 | def load(self, path): |
163 | self.refiner.load_state_dict(torch.load(path)) | 166 | self.refiner.load_state_dict(torch.load(path)) |
... | @@ -177,14 +180,15 @@ class Solver(): | ... | @@ -177,14 +180,15 @@ class Solver(): |
177 | lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) | 180 | lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) |
178 | return lr | 181 | return lr |
179 | 182 | ||
180 | - | 183 | +def im2double(im): |
181 | -def psnr(im1, im2): | ||
182 | - def im2double(im): | ||
183 | min_val, max_val = 0, 255 | 184 | min_val, max_val = 0, 255 |
184 | out = (im.astype(np.float64)-min_val) / (max_val-min_val) | 185 | out = (im.astype(np.float64)-min_val) / (max_val-min_val) |
185 | return out | 186 | return out |
186 | 187 | ||
187 | - im1 = im2double(im1) | 188 | +def psnr(im1, im2): |
188 | - im2 = im2double(im2) | 189 | + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1) |
189 | - psnr = measure.compare_psnr(im1, im2, data_range=1) | ||
190 | return psnr | 190 | return psnr |
191 | + | ||
192 | +def ssim(im1, im2): | ||
193 | + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True) | ||
194 | + return ssim | ... | ... |
-
Please register or login to post a comment