Showing
3 changed files
with
0 additions
and
116 deletions
report/캡스톤 디자인 2 주간보고서-3.docx
deleted
100644 → 0
No preview for this file type
train.py
deleted
100644 → 0
1 | -# -*- coding: utf-8 -*- | ||
2 | -import argparse | ||
3 | -import os | ||
4 | -import glob | ||
5 | - | ||
6 | -import gluonnlp as nlp | ||
7 | -import torch | ||
8 | -from torch.utils.data import DataLoader, Dataset | ||
9 | -from gluonnlp.data import SentencepieceTokenizer | ||
10 | -from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
11 | -from kogpt2.utils import get_tokenizer | ||
12 | -from tqdm import tqdm | ||
13 | -from util.data_loader import ArticleDataset, ToTensor | ||
14 | - | ||
15 | -if __name__ == "__main__": | ||
16 | - ctx='cuda' if torch.cuda.is_available() else 'cpu' | ||
17 | - device=torch.device(ctx) | ||
18 | - tokenizer_path = get_tokenizer(cachedir='/code/model') | ||
19 | - model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model') | ||
20 | - tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
21 | - num_workers=0 | ||
22 | - padding_id=vocab[vocab.padding_token] | ||
23 | - | ||
24 | - transform=ToTensor(tokenizer,vocab) | ||
25 | - print("Preparing dataloader...") | ||
26 | - trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True) | ||
27 | - validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0) | ||
28 | - #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4) | ||
29 | - print("Prepared dataloader.") | ||
30 | - epoches=200 | ||
31 | - checkpoint_epoch=0 | ||
32 | - learning_rate = 3e-5 | ||
33 | - criterion = torch.nn.CrossEntropyLoss() | ||
34 | - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||
35 | - | ||
36 | - save_path='/model/save' | ||
37 | - saves=glob.glob(save_path+'*.state') | ||
38 | - if len(saves)>0: | ||
39 | - last_save=max(saves,key=os.path.getmtime) | ||
40 | - checkpoint = torch.load(last_save) | ||
41 | - print(f"Loading save from {last_save}") | ||
42 | - model.load_state_dict(checkpoint['model_state_dict']) | ||
43 | - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
44 | - checkpoint_epoch = checkpoint['epoch'] | ||
45 | - loss = checkpoint['loss'] | ||
46 | - else: | ||
47 | - print("No save exists.") | ||
48 | - | ||
49 | - | ||
50 | - model.to(device) | ||
51 | - model.train() | ||
52 | - | ||
53 | - last_valid_loss=float('infinity') | ||
54 | - for epoch in tqdm(range(checkpoint_epoch,epoches)): | ||
55 | - train_loss_list=[] | ||
56 | - valid_loss_list=[] | ||
57 | - for data in tqdm(trainset): | ||
58 | - optimizer.zero_grad() | ||
59 | - data = data.to(ctx) | ||
60 | - label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100) | ||
61 | - mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data)) | ||
62 | - output=model(data, labels=label, attention_mask=mask) | ||
63 | - loss, logits=output[0], output[1] | ||
64 | - #loss = loss.to(ctx) | ||
65 | - loss.backward() | ||
66 | - optimizer.step() | ||
67 | - train_loss_list.append(loss.item()) | ||
68 | - with torch.no_grad(): | ||
69 | - | ||
70 | - for v_data in tqdm(validset): | ||
71 | - v_data = v_data.to(ctx) | ||
72 | - v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100) | ||
73 | - v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data)) | ||
74 | - v_output=model(v_data,labels=v_label, attention_mask=v_mask) | ||
75 | - v_loss, v_logits=v_output[0], v_output[1] | ||
76 | - valid_loss_list.append(v_loss.item()) | ||
77 | - valid_loss=sum(valid_loss_list)/len(valid_loss_list) | ||
78 | - print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}") | ||
79 | - if valid_loss>last_valid_loss or (epoch%10==9): | ||
80 | - try: | ||
81 | - torch.save({ | ||
82 | - 'epoch': epoch, | ||
83 | - 'train_no': i, | ||
84 | - 'model_state_dict': model.state_dict(), | ||
85 | - 'optimizer_state_dict': optimizer.state_dict(), | ||
86 | - 'loss': loss | ||
87 | - }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state") | ||
88 | - except Exception as e: | ||
89 | - print(e) | ||
90 | - last_valid_loss=valid_loss | ||
91 | - if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset. | ||
92 | - trainset.dataset.set_use_cache(True) | ||
93 | - trainset.num_workers=num_workers | ||
94 | - validset.dataset.set_use_cache(True) | ||
95 | - validset.num_workers=num_workers | ||
96 | - | ||
97 | - | ||
98 | - |
통계
deleted
100644 → 0
1 | -과학 | ||
2 | -100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s] | ||
3 | -count[256]: 30451/215067 (%) | ||
4 | -count[512]: 137611/215067 (%) | ||
5 | -count[768]: 185856/215067 (%) | ||
6 | -count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능 | ||
7 | -count[1280]: 211386/215067 (%) | ||
8 | -count[1536]: 213877/215067 (%) | ||
9 | -count[1792]: 214932/215067 (%) | ||
10 | -전체 | ||
11 | -100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s] | ||
12 | -count[256]: 421097/3444755 (12.2%) | ||
13 | -count[512]: 2110517/3444755 (61.2%) | ||
14 | -count[768]: 2927091/3444755 (84.9%) | ||
15 | -count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능 | ||
16 | -count[1280]: 3355523/3444755 (97.4%) | ||
17 | -count[1536]: 3410390/3444755 (99.0%) | ||
18 | -count[1792]: 3437609/3444755 (99.7%) |
-
Please register or login to post a comment