Showing
3 changed files
with
116 additions
and
0 deletions
report/캡스톤 디자인 2 주간보고서-3.docx
0 → 100644
No preview for this file type
train.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +import argparse | ||
3 | +import os | ||
4 | +import glob | ||
5 | + | ||
6 | +import gluonnlp as nlp | ||
7 | +import torch | ||
8 | +from torch.utils.data import DataLoader, Dataset | ||
9 | +from gluonnlp.data import SentencepieceTokenizer | ||
10 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
11 | +from kogpt2.utils import get_tokenizer | ||
12 | +from tqdm import tqdm | ||
13 | +from util.data_loader import ArticleDataset, ToTensor | ||
14 | + | ||
15 | +if __name__ == "__main__": | ||
16 | + ctx='cuda' if torch.cuda.is_available() else 'cpu' | ||
17 | + device=torch.device(ctx) | ||
18 | + tokenizer_path = get_tokenizer(cachedir='/code/model') | ||
19 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model') | ||
20 | + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
21 | + num_workers=0 | ||
22 | + padding_id=vocab[vocab.padding_token] | ||
23 | + | ||
24 | + transform=ToTensor(tokenizer,vocab) | ||
25 | + print("Preparing dataloader...") | ||
26 | + trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True) | ||
27 | + validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0) | ||
28 | + #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4) | ||
29 | + print("Prepared dataloader.") | ||
30 | + epoches=200 | ||
31 | + checkpoint_epoch=0 | ||
32 | + learning_rate = 3e-5 | ||
33 | + criterion = torch.nn.CrossEntropyLoss() | ||
34 | + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||
35 | + | ||
36 | + save_path='/model/save' | ||
37 | + saves=glob.glob(save_path+'*.state') | ||
38 | + if len(saves)>0: | ||
39 | + last_save=max(saves,key=os.path.getmtime) | ||
40 | + checkpoint = torch.load(last_save) | ||
41 | + print(f"Loading save from {last_save}") | ||
42 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
43 | + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
44 | + checkpoint_epoch = checkpoint['epoch'] | ||
45 | + loss = checkpoint['loss'] | ||
46 | + else: | ||
47 | + print("No save exists.") | ||
48 | + | ||
49 | + | ||
50 | + model.to(device) | ||
51 | + model.train() | ||
52 | + | ||
53 | + last_valid_loss=float('infinity') | ||
54 | + for epoch in tqdm(range(checkpoint_epoch,epoches)): | ||
55 | + train_loss_list=[] | ||
56 | + valid_loss_list=[] | ||
57 | + for data in tqdm(trainset): | ||
58 | + optimizer.zero_grad() | ||
59 | + data = data.to(ctx) | ||
60 | + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100) | ||
61 | + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data)) | ||
62 | + output=model(data, labels=label, attention_mask=mask) | ||
63 | + loss, logits=output[0], output[1] | ||
64 | + #loss = loss.to(ctx) | ||
65 | + loss.backward() | ||
66 | + optimizer.step() | ||
67 | + train_loss_list.append(loss.item()) | ||
68 | + with torch.no_grad(): | ||
69 | + | ||
70 | + for v_data in tqdm(validset): | ||
71 | + v_data = v_data.to(ctx) | ||
72 | + v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100) | ||
73 | + v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data)) | ||
74 | + v_output=model(v_data,labels=v_label, attention_mask=v_mask) | ||
75 | + v_loss, v_logits=v_output[0], v_output[1] | ||
76 | + valid_loss_list.append(v_loss.item()) | ||
77 | + valid_loss=sum(valid_loss_list)/len(valid_loss_list) | ||
78 | + print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}") | ||
79 | + if valid_loss>last_valid_loss or (epoch%10==9): | ||
80 | + try: | ||
81 | + torch.save({ | ||
82 | + 'epoch': epoch, | ||
83 | + 'train_no': i, | ||
84 | + 'model_state_dict': model.state_dict(), | ||
85 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
86 | + 'loss': loss | ||
87 | + }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state") | ||
88 | + except Exception as e: | ||
89 | + print(e) | ||
90 | + last_valid_loss=valid_loss | ||
91 | + if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset. | ||
92 | + trainset.dataset.set_use_cache(True) | ||
93 | + trainset.num_workers=num_workers | ||
94 | + validset.dataset.set_use_cache(True) | ||
95 | + validset.num_workers=num_workers | ||
96 | + | ||
97 | + | ||
98 | + |
통계
0 → 100644
1 | +과학 | ||
2 | +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s] | ||
3 | +count[256]: 30451/215067 (%) | ||
4 | +count[512]: 137611/215067 (%) | ||
5 | +count[768]: 185856/215067 (%) | ||
6 | +count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능 | ||
7 | +count[1280]: 211386/215067 (%) | ||
8 | +count[1536]: 213877/215067 (%) | ||
9 | +count[1792]: 214932/215067 (%) | ||
10 | +전체 | ||
11 | +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s] | ||
12 | +count[256]: 421097/3444755 (12.2%) | ||
13 | +count[512]: 2110517/3444755 (61.2%) | ||
14 | +count[768]: 2927091/3444755 (84.9%) | ||
15 | +count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능 | ||
16 | +count[1280]: 3355523/3444755 (97.4%) | ||
17 | +count[1536]: 3410390/3444755 (99.0%) | ||
18 | +count[1792]: 3437609/3444755 (99.7%) |
-
Mentioned in commit ee62c0cc
-
Please register or login to post a comment