bongminkim

update_pyfile

...@@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì ...@@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì
15 parser.add_argument('--batch_size', type=int, default=256) 15 parser.add_argument('--batch_size', type=int, default=256)
16 parser.add_argument('--num_epochs', type=int, default=22) 16 parser.add_argument('--num_epochs', type=int, default=22)
17 parser.add_argument('--warming_up_epochs', type=int, default=5) 17 parser.add_argument('--warming_up_epochs', type=int, default=5)
18 -parser.add_argument('--lr', type=float, default=0.0002) 18 +parser.add_argument('--lr', type=float, default=0.0002)#0.0002
19 parser.add_argument('--embedding_dim', type=int, default=160) 19 parser.add_argument('--embedding_dim', type=int, default=160)
20 parser.add_argument('--nlayers', type=int, default=2) 20 parser.add_argument('--nlayers', type=int, default=2)
21 parser.add_argument('--nhead', type=int, default=2) 21 parser.add_argument('--nhead', type=int, default=2)
22 parser.add_argument('--dropout', type=float, default=0.1) 22 parser.add_argument('--dropout', type=float, default=0.1)
23 parser.add_argument('--train', type=bool, default=True) 23 parser.add_argument('--train', type=bool, default=True)
24 -parser.add_argument('--per_soft', type=bool, default=False) 24 +parser.add_argument('--per_soft', type=bool, default=True)
25 parser.add_argument('--per_rough', type=bool, default=False) 25 parser.add_argument('--per_rough', type=bool, default=False)
26 args = parser.parse_args() 26 args = parser.parse_args()
27 27
...@@ -30,6 +30,7 @@ def epoch_time(start_time, end_time): ...@@ -30,6 +30,7 @@ def epoch_time(start_time, end_time):
30 elapsed_time = end_time - start_time 30 elapsed_time = end_time - start_time
31 elapsed_mins = int(elapsed_time / 60) 31 elapsed_mins = int(elapsed_time / 60)
32 elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 32 elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
33 +
33 return elapsed_mins, elapsed_secs 34 return elapsed_mins, elapsed_secs
34 35
35 # 학습 36 # 학습
......
1 import torch 1 import torch
2 +import torch.nn.functional as F
3 +from math import log
4 +from numpy import array
2 from get_data import tokenizer1 5 from get_data import tokenizer1
3 from torch.autograd import Variable 6 from torch.autograd import Variable
4 from chatspace import ChatSpace 7 from chatspace import ChatSpace
5 spacer = ChatSpace() 8 spacer = ChatSpace()
9 +from konlpy.tag import Mecab
10 +import re
11 +
12 +def tokenizer1(text):
13 + result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text)
14 + a = Mecab().morphs(result_text)
15 + return ([a[i] for i in range(len(a))])
16 +
17 +def _get_length_penalty(text, alpha=1.2, min_length=5):
18 + p_list = []
19 + for i in range(len(text)):
20 + temp_text = tokenizer1(text[i][0])
21 + length = len(temp_text)
22 + p_list.append(((5 + length) ** alpha) / (5 + 1) ** alpha)
23 +
24 + lp_list = [ text[j][1]/p_list[j] for j in range(len(text)) ]
25 + return lp_list
26 +
27 +def compair_beam_and_greedy(beam_pair, greedy_pair):
28 + lp_list = _get_length_penalty(beam_pair)
29 + gr_lp_list = _get_length_penalty(greedy_pair)
30 +
31 + low_val = float('inf')
32 + checked_sen = ""
33 + for idx in range(len(beam_pair)):
34 + if lp_list[idx] < low_val:
35 + low_val = lp_list[idx]
36 + checked_sen = beam_pair[idx][0]
37 +
38 + print(" beam output > ", checked_sen, " |", low_val)
39 + print("greedy output > ", greedy_pair[0][0]," |", gr_lp_list[0])
40 + if low_val < gr_lp_list[0]:
41 + print("use beam")
42 + else:
43 + print("use greedy")
44 +
45 +def cal_score(pred, score):
46 +
47 + pred = F.softmax(pred, dim=-1)
48 + pred_ids = pred.max(dim=-1)[0]
49 + pred_ids = pred_ids.to('cpu').tolist()
50 + score = score * -log(pred_ids)
51 +
52 + return score
53 +
54 +def Beam_Search(data, k, first, sequences):
55 + #sequences = [[list(), 1.0]]
56 +
57 + if first:
58 + data = data.squeeze(0)
59 + else:
60 + data = data.unsqueeze(0)
61 + data = F.softmax(data, dim=-1)
62 +
63 + for row in data:
64 + all_candidates = list()
65 + for i in range(len(sequences)):
66 + seq, score = sequences[i]
67 + for j in range(len(row)):
68 + no_tensor_row = row[j].to('cpu').tolist()
69 + candidate = [seq + [j], score * -log(no_tensor_row)]
70 + all_candidates.append(candidate)
71 + ordered = sorted(all_candidates, key=lambda tup:tup[1])
72 + sequences = ordered[:k]
73 +
74 + return(sequences)
75 +
76 +def beam(args, dec_input, enc_input_index, model, first, device, k_, LABEL):
77 + temp_dec_input = torch.zeros([k_,1], dtype=torch.long)
78 + temp_dec_input = temp_dec_input + dec_input
79 + deliver_high_beam_value = torch.zeros([k_,1], dtype=torch.long)
80 + return_sentence_beamVal_pair = []
81 + check_k = [float('inf')] * k_
82 + sequences = [[list(), 1.0]]
83 + end_sentence = []
84 + end_idx = []
85 +
86 + if first:
87 + y_pred = model(enc_input_index.to(device), dec_input.to(device))
88 + first_beam_sequence = Beam_Search(y_pred, k_, True, sequences)
89 +
90 + for i in range(len(deliver_high_beam_value)):
91 + deliver_high_beam_value[i] = first_beam_sequence[i][0][0]
92 +
93 + temp_dec_input = torch.cat(
94 + [temp_dec_input.to(torch.device('cpu')),
95 + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1)
96 +
97 + check_num = 0
98 + beam_input_sequence = first_beam_sequence
99 +
100 + for i in range(args.max_len):
101 + which_value = [float('inf')] * k_
102 + which_node = [0] * k_
103 +
104 + for j in range(len(temp_dec_input)):
105 + if temp_dec_input[j][-1] == torch.LongTensor([3]):
106 + continue
107 + y_pred = model(enc_input_index.to(device), temp_dec_input[j].unsqueeze(0).to(device))
108 + beam_seq = Beam_Search(y_pred.squeeze(0)[-1], k_, False, [beam_input_sequence[j]])
109 +
110 + beam_input_sequence[j] = [[beam_seq[0][0][-1]], beam_seq[0][1]]
111 + which_node[j] = beam_seq[0][0][-1] # k개의 output중 누적확률 높은 거 get
112 +
113 + for l in range(len(deliver_high_beam_value)):
114 + if temp_dec_input[j][-1] == torch.LongTensor([3]):
115 + continue
116 + deliver_high_beam_value[l] = which_node[l]
117 +
118 + temp_dec_input = torch.cat(
119 + [temp_dec_input.to(torch.device('cpu')),
120 + deliver_high_beam_value.to(torch.device('cpu'))], dim=-1)
121 +
122 + for x in range(len(temp_dec_input)):
123 + for y in range(len(temp_dec_input[x])):
124 + if temp_dec_input[x][y] == torch.LongTensor([3]) and check_k[x] == float('inf'):
125 + check_k[x] = beam_input_sequence[x][1]
126 +
127 + if i+1 == args.max_len:
128 + for k in range(k_):
129 + for kk in range(len(temp_dec_input[k])):
130 + if temp_dec_input[k][kk] == torch.LongTensor([3]):
131 + check_num += 1
132 + end_sentence.append(temp_dec_input[k])
133 + end_idx.append(k)
134 + break
135 +
136 + for l in range(len(end_sentence)):
137 + pred = []
138 + for idx in range(len(end_sentence[l])):
139 +
140 + if end_sentence[l][idx] == torch.LongTensor([3]):
141 + pred_sentence = "".join(pred)
142 + pred_str = spacer.space(pred_sentence)
143 + #print(pred_str, " |", check_k[end_idx[l]])
144 + return_sentence_beamVal_pair.append([pred_str, check_k[end_idx[l]]])
145 + break
146 + else:
147 + if idx == 0:
148 + continue
149 + pred.append(LABEL.vocab.itos[end_sentence[l][idx]])
150 + return return_sentence_beamVal_pair
6 151
7 def inference(device, args, TEXT, LABEL, model, sa_model): 152 def inference(device, args, TEXT, LABEL, model, sa_model):
8 from KoBERT.Sentiment_Analysis_BERT_main import bert_inference 153 from KoBERT.Sentiment_Analysis_BERT_main import bert_inference
...@@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model): ...@@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model):
40 185
41 model.eval() 186 model.eval()
42 pred = [] 187 pred = []
188 +
189 + beam_k = 10
190 + beam_sen_val_pair = beam(args, dec_input, enc_input_index, model, True, device, beam_k, LABEL)
191 + greedy_pair = []
43 for i in range(args.max_len): 192 for i in range(args.max_len):
44 y_pred = model(enc_input_index.to(device), dec_input.to(device)) 193 y_pred = model(enc_input_index.to(device), dec_input.to(device))
194 + if i == 0:
195 + score = cal_score(y_pred.squeeze(0)[-1], 1.0)
196 + else:
197 + score = cal_score(y_pred.squeeze(0)[-1], score )
198 +
45 y_pred_ids = y_pred.max(dim=-1)[1] 199 y_pred_ids = y_pred.max(dim=-1)[1]
200 +
46 if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): 201 if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']):
47 y_pred_ids = y_pred_ids.squeeze(0) 202 y_pred_ids = y_pred_ids.squeeze(0)
48 - print(">", end=" ") 203 + #print(">", end=" ")
49 for idx in range(len(y_pred_ids)): 204 for idx in range(len(y_pred_ids)):
50 if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': 205 if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
51 pred_sentence = "".join(pred) 206 pred_sentence = "".join(pred)
52 pred_str = spacer.space(pred_sentence) 207 pred_str = spacer.space(pred_sentence)
53 - print(pred_str) 208 + #print(pred_str, " |", score)
209 + greedy_pair.append([pred_str, score])
54 break 210 break
55 else: 211 else:
56 pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) 212 pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
213 +
214 + compair_beam_and_greedy(beam_sen_val_pair, greedy_pair)
57 return 0 215 return 0
58 216
59 dec_input = torch.cat( 217 dec_input = torch.cat(
60 [dec_input.to(torch.device('cpu')), 218 [dec_input.to(torch.device('cpu')),
61 y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) 219 y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
62 - return 0
...\ No newline at end of file ...\ No newline at end of file
220 +
221 +
......
...@@ -51,8 +51,8 @@ def data_preprocessing(args, device): ...@@ -51,8 +51,8 @@ def data_preprocessing(args, device):
51 # TEXT, LABEL 에 필요한 special token 만듦. 51 # TEXT, LABEL 에 필요한 special token 만듦.
52 text_specials, label_specials = make_special_token(args) 52 text_specials, label_specials = make_special_token(args)
53 53
54 - TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials) 54 + TEXT.build_vocab(train_data,vectors=vectors, max_size=15000, specials=text_specials)
55 - LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials) 55 + LABEL.build_vocab(train_data,vectors=vectors, max_size=15000, specials=label_specials)
56 56
57 train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) 57 train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
58 test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) 58 test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
......