train.py
4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
import argparse
import os
import glob
import gluonnlp as nlp
import torch
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor
if __name__ == "__main__":
ctx='cuda' if torch.cuda.is_available() else 'cpu'
device=torch.device(ctx)
tokenizer_path = get_tokenizer(cachedir='/code/model')
model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
num_workers=0
padding_id=vocab[vocab.padding_token]
transform=ToTensor(tokenizer,vocab)
print("Preparing dataloader...")
trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True)
validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0)
#testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4)
print("Prepared dataloader.")
epoches=200
checkpoint_epoch=0
learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
save_path='/model/save'
saves=glob.glob(save_path+'*.state')
if len(saves)>0:
last_save=max(saves,key=os.path.getmtime)
checkpoint = torch.load(last_save)
print(f"Loading save from {last_save}")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
loss = checkpoint['loss']
else:
print("No save exists.")
model.to(device)
model.train()
last_valid_loss=float('infinity')
for epoch in tqdm(range(checkpoint_epoch,epoches)):
train_loss_list=[]
valid_loss_list=[]
for data in tqdm(trainset):
optimizer.zero_grad()
data = data.to(ctx)
label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
output=model(data, labels=label, attention_mask=mask)
loss, logits=output[0], output[1]
#loss = loss.to(ctx)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
with torch.no_grad():
for v_data in tqdm(validset):
v_data = v_data.to(ctx)
v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100)
v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
v_output=model(v_data,labels=v_label, attention_mask=v_mask)
v_loss, v_logits=v_output[0], v_output[1]
valid_loss_list.append(v_loss.item())
valid_loss=sum(valid_loss_list)/len(valid_loss_list)
print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}")
if valid_loss>last_valid_loss or (epoch%10==9):
try:
torch.save({
'epoch': epoch,
'train_no': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state")
except Exception as e:
print(e)
last_valid_loss=valid_loss
if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset.
trainset.dataset.set_use_cache(True)
trainset.num_workers=num_workers
validset.dataset.set_use_cache(True)
validset.num_workers=num_workers