Showing
5 changed files
with
862 additions
and
0 deletions
Chatbot/Styling.py
0 → 100644
1 | +import torch | ||
2 | +import csv | ||
3 | +import hgtk | ||
4 | +from konlpy.tag import Mecab | ||
5 | +import random | ||
6 | + | ||
7 | +mecab = Mecab() | ||
8 | +empty_list = [] | ||
9 | +positive_emo = ['ㅎㅎ', '~'] | ||
10 | +negative_emo = ['...', 'ㅠㅠ'] | ||
11 | +asdf = [] | ||
12 | + | ||
13 | +# mecab 을 통한 형태소 분석. | ||
14 | +def mecab_token_pos_flat_fn(string): | ||
15 | + tokens_ko = mecab.pos(string) | ||
16 | + return [str(pos[0]) + '/' + str(pos[1]) for pos in tokens_ko] | ||
17 | + | ||
18 | +# rough 를 위한 함수. 대명사 NP (저, 제) 를 찾아 나 or 내 로 바꿔준다. | ||
19 | +def exchange_NP(target, args): | ||
20 | + keyword = [] | ||
21 | + ko_sp = mecab_token_pos_flat_fn(target) | ||
22 | + for idx, word in enumerate(ko_sp): | ||
23 | + if word.find('NP') > 0: | ||
24 | + keyword.append(word.split('/')) | ||
25 | + _idx = idx | ||
26 | + break | ||
27 | + if keyword == []: | ||
28 | + return '', -1, False | ||
29 | + | ||
30 | + if keyword[0][0] == '저': | ||
31 | + keyword[0][0] = '나' | ||
32 | + elif keyword[0][0] == '제': | ||
33 | + keyword[0][0] = '내' | ||
34 | + else: | ||
35 | + return keyword[0], _idx, False | ||
36 | + | ||
37 | + return keyword[0][0], _idx, True | ||
38 | + | ||
39 | +# 단어를 soft or rough 말투로 바꾸는 과정 | ||
40 | +def make_special_word(target, args, search_ec): | ||
41 | + # mecab 를 통해 문장을 구분 (example output : ['오늘/MAG', '날씨/NNG', '좋/VA', '다/EF', './SF']) | ||
42 | + ko_sp = mecab_token_pos_flat_fn(target) | ||
43 | + | ||
44 | + keyword = [] | ||
45 | + | ||
46 | + # word 에 종결어미 'EF' or 'EC' 가 포함 되어 있을 경우 index 와 keyword 추출. | ||
47 | + for idx, word in enumerate(ko_sp): | ||
48 | + if word.find('EF') > 0: | ||
49 | + keyword.append(word.split('/')) | ||
50 | + _idx = idx | ||
51 | + break | ||
52 | + if search_ec: | ||
53 | + if ko_sp[-2].find('EC') > 0: | ||
54 | + keyword.append(ko_sp[-2].split('/')) | ||
55 | + _idx = len(ko_sp) -1 | ||
56 | + break | ||
57 | + else: | ||
58 | + continue | ||
59 | + | ||
60 | + # 'EF'가 없을 시 return. | ||
61 | + if keyword == []: | ||
62 | + return '', -1 | ||
63 | + else: | ||
64 | + keyword = keyword[0] | ||
65 | + | ||
66 | + if args.per_rough: | ||
67 | + return keyword[0], _idx | ||
68 | + | ||
69 | + # hgtk 를 사용하여 keyword 를 쪼갬. (ex output : 하ᴥ세요) | ||
70 | + h_separation = hgtk.text.decompose(keyword[0]) | ||
71 | + total_word = '' | ||
72 | + | ||
73 | + for idx, word in enumerate(h_separation): | ||
74 | + total_word += word | ||
75 | + | ||
76 | + # 'EF' 에 종성 'ㅇ' 를 붙여 Styling | ||
77 | + total_word = replaceRight(total_word, "ᴥ", "ㅇᴥ", 1) | ||
78 | + | ||
79 | + # 다 이어 붙임. ' 하세요 -> 하세용 ' 으로 변환. | ||
80 | + h_combine = hgtk.text.compose(total_word) | ||
81 | + | ||
82 | + return h_combine, _idx | ||
83 | + | ||
84 | +# special token 을 만드는 함수 | ||
85 | +def make_special_token(args): | ||
86 | + # 감정을 나타내기 위한 special token | ||
87 | + target_special_voca=[] | ||
88 | + | ||
89 | + banmal_dict = get_rough_dic() | ||
90 | + | ||
91 | + # train data set 의 chatbot answer 에서 'EF' 를 뽑아 종성 'ㅇ' 을 붙인 special token 생성 | ||
92 | + with open('chatbot_0325_ALLLABEL_train.txt', 'r', encoding='utf-8') as f: | ||
93 | + rdr = csv.reader(f, delimiter='\t') | ||
94 | + for idx, line in enumerate(rdr): | ||
95 | + target = line[2] # chatbot answer | ||
96 | + exchange_word, _ = make_special_word(target, args, False) | ||
97 | + target_special_voca.append(str(exchange_word)) | ||
98 | + target_special_voca = list(set(target_special_voca)) | ||
99 | + | ||
100 | + banmal_special_voca = [] | ||
101 | + for i in range(len(target_special_voca)): | ||
102 | + try: | ||
103 | + banmal_special_voca.append(banmal_dict[target_special_voca[i]]) | ||
104 | + except KeyError: | ||
105 | + if args.per_rough: | ||
106 | + print("not include banmal dictionary") | ||
107 | + pass | ||
108 | + | ||
109 | + # 임의 이모티콘 추가. | ||
110 | + target_special_voca.append('ㅎㅎ') | ||
111 | + target_special_voca.append('~') | ||
112 | + target_special_voca.append('ㅠㅠ') | ||
113 | + target_special_voca.append('...') | ||
114 | + target_special_voca = target_special_voca + banmal_special_voca | ||
115 | + | ||
116 | + # '<posi> : positive, <nega> : negative' 를 의미 | ||
117 | + return ['<posi>', '<nega>'], target_special_voca | ||
118 | + | ||
119 | +# python string 함수 replace 를 오른쪽부터 시작하는 함수. | ||
120 | +def replaceRight(original, old, new, count_right): | ||
121 | + repeat = 0 | ||
122 | + text = original | ||
123 | + | ||
124 | + count_find = original.count(old) | ||
125 | + if count_right > count_find: # 바꿀 횟수가 문자열에 포함된 old보다 많다면 | ||
126 | + repeat = count_find # 문자열에 포함된 old의 모든 개수(count_find)만큼 교체한다 | ||
127 | + else: | ||
128 | + repeat = count_right # 아니라면 입력받은 개수(count)만큼 교체한다 | ||
129 | + | ||
130 | + for _ in range(repeat): | ||
131 | + find_index = text.rfind(old) # 오른쪽부터 index를 찾기위해 rfind 사용 | ||
132 | + text = text[:find_index] + new + text[find_index + 1:] | ||
133 | + | ||
134 | + return text | ||
135 | + | ||
136 | +# transformer 에 input 과 output 으로 들어갈 tensor Styling 변환. | ||
137 | +def styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL): | ||
138 | + | ||
139 | + pad_tensor = torch.tensor([LABEL.vocab.stoi['<pad>']]).type(dtype=torch.int32).cuda() | ||
140 | + | ||
141 | + temp_enc = enc_input.data.cpu().numpy() | ||
142 | + batch_sentiment_list = [] | ||
143 | + | ||
144 | + # 부드러운 성격 | ||
145 | + if args.per_soft: | ||
146 | + # encoder input : 나는 너를 좋아해 <posi> <pad> <pad> ... - 형식으로 바꿔줌. | ||
147 | + for i in range(len(temp_enc)): | ||
148 | + for j in range(args.max_len): | ||
149 | + if temp_enc[i][j] == 1 and enc_label[i] == 0: | ||
150 | + temp_enc[i][j] = TEXT.vocab.stoi["<nega>"] | ||
151 | + batch_sentiment_list.append(0) | ||
152 | + break | ||
153 | + elif temp_enc[i][j] == 1 and enc_label[i] == 1: | ||
154 | + temp_enc[i][j] = TEXT.vocab.stoi["<posi>"] | ||
155 | + batch_sentiment_list.append(1) | ||
156 | + break | ||
157 | + | ||
158 | + enc_input = torch.tensor(temp_enc, dtype=torch.int32).cuda() | ||
159 | + | ||
160 | + for i in range(len(dec_outputs)): | ||
161 | + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1) | ||
162 | + | ||
163 | + temp_dec = dec_outputs.data.cpu().numpy() | ||
164 | + | ||
165 | + dec_outputs_sentiment_list = [] # decoder 에 들어가 감정표현 저장. | ||
166 | + | ||
167 | + # decoder outputs : 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
168 | + for i in range(len(temp_dec)): # i = batch size | ||
169 | + temp_sentence = '' | ||
170 | + sa_ = batch_sentiment_list[i] | ||
171 | + if sa_ == 0: | ||
172 | + sa_ = random.choice(negative_emo) | ||
173 | + elif sa_ == 1: | ||
174 | + sa_ = random.choice(positive_emo) | ||
175 | + dec_outputs_sentiment_list.append(sa_) | ||
176 | + | ||
177 | + for ix, token_i in enumerate(temp_dec[i]): | ||
178 | + if LABEL.vocab.itos[token_i] == '<sos>' or LABEL.vocab.itos[token_i] == '<eos>' or LABEL.vocab.itos[token_i] == '<pad>': | ||
179 | + continue | ||
180 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
181 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
182 | + exchange_word, idx = make_special_word(temp_sentence, args, True) | ||
183 | + | ||
184 | + if exchange_word == '': | ||
185 | + for j in range(len(temp_dec[i])): | ||
186 | + if temp_dec[i][j] == LABEL.vocab.stoi['<eos>']: | ||
187 | + temp_dec[i][j] = LABEL.vocab.stoi[sa_] | ||
188 | + temp_dec[i][j+1] = LABEL.vocab.stoi['<eos>'] | ||
189 | + break | ||
190 | + continue | ||
191 | + | ||
192 | + for j in range(len(temp_dec[i])): | ||
193 | + if LABEL.vocab.itos[temp_dec[i][j]] == '<eos>': | ||
194 | + temp_dec[i][j - 1] = LABEL.vocab.stoi[exchange_word] | ||
195 | + temp_dec[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
196 | + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>'] | ||
197 | + break | ||
198 | + elif temp_dec[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec[i]): | ||
199 | + print("\t-ERROR- No <EOS> token") | ||
200 | + exit() | ||
201 | + | ||
202 | + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).cuda() | ||
203 | + | ||
204 | + temp_dec_input = dec_input.data.cpu().numpy() | ||
205 | + # decoder input : <sos> 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
206 | + for i in range(len(temp_dec_input)): | ||
207 | + temp_sentence = '' | ||
208 | + for ix, token_i in enumerate(temp_dec_input[i]): | ||
209 | + if LABEL.vocab.itos[token_i] == '<sos>' or LABEL.vocab.itos[token_i] == '<eos>' or LABEL.vocab.itos[token_i] == '<pad>': | ||
210 | + continue | ||
211 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
212 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
213 | + exchange_word, idx = make_special_word(temp_sentence, args, True) | ||
214 | + | ||
215 | + if exchange_word == '': | ||
216 | + for j in range(len(temp_dec_input[i])): | ||
217 | + if temp_dec_input[i][j] == LABEL.vocab.stoi['<eos>']: | ||
218 | + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
219 | + temp_dec_input[i][j+1] = LABEL.vocab.stoi['<eos>'] | ||
220 | + break | ||
221 | + continue | ||
222 | + | ||
223 | + for j in range(len(temp_dec_input[i])): | ||
224 | + if LABEL.vocab.itos[temp_dec_input[i][j]] == '<eos>': | ||
225 | + temp_dec_input[i][j-1] = LABEL.vocab.stoi[exchange_word] | ||
226 | + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
227 | + temp_dec_input[i][j+1] = LABEL.vocab.stoi['<eos>'] | ||
228 | + break | ||
229 | + elif temp_dec_input[i][j] != LABEL.vocab.stoi['<eos>'] and j+1 == len(temp_dec_input[i]): | ||
230 | + print("\t-ERROR- No <EOS> token") | ||
231 | + exit() | ||
232 | + | ||
233 | + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).cuda() | ||
234 | + | ||
235 | + # 거친 성격 | ||
236 | + elif args.per_rough: | ||
237 | + banmal_dic = get_rough_dic() | ||
238 | + | ||
239 | + for i in range(len(dec_outputs)): | ||
240 | + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1) | ||
241 | + | ||
242 | + temp_dec = dec_outputs.data.cpu().numpy() | ||
243 | + | ||
244 | + # decoder outputs : 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
245 | + for i in range(len(temp_dec)): # i = batch size | ||
246 | + temp_sentence = '' | ||
247 | + for ix, token_i in enumerate(temp_dec[i]): | ||
248 | + if LABEL.vocab.itos[token_i] == '<eos>': | ||
249 | + break | ||
250 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
251 | + temp_sentence = temp_sentence+'.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
252 | + exchange_word, idx = make_special_word(temp_sentence, args, True) | ||
253 | + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence, args) | ||
254 | + | ||
255 | + if exist: | ||
256 | + temp_dec[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word] | ||
257 | + | ||
258 | + if exchange_word == '': | ||
259 | + continue | ||
260 | + try: | ||
261 | + exchange_word = banmal_dic[exchange_word] | ||
262 | + except KeyError: | ||
263 | + asdf.append(exchange_word) | ||
264 | + print("not include banmal dictionary") | ||
265 | + pass | ||
266 | + | ||
267 | + temp_dec[i][idx] = LABEL.vocab.stoi[exchange_word] | ||
268 | + temp_dec[i][idx+1] = LABEL.vocab.stoi['<eos>'] | ||
269 | + for k in range(idx+2, args.max_len): | ||
270 | + temp_dec[i][k] = LABEL.vocab.stoi['<pad>'] | ||
271 | + | ||
272 | + # for j in range(len(temp_dec[i])): | ||
273 | + # if LABEL.vocab.itos[temp_dec[i][j]]=='<eos>': | ||
274 | + # break | ||
275 | + # print(LABEL.vocab.itos[temp_dec[i][j]], end='') | ||
276 | + # print() | ||
277 | + | ||
278 | + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).cuda() | ||
279 | + | ||
280 | + temp_dec_input = dec_input.data.cpu().numpy() | ||
281 | + # decoder input : <sos> 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
282 | + for i in range(len(temp_dec_input)): | ||
283 | + temp_sentence = '' | ||
284 | + for ix, token_i in enumerate(temp_dec_input[i]): | ||
285 | + if ix == 0 : | ||
286 | + continue # because of token <sos> | ||
287 | + if LABEL.vocab.itos[token_i] == '<eos>': | ||
288 | + break | ||
289 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
290 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
291 | + exchange_word, idx = make_special_word(temp_sentence, args, True) | ||
292 | + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence, args) | ||
293 | + idx = idx + 1 # because of token <sos> | ||
294 | + NP_idx = NP_idx + 1 | ||
295 | + | ||
296 | + if exist: | ||
297 | + temp_dec_input[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word] | ||
298 | + | ||
299 | + if exchange_word == '': | ||
300 | + continue | ||
301 | + | ||
302 | + try: | ||
303 | + exchange_word = banmal_dic[exchange_word] | ||
304 | + except KeyError: | ||
305 | + print("not include banmal dictionary") | ||
306 | + pass | ||
307 | + | ||
308 | + temp_dec_input[i][idx] = LABEL.vocab.stoi[exchange_word] | ||
309 | + temp_dec_input[i][idx + 1] = LABEL.vocab.stoi['<eos>'] | ||
310 | + | ||
311 | + for k in range(idx+2, args.max_len): | ||
312 | + temp_dec_input[i][k] = LABEL.vocab.stoi['<pad>'] | ||
313 | + | ||
314 | + # for j in range(len(temp_dec_input[i])): | ||
315 | + # if LABEL.vocab.itos[temp_dec_input[i][j]]=='<eos>': | ||
316 | + # break | ||
317 | + # print(LABEL.vocab.itos[temp_dec_input[i][j]], end='') | ||
318 | + # print() | ||
319 | + | ||
320 | + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).cuda() | ||
321 | + | ||
322 | + return enc_input, dec_input, dec_outputs | ||
323 | + | ||
324 | +# 반말로 바꾸기위한 딕셔너리 | ||
325 | +def get_rough_dic(): | ||
326 | + my_exword = { | ||
327 | + '돌아와요': '돌아와', | ||
328 | + '으세요': '으셈', | ||
329 | + '잊어버려요': '잊어버려', | ||
330 | + '나온대요': '나온대', | ||
331 | + '될까요': '될까', | ||
332 | + '할텐데': '할텐데', | ||
333 | + '옵니다': '온다', | ||
334 | + '봅니다': '본다', | ||
335 | + '네요': '네', | ||
336 | + '된답니다': '된대', | ||
337 | + '데요': '데', | ||
338 | + '봐요': '봐', | ||
339 | + '부러워요': '부러워', | ||
340 | + '바랄게요': '바랄게', | ||
341 | + '지나갑니다': "지가간다", | ||
342 | + '이뻐요': "이뻐", | ||
343 | + '지요': "지", | ||
344 | + '사세요': "사라", | ||
345 | + '던가요': "던가", | ||
346 | + '모릅니다': "몰라", | ||
347 | + '은가요': "은가", | ||
348 | + '심해요': "심해", | ||
349 | + '몰라요': "몰라", | ||
350 | + '라요': "라", | ||
351 | + '더라고요': '더라고', | ||
352 | + '입니다': '이라고', | ||
353 | + '는다면요': '는다면', | ||
354 | + '멋져요': '멋져', | ||
355 | + '다면요': '다면', | ||
356 | + '다니': '다나', | ||
357 | + '져요': '져', | ||
358 | + '만드세요': '만들어', | ||
359 | + '야죠': '야지', | ||
360 | + '죠': '지', | ||
361 | + '해줄게요': '해줄게', | ||
362 | + '대요': '대', | ||
363 | + '돌아갑시다': '돌아가자', | ||
364 | + '해보여요': '해봐', | ||
365 | + '라뇨': '라니', | ||
366 | + '편합니다': '편해', | ||
367 | + '합시다': '하자', | ||
368 | + '드세요': '먹어', | ||
369 | + '아름다워요': '아름답네', | ||
370 | + '드립니다': '줄게', | ||
371 | + '받아들여요': '받아들여', | ||
372 | + '건가요': '간기', | ||
373 | + '쏟아진다': '쏟아지네', | ||
374 | + '슬퍼요': '슬퍼', | ||
375 | + '해서요': '해서', | ||
376 | + '다릅니다': '다르다', | ||
377 | + '니다': '니', | ||
378 | + '내려요': '내려', | ||
379 | + '마셔요': '마셔', | ||
380 | + '아세요': '아냐', | ||
381 | + '변해요': '뱐헤', | ||
382 | + '드려요': '드려', | ||
383 | + '아요': '아', | ||
384 | + '어서요': '어서', | ||
385 | + '뜁니다': '뛴다', | ||
386 | + '속상해요': '속상해', | ||
387 | + '래요': '래', | ||
388 | + '까요': '까', | ||
389 | + '어야죠': '어야지', | ||
390 | + '라니': '라니', | ||
391 | + '해집니다': '해진다', | ||
392 | + '으련만': '으련만', | ||
393 | + '지워져요': '지워져', | ||
394 | + '잘라요': '잘라', | ||
395 | + '고요': '고', | ||
396 | + '셔야죠': '셔야지', | ||
397 | + '다쳐요': '다쳐', | ||
398 | + '는구나': '는구만', | ||
399 | + '은데요': '은데', | ||
400 | + '일까요': '일까', | ||
401 | + '인가요': '인가', | ||
402 | + '아닐까요': '아닐까', | ||
403 | + '텐데요': '텐데', | ||
404 | + '할게요': '할게', | ||
405 | + '보입니다': '보이네', | ||
406 | + '에요': '야', | ||
407 | + '걸요': '걸', | ||
408 | + '한답니다': '한대', | ||
409 | + '을까요': '을까', | ||
410 | + '못해요': '못해', | ||
411 | + '베푸세요': '베풀어', | ||
412 | + '어때요': '어떄', | ||
413 | + '더라구요': '더라구', | ||
414 | + '노라': '노라', | ||
415 | + '반가워요': '반가워', | ||
416 | + '군요': '군', | ||
417 | + '만납시다': '만나자', | ||
418 | + '어떠세요': '어때', | ||
419 | + '달라져요': '달라져', | ||
420 | + '예뻐요': '예뻐', | ||
421 | + '됩니다': '된다', | ||
422 | + '봅시다': '보자', | ||
423 | + '한대요': '한대', | ||
424 | + '싸워요': '싸워', | ||
425 | + '와요': '와', | ||
426 | + '인데요': '인데', | ||
427 | + '야': '야', | ||
428 | + '줄게요': '줄게', | ||
429 | + '기에요': '기', | ||
430 | + '던데요': '던데', | ||
431 | + '걸까요': '걸까', | ||
432 | + '신가요': '신가', | ||
433 | + '어요': '어', | ||
434 | + '따져요': '따져', | ||
435 | + '갈게요': '갈게', | ||
436 | + '봐': '봐', | ||
437 | + '나요': '나', | ||
438 | + '니까요': '니까', | ||
439 | + '마요': '마', | ||
440 | + '씁니다': '쓴다', | ||
441 | + '집니다': '진다', | ||
442 | + '건데요': '건데', | ||
443 | + '지웁시다': '지우자', | ||
444 | + '바랍니다': '바래', | ||
445 | + '는데요': '는데', | ||
446 | + '으니까요': '으니까', | ||
447 | + '셔요': '셔', | ||
448 | + '네여': '네', | ||
449 | + '달라요': '달라', | ||
450 | + '거려요': '거려', | ||
451 | + '보여요': '보여', | ||
452 | + '겁니다': '껄', | ||
453 | + '다': '다', | ||
454 | + '그래요': '그래', | ||
455 | + '한가요': '한가', | ||
456 | + '잖아요': '잖아', | ||
457 | + '한데요': '한데', | ||
458 | + '우세요': '우셈', | ||
459 | + '해야죠': '해야지', | ||
460 | + '세요': '셈', | ||
461 | + '걸려요': '걸려', | ||
462 | + '텐데': '텐데', | ||
463 | + '어딘가': '어딘가', | ||
464 | + '요': '', | ||
465 | + '흘러갑니다': '흘러간다', | ||
466 | + '줘요': '줘', | ||
467 | + '편해요': '편해', | ||
468 | + '거예요': '거야', | ||
469 | + '예요': '야', | ||
470 | + '습니다': '어', | ||
471 | + '아닌가요': '아닌가', | ||
472 | + '합니다': '한다', | ||
473 | + '사라집니다': '사라져', | ||
474 | + '드릴게요': '줄게', | ||
475 | + '다면': '다면', | ||
476 | + '그럴까요': '그럴까', | ||
477 | + '해요': '해', | ||
478 | + '답니다': '다', | ||
479 | + '주무세요': '자라', | ||
480 | + '마세요': '마라', | ||
481 | + '아픈가요': '아프냐', | ||
482 | + '그런가요': '그런가', | ||
483 | + '했잖아요': '했잖아', | ||
484 | + '버려요': '버려', | ||
485 | + '갑니다': '간다', | ||
486 | + '가요': '가', | ||
487 | + '라면요': '라면', | ||
488 | + '아야죠': '아야지', | ||
489 | + '살펴봐요': '살펴봐', | ||
490 | + '남겨요': '남겨', | ||
491 | + '내려놔요': '내려놔', | ||
492 | + '떨려요': '떨려', | ||
493 | + '랍니다': '란다', | ||
494 | + '돼요': '돼', | ||
495 | + '버텨요': '버텨', | ||
496 | + '만나': '만나', | ||
497 | + '일러요': '일러', | ||
498 | + '을게요': '을게', | ||
499 | + '갑시다': '가자', | ||
500 | + '나아요': '나아', | ||
501 | + '어려요': '어려', | ||
502 | + '온대요': '온대', | ||
503 | + '다고요': '다고', | ||
504 | + '할래요': '할래', | ||
505 | + '된대요': '된대', | ||
506 | + '어울려요': '어울려', | ||
507 | + '는군요': '는군', | ||
508 | + '볼까요': '볼까', | ||
509 | + '드릴까요': '줄까', | ||
510 | + '라던데요': '라던데', | ||
511 | + '올게요': '올게', | ||
512 | + '기뻐요': '기뻐', | ||
513 | + '아닙니다': '아냐', | ||
514 | + '둬요': '둬', | ||
515 | + '십니다': '십', | ||
516 | + '아파요': '아파', | ||
517 | + '생겨요': '생겨', | ||
518 | + '해줘요': '해줘', | ||
519 | + '로군요': '로군요', | ||
520 | + '시켜요': '시켜', | ||
521 | + '느껴져요': '느껴져', | ||
522 | + '가재요': '가재', | ||
523 | + '어 ': ' ', | ||
524 | + '느려요': '느려', | ||
525 | + '볼게요': '볼게', | ||
526 | + '쉬워요': '쉬워', | ||
527 | + '나빠요': '나빠', | ||
528 | + '불러줄게요': '불러줄게', | ||
529 | + '살쪄요': '살쪄', | ||
530 | + '봐야겠어요': '봐야겠어', | ||
531 | + '네': '네', | ||
532 | + '어': '어', | ||
533 | + '든지요': '든지', | ||
534 | + '드신다': '드심', | ||
535 | + '가져요': '가져', | ||
536 | + '할까요': '할까', | ||
537 | + '졸려요': '졸려', | ||
538 | + '그럴게요': '그럴게', | ||
539 | + '': '', | ||
540 | + '어린가': '어린가', | ||
541 | + '나와요': '나와', | ||
542 | + '빨라요': '빨라', | ||
543 | + '겠죠': '겠지', | ||
544 | + '졌어요': '졌어', | ||
545 | + '해봐요': '해봐', | ||
546 | + '게요': '게', | ||
547 | + '해드릴까요': '해줄까', | ||
548 | + '인걸요': '인걸', | ||
549 | + '했어요': '했어', | ||
550 | + '원해요': '원해', | ||
551 | + '는걸요': '는걸', | ||
552 | + '좋아합니다': '좋아해', | ||
553 | + '했으면': '했으면', | ||
554 | + '나갑니다': '나간다', | ||
555 | + '왔어요': '왔어', | ||
556 | + '해봅시다': '해보자', | ||
557 | + '물어봐요': '물어봐', | ||
558 | + '생겼어요': '생겼어', | ||
559 | + '해': '해', | ||
560 | + '다녀올게요': '다녀올게', | ||
561 | + '납시다': '나자' | ||
562 | + } | ||
563 | + return my_exword | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Chatbot/generation.py
0 → 100644
1 | +import torch | ||
2 | +from get_data import tokenizer1 | ||
3 | +from torch.autograd import Variable | ||
4 | +from chatspace import ChatSpace | ||
5 | +spacer = ChatSpace() | ||
6 | + | ||
7 | +def inference(device, args, TEXT, LABEL, model, sa_model): | ||
8 | + from KoBERT.Sentiment_Analysis_BERT_main import bert_inference | ||
9 | + sentence = input("문장을 입력하세요 : ") | ||
10 | + se_list = [sentence] | ||
11 | + | ||
12 | + # https://github.com/SKTBrain/KoBERT | ||
13 | + # SKT 에서 공개한 KoBert Sentiment Analysis 를 통해 입력문장의 긍정 부정 판단. | ||
14 | + sa_label = int(bert_inference(sa_model, se_list)) | ||
15 | + | ||
16 | + sa_token = '' | ||
17 | + # SA Label 에 따른 encoder input 변화. | ||
18 | + if sa_label == 0: | ||
19 | + sa_token = TEXT.vocab.stoi['<nega>'] | ||
20 | + else: | ||
21 | + sa_token = TEXT.vocab.stoi['<posi>'] | ||
22 | + | ||
23 | + enc_input = tokenizer1(sentence) | ||
24 | + enc_input_index = [] | ||
25 | + | ||
26 | + for tok in enc_input: | ||
27 | + enc_input_index.append(TEXT.vocab.stoi[tok]) | ||
28 | + | ||
29 | + # encoder input string to index tensor and plus <pad> | ||
30 | + if args.per_soft: | ||
31 | + enc_input_index.append(sa_token) | ||
32 | + | ||
33 | + for j in range(args.max_len - len(enc_input_index)): | ||
34 | + enc_input_index.append(TEXT.vocab.stoi['<pad>']) | ||
35 | + | ||
36 | + enc_input_index = Variable(torch.LongTensor([enc_input_index])) | ||
37 | + | ||
38 | + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]]) | ||
39 | + #print("긍정" if sa_label == 1 else "부정") | ||
40 | + | ||
41 | + model.eval() | ||
42 | + pred = [] | ||
43 | + for i in range(args.max_len): | ||
44 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
45 | + y_pred_ids = y_pred.max(dim=-1)[1] | ||
46 | + if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']): | ||
47 | + y_pred_ids = y_pred_ids.squeeze(0) | ||
48 | + print(">", end=" ") | ||
49 | + for idx in range(len(y_pred_ids)): | ||
50 | + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | ||
51 | + pred_sentence = "".join(pred) | ||
52 | + pred_str = spacer.space(pred_sentence) | ||
53 | + print(pred_str) | ||
54 | + break | ||
55 | + else: | ||
56 | + pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | ||
57 | + return 0 | ||
58 | + | ||
59 | + dec_input = torch.cat( | ||
60 | + [dec_input.to(torch.device('cpu')), | ||
61 | + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | ||
62 | + return 0 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Chatbot/get_data.py
0 → 100644
1 | +import torch | ||
2 | +from torchtext import data | ||
3 | +from torchtext.data import TabularDataset | ||
4 | +from torchtext.data import BucketIterator | ||
5 | +from torchtext.vocab import Vectors | ||
6 | +from konlpy.tag import Mecab | ||
7 | +import re | ||
8 | +from Styling import styling, make_special_token | ||
9 | + | ||
10 | +# tokenizer | ||
11 | +def tokenizer1(text): | ||
12 | + result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text) | ||
13 | + a = Mecab().morphs(result_text) | ||
14 | + return ([a[i] for i in range(len(a))]) | ||
15 | + | ||
16 | +# 데이터 전처리 및 loader return | ||
17 | +def data_preprocessing(args, device): | ||
18 | + | ||
19 | + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임. | ||
20 | + ID = data.Field(sequential=False, | ||
21 | + use_vocab=False) | ||
22 | + | ||
23 | + TEXT = data.Field(sequential=True, | ||
24 | + use_vocab=True, | ||
25 | + tokenize=tokenizer1, | ||
26 | + batch_first=True, | ||
27 | + fix_length=args.max_len, | ||
28 | + dtype=torch.int32 | ||
29 | + ) | ||
30 | + | ||
31 | + LABEL = data.Field(sequential=True, | ||
32 | + use_vocab=True, | ||
33 | + tokenize=tokenizer1, | ||
34 | + batch_first=True, | ||
35 | + fix_length=args.max_len, | ||
36 | + init_token='<sos>', | ||
37 | + eos_token='<eos>', | ||
38 | + dtype=torch.int32 | ||
39 | + ) | ||
40 | + | ||
41 | + SA = data.Field(sequential=False, | ||
42 | + use_vocab=False) | ||
43 | + | ||
44 | + train_data, test_data = TabularDataset.splits( | ||
45 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
46 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
47 | + ) | ||
48 | + | ||
49 | + vectors = Vectors(name="kr-projected.txt") | ||
50 | + | ||
51 | + # TEXT, LABEL 에 필요한 special token 만듦. | ||
52 | + text_specials, label_specials = make_special_token(args) | ||
53 | + | ||
54 | + TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials) | ||
55 | + LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials) | ||
56 | + | ||
57 | + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
58 | + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
59 | + # BucketIterator(dataset=traing_data check) | ||
60 | + return TEXT, LABEL, train_loader, test_loader |
Chatbot/metric.py
0 → 100644
1 | +import torch | ||
2 | + | ||
3 | +# acc 출력 | ||
4 | +def acc(yhat, y): | ||
5 | + with torch.no_grad(): | ||
6 | + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value | ||
7 | + acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거 | ||
8 | + return acc | ||
9 | + | ||
10 | +# 학습시 모델에 넣는 입력과 모델의 예측 출력. | ||
11 | +def train_test(step, y_pred, dec_output, real_value_index, enc_input, args, TEXT, LABEL): | ||
12 | + | ||
13 | + if 0 <= step < 3: | ||
14 | + _, ix = y_pred[real_value_index].data.topk(1) | ||
15 | + train_Q = enc_input[0] | ||
16 | + print("<<Q>> :", end=" ") | ||
17 | + for i in train_Q: | ||
18 | + if TEXT.vocab.itos[i] == "<pad>": | ||
19 | + break | ||
20 | + print(TEXT.vocab.itos[i], end=" ") | ||
21 | + | ||
22 | + print("\n<<trg A>> :", end=" ") | ||
23 | + for jj, jx in enumerate(dec_output[real_value_index]): | ||
24 | + if LABEL.vocab.itos[jx] == "<eos>": | ||
25 | + break | ||
26 | + print(LABEL.vocab.itos[jx], end=" ") | ||
27 | + | ||
28 | + print("\n<<pred A>> :", end=" ") | ||
29 | + for jj, ix in enumerate(ix): | ||
30 | + if jj == args.max_len: | ||
31 | + break | ||
32 | + if LABEL.vocab.itos[ix] == '<eos>': | ||
33 | + break | ||
34 | + print(LABEL.vocab.itos[ix], end=" ") | ||
35 | + print("\n") |
Chatbot/model.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import math | ||
4 | +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
5 | + | ||
6 | +class Transformer(nn.Module): | ||
7 | + def __init__(self, args, SRC_vocab, TRG_vocab): | ||
8 | + super(Transformer, self).__init__() | ||
9 | + self.d_model = args.embedding_dim | ||
10 | + self.n_head = args.nhead | ||
11 | + self.num_encoder_layers = args.nlayers | ||
12 | + self.num_decoder_layers = args.nlayers | ||
13 | + self.dim_feedforward = args.embedding_dim | ||
14 | + self.dropout = args.dropout | ||
15 | + | ||
16 | + self.SRC_vo = SRC_vocab | ||
17 | + self.TRG_vo = TRG_vocab | ||
18 | + | ||
19 | + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout) | ||
20 | + | ||
21 | + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model) | ||
22 | + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model) | ||
23 | + | ||
24 | + self.transfomrer = torch.nn.Transformer(d_model=self.d_model, | ||
25 | + nhead=self.n_head, | ||
26 | + num_encoder_layers=self.num_encoder_layers, | ||
27 | + num_decoder_layers=self.num_decoder_layers, | ||
28 | + dim_feedforward=self.dim_feedforward, | ||
29 | + dropout=self.dropout) | ||
30 | + self.proj_vocab_layer = nn.Linear( | ||
31 | + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab)) | ||
32 | + | ||
33 | + #self.apply(self._initailze) | ||
34 | + | ||
35 | + def forward(self, en_input, de_input): | ||
36 | + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model) | ||
37 | + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model) | ||
38 | + x_en_embed = self.pos_encoder(x_en_embed) | ||
39 | + x_de_embed = self.pos_encoder(x_de_embed) | ||
40 | + | ||
41 | + # Masking | ||
42 | + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>'] | ||
43 | + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>'] | ||
44 | + memory_key_padding_mask = src_key_padding_mask | ||
45 | + tgt_mask = self.transfomrer.generate_square_subsequent_mask(de_input.size(1)) | ||
46 | + | ||
47 | + x_en_embed = torch.einsum('ijk->jik', x_en_embed) | ||
48 | + x_de_embed = torch.einsum('ijk->jik', x_de_embed) | ||
49 | + | ||
50 | + feature = self.transfomrer(src=x_en_embed, | ||
51 | + tgt=x_de_embed, | ||
52 | + src_key_padding_mask=src_key_padding_mask, | ||
53 | + tgt_key_padding_mask=tgt_key_padding_mask, | ||
54 | + memory_key_padding_mask=memory_key_padding_mask, | ||
55 | + tgt_mask=tgt_mask.to(device)) | ||
56 | + | ||
57 | + logits = self.proj_vocab_layer(feature) | ||
58 | + logits = torch.einsum('ijk->jik', logits) | ||
59 | + | ||
60 | + return logits | ||
61 | + | ||
62 | + def _initailze(self, layer): | ||
63 | + if isinstance(layer, (nn.Linear)): | ||
64 | + nn.init.kaiming_uniform_(layer.weight) | ||
65 | + | ||
66 | +class PositionalEncoding(nn.Module): | ||
67 | + | ||
68 | + def __init__(self, d_model, dropout, max_len=15000): | ||
69 | + super(PositionalEncoding, self).__init__() | ||
70 | + self.dropout = nn.Dropout(p=dropout) | ||
71 | + | ||
72 | + pe = torch.zeros(max_len, d_model) | ||
73 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | ||
74 | + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
75 | + pe[:, 0::2] = torch.sin(position * div_term) | ||
76 | + pe[:, 1::2] = torch.cos(position * div_term) | ||
77 | + pe = pe.unsqueeze(0).transpose(0, 1) | ||
78 | + self.register_buffer('pe', pe) | ||
79 | + | ||
80 | + def forward(self, x): | ||
81 | + x = x + self.pe[:x.size(0), :] | ||
82 | + return self.dropout(x) | ||
83 | + | ||
84 | +from torch.optim.lr_scheduler import _LRScheduler | ||
85 | +from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
86 | + | ||
87 | +class GradualWarmupScheduler(_LRScheduler): | ||
88 | + | ||
89 | + """ Gradually warm-up(increasing) learning rate in optimizer. | ||
90 | + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | ||
91 | + Args: | ||
92 | + optimizer (Optimizer): Wrapped optimizer. | ||
93 | + multiplier: target learning rate = base lr * multiplier | ||
94 | + total_epoch: target learning rate is reached at total_epoch, gradually | ||
95 | + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | ||
96 | + """ | ||
97 | + | ||
98 | + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | ||
99 | + self.multiplier = multiplier | ||
100 | + if self.multiplier <= 1.: | ||
101 | + raise ValueError('multiplier should be greater than 1.') | ||
102 | + self.total_epoch = total_epoch | ||
103 | + self.after_scheduler = after_scheduler | ||
104 | + self.finished = False | ||
105 | + super().__init__(optimizer) | ||
106 | + | ||
107 | + def get_lr(self): | ||
108 | + if self.last_epoch > self.total_epoch: | ||
109 | + if self.after_scheduler: | ||
110 | + if not self.finished: | ||
111 | + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
112 | + self.finished = True | ||
113 | + return self.after_scheduler.get_lr() | ||
114 | + return [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
115 | + | ||
116 | + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | ||
117 | + | ||
118 | + def step_ReduceLROnPlateau(self, metrics, epoch=None): | ||
119 | + if epoch is None: | ||
120 | + epoch = self.last_epoch + 1 | ||
121 | + self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | ||
122 | + if self.last_epoch <= self.total_epoch: | ||
123 | + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | ||
124 | + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | ||
125 | + param_group['lr'] = lr | ||
126 | + else: | ||
127 | + if epoch is None: | ||
128 | + self.after_scheduler.step(metrics, None) | ||
129 | + else: | ||
130 | + self.after_scheduler.step(metrics, epoch - self.total_epoch) | ||
131 | + | ||
132 | + def step(self, epoch=None, metrics=None): | ||
133 | + if type(self.after_scheduler) != ReduceLROnPlateau: | ||
134 | + if self.finished and self.after_scheduler: | ||
135 | + if epoch is None: | ||
136 | + self.after_scheduler.step(None) | ||
137 | + else: | ||
138 | + self.after_scheduler.step(epoch - self.total_epoch) | ||
139 | + else: | ||
140 | + return super(GradualWarmupScheduler, self).step(epoch) | ||
141 | + else: | ||
142 | + self.step_ReduceLROnPlateau(metrics, epoch) |
-
Please register or login to post a comment