encoder_train.py 1.78 KB
import torch
import torchvision
import torch.nn as nn
import argparse
from model import AutoEncoder, pytorch_autoencoder, AutoEncoder_s
from get_mean_std import get_params
from torchvision.utils import save_image

parser = argparse.ArgumentParser(description='Process autoencoder')
parser.add_argument('--config', type=str, help='select type')
args = parser.parse_args()

# 노말만 넣은 데이터
data_path = "../data/Fourth_data/Auto"
resize_size = 128
num_epochs = 100
batch_size = 128
learning_rate = 1e-3

# 보고서를 참고하여 만든 autoencoder 와 pytorch 에서 제공하는 autoencoder
if args.config == "my":
    model = AutoEncoder().cuda("cuda:1")
elif args.config == "pytorch":
    model = pytorch_autoencoder().cuda("cuda:1")
else:
    model = AutoEncoder_s().cuda("cuda:1")

print(model)
#mean, std = get_params(data_path, resize_size)

img_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((resize_size, resize_size)),
    torchvision.transforms.Grayscale(),
    torchvision.transforms.ToTensor(),
])

dataset = torchvision.datasets.ImageFolder(data_path, transform=img_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.cuda("cuda:1")
        output = model(img)
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.item()))

    if epoch % 10 ==0:
        save_image(output, './dc_img/image_{}.png'.format(epoch))

torch.save(model.state_dict(), './dc_img/checkpoint.pth')