김민수

Added weekly report

1 +# -*- coding: utf-8 -*-
2 +import argparse
3 +import os
4 +import glob
5 +
6 +import gluonnlp as nlp
7 +import torch
8 +from torch.utils.data import DataLoader, Dataset
9 +from gluonnlp.data import SentencepieceTokenizer
10 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
11 +from kogpt2.utils import get_tokenizer
12 +from tqdm import tqdm
13 +from util.data_loader import ArticleDataset, ToTensor
14 +
15 +if __name__ == "__main__":
16 + ctx='cuda' if torch.cuda.is_available() else 'cpu'
17 + device=torch.device(ctx)
18 + tokenizer_path = get_tokenizer(cachedir='/code/model')
19 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
20 + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
21 + num_workers=0
22 + padding_id=vocab[vocab.padding_token]
23 +
24 + transform=ToTensor(tokenizer,vocab)
25 + print("Preparing dataloader...")
26 + trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True)
27 + validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0)
28 + #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4)
29 + print("Prepared dataloader.")
30 + epoches=200
31 + checkpoint_epoch=0
32 + learning_rate = 3e-5
33 + criterion = torch.nn.CrossEntropyLoss()
34 + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
35 +
36 + save_path='/model/save'
37 + saves=glob.glob(save_path+'*.state')
38 + if len(saves)>0:
39 + last_save=max(saves,key=os.path.getmtime)
40 + checkpoint = torch.load(last_save)
41 + print(f"Loading save from {last_save}")
42 + model.load_state_dict(checkpoint['model_state_dict'])
43 + optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
44 + checkpoint_epoch = checkpoint['epoch']
45 + loss = checkpoint['loss']
46 + else:
47 + print("No save exists.")
48 +
49 +
50 + model.to(device)
51 + model.train()
52 +
53 + last_valid_loss=float('infinity')
54 + for epoch in tqdm(range(checkpoint_epoch,epoches)):
55 + train_loss_list=[]
56 + valid_loss_list=[]
57 + for data in tqdm(trainset):
58 + optimizer.zero_grad()
59 + data = data.to(ctx)
60 + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
61 + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
62 + output=model(data, labels=label, attention_mask=mask)
63 + loss, logits=output[0], output[1]
64 + #loss = loss.to(ctx)
65 + loss.backward()
66 + optimizer.step()
67 + train_loss_list.append(loss.item())
68 + with torch.no_grad():
69 +
70 + for v_data in tqdm(validset):
71 + v_data = v_data.to(ctx)
72 + v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100)
73 + v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
74 + v_output=model(v_data,labels=v_label, attention_mask=v_mask)
75 + v_loss, v_logits=v_output[0], v_output[1]
76 + valid_loss_list.append(v_loss.item())
77 + valid_loss=sum(valid_loss_list)/len(valid_loss_list)
78 + print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}")
79 + if valid_loss>last_valid_loss or (epoch%10==9):
80 + try:
81 + torch.save({
82 + 'epoch': epoch,
83 + 'train_no': i,
84 + 'model_state_dict': model.state_dict(),
85 + 'optimizer_state_dict': optimizer.state_dict(),
86 + 'loss': loss
87 + }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state")
88 + except Exception as e:
89 + print(e)
90 + last_valid_loss=valid_loss
91 + if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset.
92 + trainset.dataset.set_use_cache(True)
93 + trainset.num_workers=num_workers
94 + validset.dataset.set_use_cache(True)
95 + validset.num_workers=num_workers
96 +
97 +
98 +
1 +과학
2 +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s]
3 +count[256]: 30451/215067 (%)
4 +count[512]: 137611/215067 (%)
5 +count[768]: 185856/215067 (%)
6 +count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능
7 +count[1280]: 211386/215067 (%)
8 +count[1536]: 213877/215067 (%)
9 +count[1792]: 214932/215067 (%)
10 +전체
11 +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s]
12 +count[256]: 421097/3444755 (12.2%)
13 +count[512]: 2110517/3444755 (61.2%)
14 +count[768]: 2927091/3444755 (84.9%)
15 +count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능
16 +count[1280]: 3355523/3444755 (97.4%)
17 +count[1536]: 3410390/3444755 (99.0%)
18 +count[1792]: 3437609/3444755 (99.7%)