김민수

Revert "Added weekly report"

This reverts commit 146717f4.
# -*- coding: utf-8 -*-
import argparse
import os
import glob
import gluonnlp as nlp
import torch
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor
if __name__ == "__main__":
ctx='cuda' if torch.cuda.is_available() else 'cpu'
device=torch.device(ctx)
tokenizer_path = get_tokenizer(cachedir='/code/model')
model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
num_workers=0
padding_id=vocab[vocab.padding_token]
transform=ToTensor(tokenizer,vocab)
print("Preparing dataloader...")
trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True)
validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0)
#testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4)
print("Prepared dataloader.")
epoches=200
checkpoint_epoch=0
learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
save_path='/model/save'
saves=glob.glob(save_path+'*.state')
if len(saves)>0:
last_save=max(saves,key=os.path.getmtime)
checkpoint = torch.load(last_save)
print(f"Loading save from {last_save}")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
loss = checkpoint['loss']
else:
print("No save exists.")
model.to(device)
model.train()
last_valid_loss=float('infinity')
for epoch in tqdm(range(checkpoint_epoch,epoches)):
train_loss_list=[]
valid_loss_list=[]
for data in tqdm(trainset):
optimizer.zero_grad()
data = data.to(ctx)
label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
output=model(data, labels=label, attention_mask=mask)
loss, logits=output[0], output[1]
#loss = loss.to(ctx)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
with torch.no_grad():
for v_data in tqdm(validset):
v_data = v_data.to(ctx)
v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100)
v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
v_output=model(v_data,labels=v_label, attention_mask=v_mask)
v_loss, v_logits=v_output[0], v_output[1]
valid_loss_list.append(v_loss.item())
valid_loss=sum(valid_loss_list)/len(valid_loss_list)
print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}")
if valid_loss>last_valid_loss or (epoch%10==9):
try:
torch.save({
'epoch': epoch,
'train_no': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state")
except Exception as e:
print(e)
last_valid_loss=valid_loss
if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset.
trainset.dataset.set_use_cache(True)
trainset.num_workers=num_workers
validset.dataset.set_use_cache(True)
validset.num_workers=num_workers
과학
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s]
count[256]: 30451/215067 (%)
count[512]: 137611/215067 (%)
count[768]: 185856/215067 (%)
count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능
count[1280]: 211386/215067 (%)
count[1536]: 213877/215067 (%)
count[1792]: 214932/215067 (%)
전체
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s]
count[256]: 421097/3444755 (12.2%)
count[512]: 2110517/3444755 (61.2%)
count[768]: 2927091/3444755 (84.9%)
count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능
count[1280]: 3355523/3444755 (97.4%)
count[1536]: 3410390/3444755 (99.0%)
count[1792]: 3437609/3444755 (99.7%)