Showing
3 changed files
with
167 additions
and
7 deletions
... | @@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì | ... | @@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì |
15 | parser.add_argument('--batch_size', type=int, default=256) | 15 | parser.add_argument('--batch_size', type=int, default=256) |
16 | parser.add_argument('--num_epochs', type=int, default=22) | 16 | parser.add_argument('--num_epochs', type=int, default=22) |
17 | parser.add_argument('--warming_up_epochs', type=int, default=5) | 17 | parser.add_argument('--warming_up_epochs', type=int, default=5) |
18 | -parser.add_argument('--lr', type=float, default=0.0002) | 18 | +parser.add_argument('--lr', type=float, default=0.0002)#0.0002 |
19 | parser.add_argument('--embedding_dim', type=int, default=160) | 19 | parser.add_argument('--embedding_dim', type=int, default=160) |
20 | parser.add_argument('--nlayers', type=int, default=2) | 20 | parser.add_argument('--nlayers', type=int, default=2) |
21 | parser.add_argument('--nhead', type=int, default=2) | 21 | parser.add_argument('--nhead', type=int, default=2) |
22 | parser.add_argument('--dropout', type=float, default=0.1) | 22 | parser.add_argument('--dropout', type=float, default=0.1) |
23 | parser.add_argument('--train', type=bool, default=True) | 23 | parser.add_argument('--train', type=bool, default=True) |
24 | -parser.add_argument('--per_soft', type=bool, default=False) | 24 | +parser.add_argument('--per_soft', type=bool, default=True) |
25 | parser.add_argument('--per_rough', type=bool, default=False) | 25 | parser.add_argument('--per_rough', type=bool, default=False) |
26 | args = parser.parse_args() | 26 | args = parser.parse_args() |
27 | 27 | ||
... | @@ -30,6 +30,7 @@ def epoch_time(start_time, end_time): | ... | @@ -30,6 +30,7 @@ def epoch_time(start_time, end_time): |
30 | elapsed_time = end_time - start_time | 30 | elapsed_time = end_time - start_time |
31 | elapsed_mins = int(elapsed_time / 60) | 31 | elapsed_mins = int(elapsed_time / 60) |
32 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) | 32 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) |
33 | + | ||
33 | return elapsed_mins, elapsed_secs | 34 | return elapsed_mins, elapsed_secs |
34 | 35 | ||
35 | # 학습 | 36 | # 학습 | ... | ... |
1 | import torch | 1 | import torch |
2 | +import torch.nn.functional as F | ||
3 | +from math import log | ||
4 | +from numpy import array | ||
2 | from get_data import tokenizer1 | 5 | from get_data import tokenizer1 |
3 | from torch.autograd import Variable | 6 | from torch.autograd import Variable |
4 | from chatspace import ChatSpace | 7 | from chatspace import ChatSpace |
5 | spacer = ChatSpace() | 8 | spacer = ChatSpace() |
9 | +from konlpy.tag import Mecab | ||
10 | +import re | ||
11 | + | ||
12 | +def tokenizer1(text): | ||
13 | + result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text) | ||
14 | + a = Mecab().morphs(result_text) | ||
15 | + return ([a[i] for i in range(len(a))]) | ||
16 | + | ||
17 | +def _get_length_penalty(text, alpha=1.2, min_length=5): | ||
18 | + p_list = [] | ||
19 | + for i in range(len(text)): | ||
20 | + temp_text = tokenizer1(text[i][0]) | ||
21 | + length = len(temp_text) | ||
22 | + p_list.append(((5 + length) ** alpha) / (5 + 1) ** alpha) | ||
23 | + | ||
24 | + lp_list = [ text[j][1]/p_list[j] for j in range(len(text)) ] | ||
25 | + return lp_list | ||
26 | + | ||
27 | +def compair_beam_and_greedy(beam_pair, greedy_pair): | ||
28 | + lp_list = _get_length_penalty(beam_pair) | ||
29 | + gr_lp_list = _get_length_penalty(greedy_pair) | ||
30 | + | ||
31 | + low_val = float('inf') | ||
32 | + checked_sen = "" | ||
33 | + for idx in range(len(beam_pair)): | ||
34 | + if lp_list[idx] < low_val: | ||
35 | + low_val = lp_list[idx] | ||
36 | + checked_sen = beam_pair[idx][0] | ||
37 | + | ||
38 | + print(" beam output > ", checked_sen, " |", low_val) | ||
39 | + print("greedy output > ", greedy_pair[0][0]," |", gr_lp_list[0]) | ||
40 | + if low_val < gr_lp_list[0]: | ||
41 | + print("use beam") | ||
42 | + else: | ||
43 | + print("use greedy") | ||
44 | + | ||
45 | +def cal_score(pred, score): | ||
46 | + | ||
47 | + pred = F.softmax(pred, dim=-1) | ||
48 | + pred_ids = pred.max(dim=-1)[0] | ||
49 | + pred_ids = pred_ids.to('cpu').tolist() | ||
50 | + score = score * -log(pred_ids) | ||
51 | + | ||
52 | + return score | ||
53 | + | ||
54 | +def Beam_Search(data, k, first, sequences): | ||
55 | + #sequences = [[list(), 1.0]] | ||
56 | + | ||
57 | + if first: | ||
58 | + data = data.squeeze(0) | ||
59 | + else: | ||
60 | + data = data.unsqueeze(0) | ||
61 | + data = F.softmax(data, dim=-1) | ||
62 | + | ||
63 | + for row in data: | ||
64 | + all_candidates = list() | ||
65 | + for i in range(len(sequences)): | ||
66 | + seq, score = sequences[i] | ||
67 | + for j in range(len(row)): | ||
68 | + no_tensor_row = row[j].to('cpu').tolist() | ||
69 | + candidate = [seq + [j], score * -log(no_tensor_row)] | ||
70 | + all_candidates.append(candidate) | ||
71 | + ordered = sorted(all_candidates, key=lambda tup:tup[1]) | ||
72 | + sequences = ordered[:k] | ||
73 | + | ||
74 | + return(sequences) | ||
75 | + | ||
76 | +def beam(args, dec_input, enc_input_index, model, first, device, k_, LABEL): | ||
77 | + temp_dec_input = torch.zeros([k_,1], dtype=torch.long) | ||
78 | + temp_dec_input = temp_dec_input + dec_input | ||
79 | + deliver_high_beam_value = torch.zeros([k_,1], dtype=torch.long) | ||
80 | + return_sentence_beamVal_pair = [] | ||
81 | + check_k = [float('inf')] * k_ | ||
82 | + sequences = [[list(), 1.0]] | ||
83 | + end_sentence = [] | ||
84 | + end_idx = [] | ||
85 | + | ||
86 | + if first: | ||
87 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
88 | + first_beam_sequence = Beam_Search(y_pred, k_, True, sequences) | ||
89 | + | ||
90 | + for i in range(len(deliver_high_beam_value)): | ||
91 | + deliver_high_beam_value[i] = first_beam_sequence[i][0][0] | ||
92 | + | ||
93 | + temp_dec_input = torch.cat( | ||
94 | + [temp_dec_input.to(torch.device('cpu')), | ||
95 | + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1) | ||
96 | + | ||
97 | + check_num = 0 | ||
98 | + beam_input_sequence = first_beam_sequence | ||
99 | + | ||
100 | + for i in range(args.max_len): | ||
101 | + which_value = [float('inf')] * k_ | ||
102 | + which_node = [0] * k_ | ||
103 | + | ||
104 | + for j in range(len(temp_dec_input)): | ||
105 | + if temp_dec_input[j][-1] == torch.LongTensor([3]): | ||
106 | + continue | ||
107 | + y_pred = model(enc_input_index.to(device), temp_dec_input[j].unsqueeze(0).to(device)) | ||
108 | + beam_seq = Beam_Search(y_pred.squeeze(0)[-1], k_, False, [beam_input_sequence[j]]) | ||
109 | + | ||
110 | + beam_input_sequence[j] = [[beam_seq[0][0][-1]], beam_seq[0][1]] | ||
111 | + which_node[j] = beam_seq[0][0][-1] # k개의 output중 누적확률 높은 거 get | ||
112 | + | ||
113 | + for l in range(len(deliver_high_beam_value)): | ||
114 | + if temp_dec_input[j][-1] == torch.LongTensor([3]): | ||
115 | + continue | ||
116 | + deliver_high_beam_value[l] = which_node[l] | ||
117 | + | ||
118 | + temp_dec_input = torch.cat( | ||
119 | + [temp_dec_input.to(torch.device('cpu')), | ||
120 | + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1) | ||
121 | + | ||
122 | + for x in range(len(temp_dec_input)): | ||
123 | + for y in range(len(temp_dec_input[x])): | ||
124 | + if temp_dec_input[x][y] == torch.LongTensor([3]) and check_k[x] == float('inf'): | ||
125 | + check_k[x] = beam_input_sequence[x][1] | ||
126 | + | ||
127 | + if i+1 == args.max_len: | ||
128 | + for k in range(k_): | ||
129 | + for kk in range(len(temp_dec_input[k])): | ||
130 | + if temp_dec_input[k][kk] == torch.LongTensor([3]): | ||
131 | + check_num += 1 | ||
132 | + end_sentence.append(temp_dec_input[k]) | ||
133 | + end_idx.append(k) | ||
134 | + break | ||
135 | + | ||
136 | + for l in range(len(end_sentence)): | ||
137 | + pred = [] | ||
138 | + for idx in range(len(end_sentence[l])): | ||
139 | + | ||
140 | + if end_sentence[l][idx] == torch.LongTensor([3]): | ||
141 | + pred_sentence = "".join(pred) | ||
142 | + pred_str = spacer.space(pred_sentence) | ||
143 | + #print(pred_str, " |", check_k[end_idx[l]]) | ||
144 | + return_sentence_beamVal_pair.append([pred_str, check_k[end_idx[l]]]) | ||
145 | + break | ||
146 | + else: | ||
147 | + if idx == 0: | ||
148 | + continue | ||
149 | + pred.append(LABEL.vocab.itos[end_sentence[l][idx]]) | ||
150 | + return return_sentence_beamVal_pair | ||
6 | 151 | ||
7 | def inference(device, args, TEXT, LABEL, model, sa_model): | 152 | def inference(device, args, TEXT, LABEL, model, sa_model): |
8 | from KoBERT.Sentiment_Analysis_BERT_main import bert_inference | 153 | from KoBERT.Sentiment_Analysis_BERT_main import bert_inference |
... | @@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model): | ... | @@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model): |
40 | 185 | ||
41 | model.eval() | 186 | model.eval() |
42 | pred = [] | 187 | pred = [] |
188 | + | ||
189 | + beam_k = 10 | ||
190 | + beam_sen_val_pair = beam(args, dec_input, enc_input_index, model, True, device, beam_k, LABEL) | ||
191 | + greedy_pair = [] | ||
43 | for i in range(args.max_len): | 192 | for i in range(args.max_len): |
44 | y_pred = model(enc_input_index.to(device), dec_input.to(device)) | 193 | y_pred = model(enc_input_index.to(device), dec_input.to(device)) |
194 | + if i == 0: | ||
195 | + score = cal_score(y_pred.squeeze(0)[-1], 1.0) | ||
196 | + else: | ||
197 | + score = cal_score(y_pred.squeeze(0)[-1], score ) | ||
198 | + | ||
45 | y_pred_ids = y_pred.max(dim=-1)[1] | 199 | y_pred_ids = y_pred.max(dim=-1)[1] |
200 | + | ||
46 | if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): | 201 | if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): |
47 | y_pred_ids = y_pred_ids.squeeze(0) | 202 | y_pred_ids = y_pred_ids.squeeze(0) |
48 | - print(">", end=" ") | 203 | + #print(">", end=" ") |
49 | for idx in range(len(y_pred_ids)): | 204 | for idx in range(len(y_pred_ids)): |
50 | if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | 205 | if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': |
51 | pred_sentence = "".join(pred) | 206 | pred_sentence = "".join(pred) |
52 | pred_str = spacer.space(pred_sentence) | 207 | pred_str = spacer.space(pred_sentence) |
53 | - print(pred_str) | 208 | + #print(pred_str, " |", score) |
209 | + greedy_pair.append([pred_str, score]) | ||
54 | break | 210 | break |
55 | else: | 211 | else: |
56 | pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | 212 | pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) |
213 | + | ||
214 | + compair_beam_and_greedy(beam_sen_val_pair, greedy_pair) | ||
57 | return 0 | 215 | return 0 |
58 | 216 | ||
59 | dec_input = torch.cat( | 217 | dec_input = torch.cat( |
60 | [dec_input.to(torch.device('cpu')), | 218 | [dec_input.to(torch.device('cpu')), |
61 | y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | 219 | y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) |
62 | - return 0 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
220 | + | ||
221 | + | ... | ... |
... | @@ -51,8 +51,8 @@ def data_preprocessing(args, device): | ... | @@ -51,8 +51,8 @@ def data_preprocessing(args, device): |
51 | # TEXT, LABEL 에 필요한 special token 만듦. | 51 | # TEXT, LABEL 에 필요한 special token 만듦. |
52 | text_specials, label_specials = make_special_token(args) | 52 | text_specials, label_specials = make_special_token(args) |
53 | 53 | ||
54 | - TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials) | 54 | + TEXT.build_vocab(train_data,vectors=vectors, max_size=15000, specials=text_specials) |
55 | - LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials) | 55 | + LABEL.build_vocab(train_data,vectors=vectors, max_size=15000, specials=label_specials) |
56 | 56 | ||
57 | train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | 57 | train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) |
58 | test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | 58 | test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ... | ... |
-
Please register or login to post a comment