김재형

Test 코드 수정

......@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg):
def main(cfg):
module = importlib.import_module("model.{}".format(cfg.model))
net = module.Net(multi_scale=True,
net = module.Net(multi_scale=False,
scale=cfg.scale,
group=cfg.group)
print(json.dumps(vars(cfg), indent=4, sort_keys=True))
state_dict = torch.load(cfg.ckpt_path)
# print(state_dict.keys())
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
......@@ -88,11 +90,13 @@ def main(cfg):
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
net.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
dataset = TestDataset(cfg.test_data_dir, cfg.scale)
with torch.no_grad():
sample(net, device, dataset, cfg)
......