Showing
3 changed files
with
167 additions
and
7 deletions
| ... | @@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì | ... | @@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì |
| 15 | parser.add_argument('--batch_size', type=int, default=256) | 15 | parser.add_argument('--batch_size', type=int, default=256) |
| 16 | parser.add_argument('--num_epochs', type=int, default=22) | 16 | parser.add_argument('--num_epochs', type=int, default=22) |
| 17 | parser.add_argument('--warming_up_epochs', type=int, default=5) | 17 | parser.add_argument('--warming_up_epochs', type=int, default=5) |
| 18 | -parser.add_argument('--lr', type=float, default=0.0002) | 18 | +parser.add_argument('--lr', type=float, default=0.0002)#0.0002 |
| 19 | parser.add_argument('--embedding_dim', type=int, default=160) | 19 | parser.add_argument('--embedding_dim', type=int, default=160) |
| 20 | parser.add_argument('--nlayers', type=int, default=2) | 20 | parser.add_argument('--nlayers', type=int, default=2) |
| 21 | parser.add_argument('--nhead', type=int, default=2) | 21 | parser.add_argument('--nhead', type=int, default=2) |
| 22 | parser.add_argument('--dropout', type=float, default=0.1) | 22 | parser.add_argument('--dropout', type=float, default=0.1) |
| 23 | parser.add_argument('--train', type=bool, default=True) | 23 | parser.add_argument('--train', type=bool, default=True) |
| 24 | -parser.add_argument('--per_soft', type=bool, default=False) | 24 | +parser.add_argument('--per_soft', type=bool, default=True) |
| 25 | parser.add_argument('--per_rough', type=bool, default=False) | 25 | parser.add_argument('--per_rough', type=bool, default=False) |
| 26 | args = parser.parse_args() | 26 | args = parser.parse_args() |
| 27 | 27 | ||
| ... | @@ -30,6 +30,7 @@ def epoch_time(start_time, end_time): | ... | @@ -30,6 +30,7 @@ def epoch_time(start_time, end_time): |
| 30 | elapsed_time = end_time - start_time | 30 | elapsed_time = end_time - start_time |
| 31 | elapsed_mins = int(elapsed_time / 60) | 31 | elapsed_mins = int(elapsed_time / 60) |
| 32 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) | 32 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) |
| 33 | + | ||
| 33 | return elapsed_mins, elapsed_secs | 34 | return elapsed_mins, elapsed_secs |
| 34 | 35 | ||
| 35 | # 학습 | 36 | # 학습 | ... | ... |
| 1 | import torch | 1 | import torch |
| 2 | +import torch.nn.functional as F | ||
| 3 | +from math import log | ||
| 4 | +from numpy import array | ||
| 2 | from get_data import tokenizer1 | 5 | from get_data import tokenizer1 |
| 3 | from torch.autograd import Variable | 6 | from torch.autograd import Variable |
| 4 | from chatspace import ChatSpace | 7 | from chatspace import ChatSpace |
| 5 | spacer = ChatSpace() | 8 | spacer = ChatSpace() |
| 9 | +from konlpy.tag import Mecab | ||
| 10 | +import re | ||
| 11 | + | ||
| 12 | +def tokenizer1(text): | ||
| 13 | + result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text) | ||
| 14 | + a = Mecab().morphs(result_text) | ||
| 15 | + return ([a[i] for i in range(len(a))]) | ||
| 16 | + | ||
| 17 | +def _get_length_penalty(text, alpha=1.2, min_length=5): | ||
| 18 | + p_list = [] | ||
| 19 | + for i in range(len(text)): | ||
| 20 | + temp_text = tokenizer1(text[i][0]) | ||
| 21 | + length = len(temp_text) | ||
| 22 | + p_list.append(((5 + length) ** alpha) / (5 + 1) ** alpha) | ||
| 23 | + | ||
| 24 | + lp_list = [ text[j][1]/p_list[j] for j in range(len(text)) ] | ||
| 25 | + return lp_list | ||
| 26 | + | ||
| 27 | +def compair_beam_and_greedy(beam_pair, greedy_pair): | ||
| 28 | + lp_list = _get_length_penalty(beam_pair) | ||
| 29 | + gr_lp_list = _get_length_penalty(greedy_pair) | ||
| 30 | + | ||
| 31 | + low_val = float('inf') | ||
| 32 | + checked_sen = "" | ||
| 33 | + for idx in range(len(beam_pair)): | ||
| 34 | + if lp_list[idx] < low_val: | ||
| 35 | + low_val = lp_list[idx] | ||
| 36 | + checked_sen = beam_pair[idx][0] | ||
| 37 | + | ||
| 38 | + print(" beam output > ", checked_sen, " |", low_val) | ||
| 39 | + print("greedy output > ", greedy_pair[0][0]," |", gr_lp_list[0]) | ||
| 40 | + if low_val < gr_lp_list[0]: | ||
| 41 | + print("use beam") | ||
| 42 | + else: | ||
| 43 | + print("use greedy") | ||
| 44 | + | ||
| 45 | +def cal_score(pred, score): | ||
| 46 | + | ||
| 47 | + pred = F.softmax(pred, dim=-1) | ||
| 48 | + pred_ids = pred.max(dim=-1)[0] | ||
| 49 | + pred_ids = pred_ids.to('cpu').tolist() | ||
| 50 | + score = score * -log(pred_ids) | ||
| 51 | + | ||
| 52 | + return score | ||
| 53 | + | ||
| 54 | +def Beam_Search(data, k, first, sequences): | ||
| 55 | + #sequences = [[list(), 1.0]] | ||
| 56 | + | ||
| 57 | + if first: | ||
| 58 | + data = data.squeeze(0) | ||
| 59 | + else: | ||
| 60 | + data = data.unsqueeze(0) | ||
| 61 | + data = F.softmax(data, dim=-1) | ||
| 62 | + | ||
| 63 | + for row in data: | ||
| 64 | + all_candidates = list() | ||
| 65 | + for i in range(len(sequences)): | ||
| 66 | + seq, score = sequences[i] | ||
| 67 | + for j in range(len(row)): | ||
| 68 | + no_tensor_row = row[j].to('cpu').tolist() | ||
| 69 | + candidate = [seq + [j], score * -log(no_tensor_row)] | ||
| 70 | + all_candidates.append(candidate) | ||
| 71 | + ordered = sorted(all_candidates, key=lambda tup:tup[1]) | ||
| 72 | + sequences = ordered[:k] | ||
| 73 | + | ||
| 74 | + return(sequences) | ||
| 75 | + | ||
| 76 | +def beam(args, dec_input, enc_input_index, model, first, device, k_, LABEL): | ||
| 77 | + temp_dec_input = torch.zeros([k_,1], dtype=torch.long) | ||
| 78 | + temp_dec_input = temp_dec_input + dec_input | ||
| 79 | + deliver_high_beam_value = torch.zeros([k_,1], dtype=torch.long) | ||
| 80 | + return_sentence_beamVal_pair = [] | ||
| 81 | + check_k = [float('inf')] * k_ | ||
| 82 | + sequences = [[list(), 1.0]] | ||
| 83 | + end_sentence = [] | ||
| 84 | + end_idx = [] | ||
| 85 | + | ||
| 86 | + if first: | ||
| 87 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
| 88 | + first_beam_sequence = Beam_Search(y_pred, k_, True, sequences) | ||
| 89 | + | ||
| 90 | + for i in range(len(deliver_high_beam_value)): | ||
| 91 | + deliver_high_beam_value[i] = first_beam_sequence[i][0][0] | ||
| 92 | + | ||
| 93 | + temp_dec_input = torch.cat( | ||
| 94 | + [temp_dec_input.to(torch.device('cpu')), | ||
| 95 | + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1) | ||
| 96 | + | ||
| 97 | + check_num = 0 | ||
| 98 | + beam_input_sequence = first_beam_sequence | ||
| 99 | + | ||
| 100 | + for i in range(args.max_len): | ||
| 101 | + which_value = [float('inf')] * k_ | ||
| 102 | + which_node = [0] * k_ | ||
| 103 | + | ||
| 104 | + for j in range(len(temp_dec_input)): | ||
| 105 | + if temp_dec_input[j][-1] == torch.LongTensor([3]): | ||
| 106 | + continue | ||
| 107 | + y_pred = model(enc_input_index.to(device), temp_dec_input[j].unsqueeze(0).to(device)) | ||
| 108 | + beam_seq = Beam_Search(y_pred.squeeze(0)[-1], k_, False, [beam_input_sequence[j]]) | ||
| 109 | + | ||
| 110 | + beam_input_sequence[j] = [[beam_seq[0][0][-1]], beam_seq[0][1]] | ||
| 111 | + which_node[j] = beam_seq[0][0][-1] # k개의 output중 누적확률 높은 거 get | ||
| 112 | + | ||
| 113 | + for l in range(len(deliver_high_beam_value)): | ||
| 114 | + if temp_dec_input[j][-1] == torch.LongTensor([3]): | ||
| 115 | + continue | ||
| 116 | + deliver_high_beam_value[l] = which_node[l] | ||
| 117 | + | ||
| 118 | + temp_dec_input = torch.cat( | ||
| 119 | + [temp_dec_input.to(torch.device('cpu')), | ||
| 120 | + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1) | ||
| 121 | + | ||
| 122 | + for x in range(len(temp_dec_input)): | ||
| 123 | + for y in range(len(temp_dec_input[x])): | ||
| 124 | + if temp_dec_input[x][y] == torch.LongTensor([3]) and check_k[x] == float('inf'): | ||
| 125 | + check_k[x] = beam_input_sequence[x][1] | ||
| 126 | + | ||
| 127 | + if i+1 == args.max_len: | ||
| 128 | + for k in range(k_): | ||
| 129 | + for kk in range(len(temp_dec_input[k])): | ||
| 130 | + if temp_dec_input[k][kk] == torch.LongTensor([3]): | ||
| 131 | + check_num += 1 | ||
| 132 | + end_sentence.append(temp_dec_input[k]) | ||
| 133 | + end_idx.append(k) | ||
| 134 | + break | ||
| 135 | + | ||
| 136 | + for l in range(len(end_sentence)): | ||
| 137 | + pred = [] | ||
| 138 | + for idx in range(len(end_sentence[l])): | ||
| 139 | + | ||
| 140 | + if end_sentence[l][idx] == torch.LongTensor([3]): | ||
| 141 | + pred_sentence = "".join(pred) | ||
| 142 | + pred_str = spacer.space(pred_sentence) | ||
| 143 | + #print(pred_str, " |", check_k[end_idx[l]]) | ||
| 144 | + return_sentence_beamVal_pair.append([pred_str, check_k[end_idx[l]]]) | ||
| 145 | + break | ||
| 146 | + else: | ||
| 147 | + if idx == 0: | ||
| 148 | + continue | ||
| 149 | + pred.append(LABEL.vocab.itos[end_sentence[l][idx]]) | ||
| 150 | + return return_sentence_beamVal_pair | ||
| 6 | 151 | ||
| 7 | def inference(device, args, TEXT, LABEL, model, sa_model): | 152 | def inference(device, args, TEXT, LABEL, model, sa_model): |
| 8 | from KoBERT.Sentiment_Analysis_BERT_main import bert_inference | 153 | from KoBERT.Sentiment_Analysis_BERT_main import bert_inference |
| ... | @@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model): | ... | @@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model): |
| 40 | 185 | ||
| 41 | model.eval() | 186 | model.eval() |
| 42 | pred = [] | 187 | pred = [] |
| 188 | + | ||
| 189 | + beam_k = 10 | ||
| 190 | + beam_sen_val_pair = beam(args, dec_input, enc_input_index, model, True, device, beam_k, LABEL) | ||
| 191 | + greedy_pair = [] | ||
| 43 | for i in range(args.max_len): | 192 | for i in range(args.max_len): |
| 44 | y_pred = model(enc_input_index.to(device), dec_input.to(device)) | 193 | y_pred = model(enc_input_index.to(device), dec_input.to(device)) |
| 194 | + if i == 0: | ||
| 195 | + score = cal_score(y_pred.squeeze(0)[-1], 1.0) | ||
| 196 | + else: | ||
| 197 | + score = cal_score(y_pred.squeeze(0)[-1], score ) | ||
| 198 | + | ||
| 45 | y_pred_ids = y_pred.max(dim=-1)[1] | 199 | y_pred_ids = y_pred.max(dim=-1)[1] |
| 200 | + | ||
| 46 | if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): | 201 | if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): |
| 47 | y_pred_ids = y_pred_ids.squeeze(0) | 202 | y_pred_ids = y_pred_ids.squeeze(0) |
| 48 | - print(">", end=" ") | 203 | + #print(">", end=" ") |
| 49 | for idx in range(len(y_pred_ids)): | 204 | for idx in range(len(y_pred_ids)): |
| 50 | if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | 205 | if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': |
| 51 | pred_sentence = "".join(pred) | 206 | pred_sentence = "".join(pred) |
| 52 | pred_str = spacer.space(pred_sentence) | 207 | pred_str = spacer.space(pred_sentence) |
| 53 | - print(pred_str) | 208 | + #print(pred_str, " |", score) |
| 209 | + greedy_pair.append([pred_str, score]) | ||
| 54 | break | 210 | break |
| 55 | else: | 211 | else: |
| 56 | pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | 212 | pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) |
| 213 | + | ||
| 214 | + compair_beam_and_greedy(beam_sen_val_pair, greedy_pair) | ||
| 57 | return 0 | 215 | return 0 |
| 58 | 216 | ||
| 59 | dec_input = torch.cat( | 217 | dec_input = torch.cat( |
| 60 | [dec_input.to(torch.device('cpu')), | 218 | [dec_input.to(torch.device('cpu')), |
| 61 | y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | 219 | y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) |
| 62 | - return 0 | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 220 | + | ||
| 221 | + | ... | ... |
| ... | @@ -51,8 +51,8 @@ def data_preprocessing(args, device): | ... | @@ -51,8 +51,8 @@ def data_preprocessing(args, device): |
| 51 | # TEXT, LABEL 에 필요한 special token 만듦. | 51 | # TEXT, LABEL 에 필요한 special token 만듦. |
| 52 | text_specials, label_specials = make_special_token(args) | 52 | text_specials, label_specials = make_special_token(args) |
| 53 | 53 | ||
| 54 | - TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials) | 54 | + TEXT.build_vocab(train_data,vectors=vectors, max_size=15000, specials=text_specials) |
| 55 | - LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials) | 55 | + LABEL.build_vocab(train_data,vectors=vectors, max_size=15000, specials=label_specials) |
| 56 | 56 | ||
| 57 | train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | 57 | train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) |
| 58 | test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | 58 | test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ... | ... |
-
Please register or login to post a comment