Showing
1 changed file
with
101 additions
and
0 deletions
carn/carn/test.py
0 → 100644
1 | +import os | ||
2 | +import json | ||
3 | +import time | ||
4 | +import importlib | ||
5 | +import argparse | ||
6 | +import numpy as np | ||
7 | +from collections import OrderedDict | ||
8 | +import torch | ||
9 | +import torch.nn as nn | ||
10 | +import torch.utils.data as data | ||
11 | +from glob import glob | ||
12 | +from torch.autograd import Variable | ||
13 | +from PIL import Image | ||
14 | +import torchvision.transforms as transforms | ||
15 | +from tqdm import tqdm | ||
16 | + | ||
17 | +def parse_args(): | ||
18 | + parser = argparse.ArgumentParser() | ||
19 | + parser.add_argument("--model", type=str) | ||
20 | + parser.add_argument("--ckpt_path", type=str) | ||
21 | + parser.add_argument("--group", type=int, default=1) | ||
22 | + parser.add_argument("--sample_dir", type=str) | ||
23 | + parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100") | ||
24 | + parser.add_argument("--cuda", action="store_true") | ||
25 | + parser.add_argument("--scale", type=int, default=4) | ||
26 | + parser.add_argument("--shave", type=int, default=20) | ||
27 | + | ||
28 | + return parser.parse_args() | ||
29 | + | ||
30 | + | ||
31 | +def save_image(tensor, filename): | ||
32 | + tensor = tensor.cpu() | ||
33 | + ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | ||
34 | + im = Image.fromarray(ndarr) | ||
35 | + im.save(filename) | ||
36 | + | ||
37 | + | ||
38 | +class TestDataset(data.Dataset): | ||
39 | + def __init__(self, dirname, scale): | ||
40 | + super(TestDataset, self).__init__() | ||
41 | + | ||
42 | + self.lr = glob(os.path.join(dirname, "*.png")) | ||
43 | + self.lr.sort() | ||
44 | + | ||
45 | + self.transform = transforms.Compose([ | ||
46 | + transforms.ToTensor() | ||
47 | + ]) | ||
48 | + | ||
49 | + def __getitem__(self, index): | ||
50 | + lr = Image.open(self.lr[index]) | ||
51 | + lr = lr.convert("RGB") | ||
52 | + filename = self.lr[index].split("/")[-1] | ||
53 | + | ||
54 | + return self.transform(lr), filename | ||
55 | + | ||
56 | + def __len__(self): | ||
57 | + return len(self.lr) | ||
58 | + | ||
59 | + | ||
60 | +def sample(net, device, dataset, cfg): | ||
61 | + scale = cfg.scale | ||
62 | + for lr, name in tqdm(dataset): | ||
63 | + t1 = time.time() | ||
64 | + lr = lr.unsqueeze(0).to(device) | ||
65 | + sr = net(lr, cfg.scale).detach().squeeze(0) | ||
66 | + lr = lr.squeeze(0) | ||
67 | + t2 = time.time() | ||
68 | + | ||
69 | + sr_dir = os.path.join(cfg.sample_dir, cfg.test_data_dir.split("/")[-1]) | ||
70 | + | ||
71 | + os.makedirs(sr_dir, exist_ok=True) | ||
72 | + | ||
73 | + sr_im_path = os.path.join(sr_dir, name) | ||
74 | + save_image(sr, sr_im_path) | ||
75 | + | ||
76 | + | ||
77 | +def main(cfg): | ||
78 | + module = importlib.import_module("model.{}".format(cfg.model)) | ||
79 | + net = module.Net(multi_scale=True, | ||
80 | + group=cfg.group) | ||
81 | + print(json.dumps(vars(cfg), indent=4, sort_keys=True)) | ||
82 | + | ||
83 | + state_dict = torch.load(cfg.ckpt_path) | ||
84 | + new_state_dict = OrderedDict() | ||
85 | + for k, v in state_dict.items(): | ||
86 | + name = k | ||
87 | + # name = k[7:] # remove "module." | ||
88 | + new_state_dict[name] = v | ||
89 | + | ||
90 | + net.load_state_dict(new_state_dict) | ||
91 | + | ||
92 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
93 | + net = net.to(device) | ||
94 | + | ||
95 | + dataset = TestDataset(cfg.test_data_dir, cfg.scale) | ||
96 | + sample(net, device, dataset, cfg) | ||
97 | + | ||
98 | + | ||
99 | +if __name__ == "__main__": | ||
100 | + cfg = parse_args() | ||
101 | + main(cfg) |
-
Please register or login to post a comment