김민수

Revert "Added weekly report"

This reverts commit 146717f4.
1 -# -*- coding: utf-8 -*-
2 -import argparse
3 -import os
4 -import glob
5 -
6 -import gluonnlp as nlp
7 -import torch
8 -from torch.utils.data import DataLoader, Dataset
9 -from gluonnlp.data import SentencepieceTokenizer
10 -from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
11 -from kogpt2.utils import get_tokenizer
12 -from tqdm import tqdm
13 -from util.data_loader import ArticleDataset, ToTensor
14 -
15 -if __name__ == "__main__":
16 - ctx='cuda' if torch.cuda.is_available() else 'cpu'
17 - device=torch.device(ctx)
18 - tokenizer_path = get_tokenizer(cachedir='/code/model')
19 - model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
20 - tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
21 - num_workers=0
22 - padding_id=vocab[vocab.padding_token]
23 -
24 - transform=ToTensor(tokenizer,vocab)
25 - print("Preparing dataloader...")
26 - trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True)
27 - validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0)
28 - #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4)
29 - print("Prepared dataloader.")
30 - epoches=200
31 - checkpoint_epoch=0
32 - learning_rate = 3e-5
33 - criterion = torch.nn.CrossEntropyLoss()
34 - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
35 -
36 - save_path='/model/save'
37 - saves=glob.glob(save_path+'*.state')
38 - if len(saves)>0:
39 - last_save=max(saves,key=os.path.getmtime)
40 - checkpoint = torch.load(last_save)
41 - print(f"Loading save from {last_save}")
42 - model.load_state_dict(checkpoint['model_state_dict'])
43 - optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
44 - checkpoint_epoch = checkpoint['epoch']
45 - loss = checkpoint['loss']
46 - else:
47 - print("No save exists.")
48 -
49 -
50 - model.to(device)
51 - model.train()
52 -
53 - last_valid_loss=float('infinity')
54 - for epoch in tqdm(range(checkpoint_epoch,epoches)):
55 - train_loss_list=[]
56 - valid_loss_list=[]
57 - for data in tqdm(trainset):
58 - optimizer.zero_grad()
59 - data = data.to(ctx)
60 - label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
61 - mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
62 - output=model(data, labels=label, attention_mask=mask)
63 - loss, logits=output[0], output[1]
64 - #loss = loss.to(ctx)
65 - loss.backward()
66 - optimizer.step()
67 - train_loss_list.append(loss.item())
68 - with torch.no_grad():
69 -
70 - for v_data in tqdm(validset):
71 - v_data = v_data.to(ctx)
72 - v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100)
73 - v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
74 - v_output=model(v_data,labels=v_label, attention_mask=v_mask)
75 - v_loss, v_logits=v_output[0], v_output[1]
76 - valid_loss_list.append(v_loss.item())
77 - valid_loss=sum(valid_loss_list)/len(valid_loss_list)
78 - print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}")
79 - if valid_loss>last_valid_loss or (epoch%10==9):
80 - try:
81 - torch.save({
82 - 'epoch': epoch,
83 - 'train_no': i,
84 - 'model_state_dict': model.state_dict(),
85 - 'optimizer_state_dict': optimizer.state_dict(),
86 - 'loss': loss
87 - }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state")
88 - except Exception as e:
89 - print(e)
90 - last_valid_loss=valid_loss
91 - if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset.
92 - trainset.dataset.set_use_cache(True)
93 - trainset.num_workers=num_workers
94 - validset.dataset.set_use_cache(True)
95 - validset.num_workers=num_workers
96 -
97 -
98 -
1 -과학
2 -100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s]
3 -count[256]: 30451/215067 (%)
4 -count[512]: 137611/215067 (%)
5 -count[768]: 185856/215067 (%)
6 -count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능
7 -count[1280]: 211386/215067 (%)
8 -count[1536]: 213877/215067 (%)
9 -count[1792]: 214932/215067 (%)
10 -전체
11 -100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s]
12 -count[256]: 421097/3444755 (12.2%)
13 -count[512]: 2110517/3444755 (61.2%)
14 -count[768]: 2927091/3444755 (84.9%)
15 -count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능
16 -count[1280]: 3355523/3444755 (97.4%)
17 -count[1536]: 3410390/3444755 (99.0%)
18 -count[1792]: 3437609/3444755 (99.7%)