김재형

CARN 테스트 코드 추가

1 +import os
2 +import json
3 +import time
4 +import importlib
5 +import argparse
6 +import numpy as np
7 +from collections import OrderedDict
8 +import torch
9 +import torch.nn as nn
10 +import torch.utils.data as data
11 +from glob import glob
12 +from torch.autograd import Variable
13 +from PIL import Image
14 +import torchvision.transforms as transforms
15 +from tqdm import tqdm
16 +
17 +def parse_args():
18 + parser = argparse.ArgumentParser()
19 + parser.add_argument("--model", type=str)
20 + parser.add_argument("--ckpt_path", type=str)
21 + parser.add_argument("--group", type=int, default=1)
22 + parser.add_argument("--sample_dir", type=str)
23 + parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100")
24 + parser.add_argument("--cuda", action="store_true")
25 + parser.add_argument("--scale", type=int, default=4)
26 + parser.add_argument("--shave", type=int, default=20)
27 +
28 + return parser.parse_args()
29 +
30 +
31 +def save_image(tensor, filename):
32 + tensor = tensor.cpu()
33 + ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
34 + im = Image.fromarray(ndarr)
35 + im.save(filename)
36 +
37 +
38 +class TestDataset(data.Dataset):
39 + def __init__(self, dirname, scale):
40 + super(TestDataset, self).__init__()
41 +
42 + self.lr = glob(os.path.join(dirname, "*.png"))
43 + self.lr.sort()
44 +
45 + self.transform = transforms.Compose([
46 + transforms.ToTensor()
47 + ])
48 +
49 + def __getitem__(self, index):
50 + lr = Image.open(self.lr[index])
51 + lr = lr.convert("RGB")
52 + filename = self.lr[index].split("/")[-1]
53 +
54 + return self.transform(lr), filename
55 +
56 + def __len__(self):
57 + return len(self.lr)
58 +
59 +
60 +def sample(net, device, dataset, cfg):
61 + scale = cfg.scale
62 + for lr, name in tqdm(dataset):
63 + t1 = time.time()
64 + lr = lr.unsqueeze(0).to(device)
65 + sr = net(lr, cfg.scale).detach().squeeze(0)
66 + lr = lr.squeeze(0)
67 + t2 = time.time()
68 +
69 + sr_dir = os.path.join(cfg.sample_dir, cfg.test_data_dir.split("/")[-1])
70 +
71 + os.makedirs(sr_dir, exist_ok=True)
72 +
73 + sr_im_path = os.path.join(sr_dir, name)
74 + save_image(sr, sr_im_path)
75 +
76 +
77 +def main(cfg):
78 + module = importlib.import_module("model.{}".format(cfg.model))
79 + net = module.Net(multi_scale=True,
80 + group=cfg.group)
81 + print(json.dumps(vars(cfg), indent=4, sort_keys=True))
82 +
83 + state_dict = torch.load(cfg.ckpt_path)
84 + new_state_dict = OrderedDict()
85 + for k, v in state_dict.items():
86 + name = k
87 + # name = k[7:] # remove "module."
88 + new_state_dict[name] = v
89 +
90 + net.load_state_dict(new_state_dict)
91 +
92 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 + net = net.to(device)
94 +
95 + dataset = TestDataset(cfg.test_data_dir, cfg.scale)
96 + sample(net, device, dataset, cfg)
97 +
98 +
99 +if __name__ == "__main__":
100 + cfg = parse_args()
101 + main(cfg)