encoder_train.py
1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torchvision
import torch.nn as nn
import argparse
from model import AutoEncoder, pytorch_autoencoder, AutoEncoder_s
from get_mean_std import get_params
from torchvision.utils import save_image
parser = argparse.ArgumentParser(description='Process autoencoder')
parser.add_argument('--config', type=str, help='select type')
args = parser.parse_args()
# 노말만 넣은 데이터
data_path = "../data/Fourth_data/Auto"
resize_size = 128
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
# 보고서를 참고하여 만든 autoencoder 와 pytorch 에서 제공하는 autoencoder
if args.config == "my":
model = AutoEncoder().cuda("cuda:1")
elif args.config == "pytorch":
model = pytorch_autoencoder().cuda("cuda:1")
else:
model = AutoEncoder_s().cuda("cuda:1")
print(model)
#mean, std = get_params(data_path, resize_size)
img_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((resize_size, resize_size)),
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(data_path, transform=img_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
for epoch in range(num_epochs):
for data in dataloader:
img, _ = data
img = img.cuda("cuda:1")
output = model(img)
loss = criterion(output, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.item()))
if epoch % 10 ==0:
save_image(output, './dc_img/image_{}.png'.format(epoch))
torch.save(model.state_dict(), './dc_img/checkpoint.pth')