cross_test.py
6.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
import argparse
from argparse import ArgumentError
import os
import glob
import time
import subprocess
import gluonnlp as nlp
from numpy.lib.function_base import delete
import torch
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor
def get_gpu_memory_map():
"""Get the current gpu usage.
Returns
-------
usage: dict
Keys are device ids as integers.
Values are memory usage as integers in MB.
"""
result = subprocess.check_output(
[
'nvidia-smi', '--query-gpu=memory.used',
'--format=csv,nounits,noheader'
], encoding='utf-8')
# Convert lines into a dictionary
gpu_memory = [int(x) for x in result.strip().split('\n')]
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
return gpu_memory_map
if __name__ == "__main__":
parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
parser.add_argument('--default', action='store_true', help="Use un-tuned KoGPT2")
parser.add_argument('--model_topic', choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'] )
parser.add_argument('--epoch', type=int)
parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
args = parser.parse_args()
print(args)
model_cache_path='/code/model' if args.docker else 'model'
dataset_path='/dataset' if args.docker else '../dataset'
save_path='/code/save' if args.docker else 'save'
ctx=args.device if torch.cuda.is_available() else 'cpu'
print(ctx)
device=torch.device(ctx)
tokenizer_path = get_tokenizer(cachedir=model_cache_path)
model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
num_workers=32
batch_size=64
padding_id=vocab[vocab.padding_token]
topics=set(sorted(args.topic))
transform=ToTensor(tokenizer,vocab,128)
print("Preparing dataloader...")
dataloaders={}
dataloaders["all"]=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=batch_size, num_workers=0)
for topic in tqdm(topics):
dataloaders[topic]=DataLoader(ArticleDataset(dataset_path, topics={topic},label='test', transform=transform),batch_size=batch_size, num_workers=0)
print("Prepared dataloader.")
epoches=30
checkpoint_epoch=0
learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
topic_all=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']
model_topic=topic_all if args.model_topic is None else sorted(list({args.model_topic}))
model_epoch='*' if args.epoch is None else args.epoch
dev=ctx if ctx in {'cpu', 'cuda'} else 'cuda:*'
braced=str("{'생활', '경제', 'IT_과학', '미용_건강', '스포츠', '사회', '연예', '문화', '정치'}") if args.model_topic is None else '{'+str(model_topic)[1:-1]+'}'
saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{dev}_{braced}_{transform.max_len}_{model_epoch}.state')
if not args.default:
if len(saves)>0:
last_save=max(saves,key=os.path.getmtime)
checkpoint = torch.load(last_save, map_location=device)
print(f"Loading save from {last_save}")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
last_test_loss = checkpoint['loss']
else:
print("No save exists.")
raise FileNotFoundError(f'{save_path}/KoGPT2_checkpoint_{ctx}_{model_topic}_{transform.max_len}_{model_epoch}.state')
model.to(device)
model.eval()
cached_testset_path=f"{save_path}/test_{topic_all}_{transform.max_len}"
if os.path.isfile(cached_testset_path+'.npy'):
dataloaders["all"].dataset.load_from_file(cached_testset_path+'.npy')
else:
print("Caching testset... topic: all")
for temp in tqdm(dataloaders["all"]):
pass
dataloaders["all"].dataset.set_use_cache(True, cached_testset_path)
print("Cached. topic: all")
dataloaders["all"].dataset.num_workers=num_workers
for topic in tqdm(topics):
cached_testset_path=f"{save_path}/test_{{topic}}_{transform.max_len}"
if os.path.isfile(cached_testset_path+'.npy'):
dataloaders[topic].dataset.load_from_file(cached_testset_path+'.npy')
else:
print(f"Caching testset... topic: {topic}")
for temp in tqdm(dataloaders[topic]):
pass
dataloaders[topic].dataset.set_use_cache(True, cached_testset_path)
print(f"Cached. topic: {topic}")
dataloaders[topic].dataset.num_workers=num_workers
last_test_loss=float('infinity')
overfit=-1
states=[]
for topic in tqdm(dataloaders):
try:
test_loss_list=[]
for data in tqdm(dataloaders[topic]):
data = data.to(ctx)
label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
output=model(data,labels=label, attention_mask=mask)
loss=output[0]
test_loss_list.append(loss.item())
del label
del mask
del loss
del output
del data
test_loss=sum(test_loss_list)/len(test_loss_list)
print(f"data_topic: {topic}, model_topic: {model_topic} test loss: {test_loss}")
states.append((topic, model_topic,test_loss))
except KeyboardInterrupt:
break
log_path=f"{save_path}/test_{'DEFAULT' if args.default else model_topic}_{topics}_{transform.max_len}_{int(time.time())}.log"
with open(log_path, 'w') as log:
log.write("data_topic, model_topic, test loss,\n")
for state in states:
log.write(f"{state[0]}, {state[1]},{state[2]},\n")
print(f"Log written at: {log_path}")