encoder_test.py
1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torchvision
import torch.nn as nn
import argparse
from model import AutoEncoder, pytorch_autoencoder
from get_mean_std import get_params
from torchvision.utils import save_image
parser = argparse.ArgumentParser(description='Process autoencoder')
parser.add_argument('--config', type=str, help='select type')
args = parser.parse_args()
# Scratch에서만 넣은 데이터
data_path = "../data/Fourth_data/Auto_test"
checkpoint_path = "./dc_img/checkpoint.pth"
resize_size = 128
batch_size = 128
# 보고서를 참고하여 만든 autoencoder 와 pytorch 에서 제공하는 autoencoder
if args.config == "my":
model = AutoEncoder().cuda("cuda:1")
else:
model = pytorch_autoencoder().cuda("cuda:1")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
print("checkpoint loaded finish!")
img_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((resize_size, resize_size)),
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(data_path, transform=img_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
criterion = nn.L1Loss()
for idx, data in enumerate(dataloader):
img, _ = data
img = img.cuda("cuda:1")
output = model(img)
save_image(output, f'./dc_img/test_output_{idx}.png')
loss = criterion(output, img)
img = img - output
save_image(img, f'./dc_img/scratch_dif_{idx}.png')
print(f"loss : {loss}")