Showing
1 changed file
with
17 additions
and
13 deletions
| ... | @@ -2,7 +2,7 @@ import os | ... | @@ -2,7 +2,7 @@ import os |
| 2 | import random | 2 | import random |
| 3 | import numpy as np | 3 | import numpy as np |
| 4 | import scipy.misc as misc | 4 | import scipy.misc as misc |
| 5 | -import skimage.measure as measure | 5 | +import skimage.metrics as metrics |
| 6 | from tensorboardX import SummaryWriter | 6 | from tensorboardX import SummaryWriter |
| 7 | import torch | 7 | import torch |
| 8 | import torch.nn as nn | 8 | import torch.nn as nn |
| ... | @@ -92,8 +92,9 @@ class Solver(): | ... | @@ -92,8 +92,9 @@ class Solver(): |
| 92 | self.step += 1 | 92 | self.step += 1 |
| 93 | if cfg.verbose and self.step % cfg.print_interval == 0: | 93 | if cfg.verbose and self.step % cfg.print_interval == 0: |
| 94 | if cfg.scale > 0: | 94 | if cfg.scale > 0: |
| 95 | - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) | 95 | + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) |
| 96 | - self.writer.add_scalar("Urban100", psnr, self.step) | 96 | + self.writer.add_scalar("PSNR", psnr, self.step) |
| 97 | + self.writer.add_scalar("SSIM", ssim, self.step) | ||
| 97 | else: | 98 | else: |
| 98 | psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] | 99 | psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] |
| 99 | self.writer.add_scalar("Urban100_2x", psnr[0], self.step) | 100 | self.writer.add_scalar("Urban100_2x", psnr[0], self.step) |
| ... | @@ -107,6 +108,7 @@ class Solver(): | ... | @@ -107,6 +108,7 @@ class Solver(): |
| 107 | def evaluate(self, test_data_dir, scale=2, num_step=0): | 108 | def evaluate(self, test_data_dir, scale=2, num_step=0): |
| 108 | cfg = self.cfg | 109 | cfg = self.cfg |
| 109 | mean_psnr = 0 | 110 | mean_psnr = 0 |
| 111 | + mean_ssim = 0 | ||
| 110 | self.refiner.eval() | 112 | self.refiner.eval() |
| 111 | 113 | ||
| 112 | test_data = TestDataset(test_data_dir, scale=scale) | 114 | test_data = TestDataset(test_data_dir, scale=scale) |
| ... | @@ -149,15 +151,16 @@ class Solver(): | ... | @@ -149,15 +151,16 @@ class Solver(): |
| 149 | hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | 151 | hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() |
| 150 | sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | 152 | sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() |
| 151 | 153 | ||
| 152 | - # evaluate PSNR | 154 | + # evaluate PSNR and SSIM |
| 153 | # this evaluation is different to MATLAB version | 155 | # this evaluation is different to MATLAB version |
| 154 | # we evaluate PSNR in RGB channel not Y in YCbCR | 156 | # we evaluate PSNR in RGB channel not Y in YCbCR |
| 155 | bnd = scale | 157 | bnd = scale |
| 156 | - im1 = hr[bnd:-bnd, bnd:-bnd] | 158 | + im1 = im2double(hr[bnd:-bnd, bnd:-bnd]) |
| 157 | - im2 = sr[bnd:-bnd, bnd:-bnd] | 159 | + im2 = im2double(sr[bnd:-bnd, bnd:-bnd]) |
| 158 | mean_psnr += psnr(im1, im2) / len(test_data) | 160 | mean_psnr += psnr(im1, im2) / len(test_data) |
| 161 | + mean_ssim += ssim(im1, im2) / len(test_data) | ||
| 159 | 162 | ||
| 160 | - return mean_psnr | 163 | + return mean_psnr, mean_ssim |
| 161 | 164 | ||
| 162 | def load(self, path): | 165 | def load(self, path): |
| 163 | self.refiner.load_state_dict(torch.load(path)) | 166 | self.refiner.load_state_dict(torch.load(path)) |
| ... | @@ -177,14 +180,15 @@ class Solver(): | ... | @@ -177,14 +180,15 @@ class Solver(): |
| 177 | lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) | 180 | lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) |
| 178 | return lr | 181 | return lr |
| 179 | 182 | ||
| 180 | - | 183 | +def im2double(im): |
| 181 | -def psnr(im1, im2): | ||
| 182 | - def im2double(im): | ||
| 183 | min_val, max_val = 0, 255 | 184 | min_val, max_val = 0, 255 |
| 184 | out = (im.astype(np.float64)-min_val) / (max_val-min_val) | 185 | out = (im.astype(np.float64)-min_val) / (max_val-min_val) |
| 185 | return out | 186 | return out |
| 186 | 187 | ||
| 187 | - im1 = im2double(im1) | 188 | +def psnr(im1, im2): |
| 188 | - im2 = im2double(im2) | 189 | + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1) |
| 189 | - psnr = measure.compare_psnr(im1, im2, data_range=1) | ||
| 190 | return psnr | 190 | return psnr |
| 191 | + | ||
| 192 | +def ssim(im1, im2): | ||
| 193 | + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True) | ||
| 194 | + return ssim | ... | ... |
-
Please register or login to post a comment