김재형

CARN 학습 시 PSNR, SSIM eval 코드 추가

......@@ -2,7 +2,7 @@ import os
import random
import numpy as np
import scipy.misc as misc
import skimage.measure as measure
import skimage.metrics as metrics
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
......@@ -92,8 +92,9 @@ class Solver():
self.step += 1
if cfg.verbose and self.step % cfg.print_interval == 0:
if cfg.scale > 0:
psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("Urban100", psnr, self.step)
psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("PSNR", psnr, self.step)
self.writer.add_scalar("SSIM", ssim, self.step)
else:
psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
......@@ -107,6 +108,7 @@ class Solver():
def evaluate(self, test_data_dir, scale=2, num_step=0):
cfg = self.cfg
mean_psnr = 0
mean_ssim = 0
self.refiner.eval()
test_data = TestDataset(test_data_dir, scale=scale)
......@@ -149,15 +151,16 @@ class Solver():
hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
# evaluate PSNR
# evaluate PSNR and SSIM
# this evaluation is different to MATLAB version
# we evaluate PSNR in RGB channel not Y in YCbCR
bnd = scale
im1 = hr[bnd:-bnd, bnd:-bnd]
im2 = sr[bnd:-bnd, bnd:-bnd]
im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
mean_psnr += psnr(im1, im2) / len(test_data)
mean_ssim += ssim(im1, im2) / len(test_data)
return mean_psnr
return mean_psnr, mean_ssim
def load(self, path):
self.refiner.load_state_dict(torch.load(path))
......@@ -177,14 +180,15 @@ class Solver():
lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
return lr
def psnr(im1, im2):
def im2double(im):
def im2double(im):
min_val, max_val = 0, 255
out = (im.astype(np.float64)-min_val) / (max_val-min_val)
return out
im1 = im2double(im1)
im2 = im2double(im2)
psnr = measure.compare_psnr(im1, im2, data_range=1)
def psnr(im1, im2):
psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
return psnr
def ssim(im1, im2):
ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
return ssim
......