train.py 4.08 KB
# -*- coding: utf-8 -*-
import argparse
import os
import glob

import gluonnlp as nlp
import torch
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor

if __name__ == "__main__":
    ctx='cuda' if torch.cuda.is_available() else 'cpu'
    device=torch.device(ctx)
    tokenizer_path = get_tokenizer(cachedir='/code/model')
    model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
    tokenizer = SentencepieceTokenizer(tokenizer_path,  num_best=0, alpha=0)
    num_workers=0
    padding_id=vocab[vocab.padding_token]

    transform=ToTensor(tokenizer,vocab)
    print("Preparing dataloader...")
    trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True)
    validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0)
    #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4)
    print("Prepared dataloader.")
    epoches=200
    checkpoint_epoch=0
    learning_rate = 3e-5
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    save_path='/model/save'
    saves=glob.glob(save_path+'*.state')
    if len(saves)>0:
        last_save=max(saves,key=os.path.getmtime)
        checkpoint = torch.load(last_save)
        print(f"Loading save from {last_save}")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        checkpoint_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
    else:
        print("No save exists.")


    model.to(device)
    model.train()

    last_valid_loss=float('infinity')
    for epoch in tqdm(range(checkpoint_epoch,epoches)):
        train_loss_list=[]
        valid_loss_list=[]
        for data in tqdm(trainset):
            optimizer.zero_grad()
            data = data.to(ctx)
            label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
            mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
            output=model(data, labels=label, attention_mask=mask)
            loss, logits=output[0], output[1]
            #loss = loss.to(ctx)
            loss.backward()
            optimizer.step()
            train_loss_list.append(loss.item())
        with torch.no_grad():
            
            for v_data in tqdm(validset):
                v_data = v_data.to(ctx)
                v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100)
                v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
                v_output=model(v_data,labels=v_label, attention_mask=v_mask)
                v_loss, v_logits=v_output[0], v_output[1]
                valid_loss_list.append(v_loss.item())
        valid_loss=sum(valid_loss_list)/len(valid_loss_list)
        print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}")
        if valid_loss>last_valid_loss or (epoch%10==9):
            try:
                    torch.save({
                            'epoch': epoch,
                            'train_no': i,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': loss
                        }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state")
            except Exception as e:
                print(e)
        last_valid_loss=valid_loss
        if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset.
            trainset.dataset.set_use_cache(True)
            trainset.num_workers=num_workers
            validset.dataset.set_use_cache(True)
            validset.num_workers=num_workers