cross_test.py 6.71 KB
# -*- coding: utf-8 -*-
import argparse
from argparse import ArgumentError
import os
import glob
import time
import subprocess

import gluonnlp as nlp
from numpy.lib.function_base import delete
import torch
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor

def get_gpu_memory_map():
	"""Get the current gpu usage.
	Returns
	-------
	usage: dict
		Keys are device ids as integers.
		Values are memory usage as integers in MB.
	"""
	result = subprocess.check_output(
		[
			'nvidia-smi', '--query-gpu=memory.used',
			'--format=csv,nounits,noheader'
		], encoding='utf-8')
	# Convert lines into a dictionary
	gpu_memory = [int(x) for x in result.strip().split('\n')]
	gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
	return gpu_memory_map


if __name__ == "__main__":
    parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
    parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
    parser.add_argument('--default', action='store_true', help="Use un-tuned KoGPT2")
    parser.add_argument('--model_topic', choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'] )
    parser.add_argument('--epoch', type=int)
    parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
    parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
    args = parser.parse_args()
    print(args)

    model_cache_path='/code/model' if args.docker else 'model'
    dataset_path='/dataset' if args.docker else '../dataset'
    save_path='/code/save' if args.docker else 'save'

    ctx=args.device if torch.cuda.is_available() else 'cpu'
    print(ctx)
    device=torch.device(ctx)
    tokenizer_path = get_tokenizer(cachedir=model_cache_path)
    model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
    tokenizer = SentencepieceTokenizer(tokenizer_path,  num_best=0, alpha=0)
    num_workers=32
    batch_size=64
    padding_id=vocab[vocab.padding_token]
    topics=set(sorted(args.topic))
    transform=ToTensor(tokenizer,vocab,128)
    print("Preparing dataloader...")
    dataloaders={}
    dataloaders["all"]=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=batch_size, num_workers=0)
    for topic in tqdm(topics):
        dataloaders[topic]=DataLoader(ArticleDataset(dataset_path, topics={topic},label='test', transform=transform),batch_size=batch_size, num_workers=0)
    print("Prepared dataloader.")
    epoches=30
    checkpoint_epoch=0
    learning_rate = 3e-5
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    topic_all=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']
    model_topic=topic_all if args.model_topic is None else sorted(list({args.model_topic}))
    model_epoch='*' if args.epoch is None else args.epoch
    dev=ctx if ctx in {'cpu', 'cuda'} else 'cuda:*'
    braced=str("{'생활', '경제', 'IT_과학', '미용_건강', '스포츠', '사회', '연예', '문화', '정치'}") if args.model_topic is None else '{'+str(model_topic)[1:-1]+'}'
    saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{dev}_{braced}_{transform.max_len}_{model_epoch}.state')
    if not args.default:
        if len(saves)>0:
            last_save=max(saves,key=os.path.getmtime)
            checkpoint = torch.load(last_save, map_location=device)
            print(f"Loading save from {last_save}")
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            checkpoint_epoch = checkpoint['epoch']
            last_test_loss = checkpoint['loss']
        else:
            print("No save exists.")
            raise FileNotFoundError(f'{save_path}/KoGPT2_checkpoint_{ctx}_{model_topic}_{transform.max_len}_{model_epoch}.state')
    model.to(device)
    model.eval()
    
    cached_testset_path=f"{save_path}/test_{topic_all}_{transform.max_len}"
    if os.path.isfile(cached_testset_path+'.npy'):
        dataloaders["all"].dataset.load_from_file(cached_testset_path+'.npy')
    else:
        print("Caching testset... topic: all")
        for temp in tqdm(dataloaders["all"]):
            pass
        dataloaders["all"].dataset.set_use_cache(True, cached_testset_path)
    print("Cached. topic: all")
    dataloaders["all"].dataset.num_workers=num_workers
    for topic in tqdm(topics):
        cached_testset_path=f"{save_path}/test_{{topic}}_{transform.max_len}"
        if os.path.isfile(cached_testset_path+'.npy'):
            dataloaders[topic].dataset.load_from_file(cached_testset_path+'.npy')
        else:
            print(f"Caching testset... topic: {topic}")
            for temp in tqdm(dataloaders[topic]):
                pass
            dataloaders[topic].dataset.set_use_cache(True, cached_testset_path)
        print(f"Cached. topic: {topic}")
        dataloaders[topic].dataset.num_workers=num_workers
    
    last_test_loss=float('infinity')
    overfit=-1
    states=[]
    
    for topic in tqdm(dataloaders):
        try:
            test_loss_list=[]
            for data in tqdm(dataloaders[topic]):
                data = data.to(ctx)
                label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
                mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
                output=model(data,labels=label, attention_mask=mask)
                loss=output[0]
                test_loss_list.append(loss.item())
                del label
                del mask
                del loss
                del output
                del data
            test_loss=sum(test_loss_list)/len(test_loss_list)
            print(f"data_topic: {topic}, model_topic: {model_topic} test loss: {test_loss}")
            states.append((topic, model_topic,test_loss))
        except KeyboardInterrupt:
            break
    log_path=f"{save_path}/test_{'DEFAULT' if args.default else model_topic}_{topics}_{transform.max_len}_{int(time.time())}.log"
    with open(log_path, 'w') as log:
        log.write("data_topic, model_topic, test loss,\n")
        for state in states:
            log.write(f"{state[0]}, {state[1]},{state[2]},\n")
    print(f"Log written at: {log_path}")