bongminkim

chatbot_py files

1 +import torch
2 +import csv
3 +import hgtk
4 +from konlpy.tag import Mecab
5 +import random
6 +
7 +mecab = Mecab()
8 +empty_list = []
9 +positive_emo = ['ㅎㅎ', '~']
10 +negative_emo = ['...', 'ㅠㅠ']
11 +asdf = []
12 +
13 +# mecab 을 통한 형태소 분석.
14 +def mecab_token_pos_flat_fn(string):
15 + tokens_ko = mecab.pos(string)
16 + return [str(pos[0]) + '/' + str(pos[1]) for pos in tokens_ko]
17 +
18 +# rough 를 위한 함수. 대명사 NP (저, 제) 를 찾아 나 or 내 로 바꿔준다.
19 +def exchange_NP(target, args):
20 + keyword = []
21 + ko_sp = mecab_token_pos_flat_fn(target)
22 + for idx, word in enumerate(ko_sp):
23 + if word.find('NP') > 0:
24 + keyword.append(word.split('/'))
25 + _idx = idx
26 + break
27 + if keyword == []:
28 + return '', -1, False
29 +
30 + if keyword[0][0] == '저':
31 + keyword[0][0] = '나'
32 + elif keyword[0][0] == '제':
33 + keyword[0][0] = '내'
34 + else:
35 + return keyword[0], _idx, False
36 +
37 + return keyword[0][0], _idx, True
38 +
39 +# 단어를 soft or rough 말투로 바꾸는 과정
40 +def make_special_word(target, args, search_ec):
41 + # mecab 를 통해 문장을 구분 (example output : ['오늘/MAG', '날씨/NNG', '좋/VA', '다/EF', './SF'])
42 + ko_sp = mecab_token_pos_flat_fn(target)
43 +
44 + keyword = []
45 +
46 + # word 에 종결어미 'EF' or 'EC' 가 포함 되어 있을 경우 index 와 keyword 추출.
47 + for idx, word in enumerate(ko_sp):
48 + if word.find('EF') > 0:
49 + keyword.append(word.split('/'))
50 + _idx = idx
51 + break
52 + if search_ec:
53 + if ko_sp[-2].find('EC') > 0:
54 + keyword.append(ko_sp[-2].split('/'))
55 + _idx = len(ko_sp) -1
56 + break
57 + else:
58 + continue
59 +
60 + # 'EF'가 없을 시 return.
61 + if keyword == []:
62 + return '', -1
63 + else:
64 + keyword = keyword[0]
65 +
66 + if args.per_rough:
67 + return keyword[0], _idx
68 +
69 + # hgtk 를 사용하여 keyword 를 쪼갬. (ex output : 하ᴥ세요)
70 + h_separation = hgtk.text.decompose(keyword[0])
71 + total_word = ''
72 +
73 + for idx, word in enumerate(h_separation):
74 + total_word += word
75 +
76 + # 'EF' 에 종성 'ㅇ' 를 붙여 Styling
77 + total_word = replaceRight(total_word, "ᴥ", "ㅇᴥ", 1)
78 +
79 + # 다 이어 붙임. ' 하세요 -> 하세용 ' 으로 변환.
80 + h_combine = hgtk.text.compose(total_word)
81 +
82 + return h_combine, _idx
83 +
84 +# special token 을 만드는 함수
85 +def make_special_token(args):
86 + # 감정을 나타내기 위한 special token
87 + target_special_voca=[]
88 +
89 + banmal_dict = get_rough_dic()
90 +
91 + # train data set 의 chatbot answer 에서 'EF' 를 뽑아 종성 'ㅇ' 을 붙인 special token 생성
92 + with open('chatbot_0325_ALLLABEL_train.txt', 'r', encoding='utf-8') as f:
93 + rdr = csv.reader(f, delimiter='\t')
94 + for idx, line in enumerate(rdr):
95 + target = line[2] # chatbot answer
96 + exchange_word, _ = make_special_word(target, args, False)
97 + target_special_voca.append(str(exchange_word))
98 + target_special_voca = list(set(target_special_voca))
99 +
100 + banmal_special_voca = []
101 + for i in range(len(target_special_voca)):
102 + try:
103 + banmal_special_voca.append(banmal_dict[target_special_voca[i]])
104 + except KeyError:
105 + if args.per_rough:
106 + print("not include banmal dictionary")
107 + pass
108 +
109 + # 임의 이모티콘 추가.
110 + target_special_voca.append('ㅎㅎ')
111 + target_special_voca.append('~')
112 + target_special_voca.append('ㅠㅠ')
113 + target_special_voca.append('...')
114 + target_special_voca = target_special_voca + banmal_special_voca
115 +
116 + # '<posi> : positive, <nega> : negative' 를 의미
117 + return ['<posi>', '<nega>'], target_special_voca
118 +
119 +# python string 함수 replace 를 오른쪽부터 시작하는 함수.
120 +def replaceRight(original, old, new, count_right):
121 + repeat = 0
122 + text = original
123 +
124 + count_find = original.count(old)
125 + if count_right > count_find: # 바꿀 횟수가 문자열에 포함된 old보다 많다면
126 + repeat = count_find # 문자열에 포함된 old의 모든 개수(count_find)만큼 교체한다
127 + else:
128 + repeat = count_right # 아니라면 입력받은 개수(count)만큼 교체한다
129 +
130 + for _ in range(repeat):
131 + find_index = text.rfind(old) # 오른쪽부터 index를 찾기위해 rfind 사용
132 + text = text[:find_index] + new + text[find_index + 1:]
133 +
134 + return text
135 +
136 +# transformer 에 input 과 output 으로 들어갈 tensor Styling 변환.
137 +def styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL):
138 +
139 + pad_tensor = torch.tensor([LABEL.vocab.stoi['<pad>']]).type(dtype=torch.int32).cuda()
140 +
141 + temp_enc = enc_input.data.cpu().numpy()
142 + batch_sentiment_list = []
143 +
144 + # 부드러운 성격
145 + if args.per_soft:
146 + # encoder input : 나는 너를 좋아해 <posi> <pad> <pad> ... - 형식으로 바꿔줌.
147 + for i in range(len(temp_enc)):
148 + for j in range(args.max_len):
149 + if temp_enc[i][j] == 1 and enc_label[i] == 0:
150 + temp_enc[i][j] = TEXT.vocab.stoi["<nega>"]
151 + batch_sentiment_list.append(0)
152 + break
153 + elif temp_enc[i][j] == 1 and enc_label[i] == 1:
154 + temp_enc[i][j] = TEXT.vocab.stoi["<posi>"]
155 + batch_sentiment_list.append(1)
156 + break
157 +
158 + enc_input = torch.tensor(temp_enc, dtype=torch.int32).cuda()
159 +
160 + for i in range(len(dec_outputs)):
161 + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1)
162 +
163 + temp_dec = dec_outputs.data.cpu().numpy()
164 +
165 + dec_outputs_sentiment_list = [] # decoder 에 들어가 감정표현 저장.
166 +
167 + # decoder outputs : 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌.
168 + for i in range(len(temp_dec)): # i = batch size
169 + temp_sentence = ''
170 + sa_ = batch_sentiment_list[i]
171 + if sa_ == 0:
172 + sa_ = random.choice(negative_emo)
173 + elif sa_ == 1:
174 + sa_ = random.choice(positive_emo)
175 + dec_outputs_sentiment_list.append(sa_)
176 +
177 + for ix, token_i in enumerate(temp_dec[i]):
178 + if LABEL.vocab.itos[token_i] == '<sos>' or LABEL.vocab.itos[token_i] == '<eos>' or LABEL.vocab.itos[token_i] == '<pad>':
179 + continue
180 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
181 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
182 + exchange_word, idx = make_special_word(temp_sentence, args, True)
183 +
184 + if exchange_word == '':
185 + for j in range(len(temp_dec[i])):
186 + if temp_dec[i][j] == LABEL.vocab.stoi['<eos>']:
187 + temp_dec[i][j] = LABEL.vocab.stoi[sa_]
188 + temp_dec[i][j+1] = LABEL.vocab.stoi['<eos>']
189 + break
190 + continue
191 +
192 + for j in range(len(temp_dec[i])):
193 + if LABEL.vocab.itos[temp_dec[i][j]] == '<eos>':
194 + temp_dec[i][j - 1] = LABEL.vocab.stoi[exchange_word]
195 + temp_dec[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
196 + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>']
197 + break
198 + elif temp_dec[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec[i]):
199 + print("\t-ERROR- No <EOS> token")
200 + exit()
201 +
202 + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).cuda()
203 +
204 + temp_dec_input = dec_input.data.cpu().numpy()
205 + # decoder input : <sos> 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌.
206 + for i in range(len(temp_dec_input)):
207 + temp_sentence = ''
208 + for ix, token_i in enumerate(temp_dec_input[i]):
209 + if LABEL.vocab.itos[token_i] == '<sos>' or LABEL.vocab.itos[token_i] == '<eos>' or LABEL.vocab.itos[token_i] == '<pad>':
210 + continue
211 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
212 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
213 + exchange_word, idx = make_special_word(temp_sentence, args, True)
214 +
215 + if exchange_word == '':
216 + for j in range(len(temp_dec_input[i])):
217 + if temp_dec_input[i][j] == LABEL.vocab.stoi['<eos>']:
218 + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
219 + temp_dec_input[i][j+1] = LABEL.vocab.stoi['<eos>']
220 + break
221 + continue
222 +
223 + for j in range(len(temp_dec_input[i])):
224 + if LABEL.vocab.itos[temp_dec_input[i][j]] == '<eos>':
225 + temp_dec_input[i][j-1] = LABEL.vocab.stoi[exchange_word]
226 + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
227 + temp_dec_input[i][j+1] = LABEL.vocab.stoi['<eos>']
228 + break
229 + elif temp_dec_input[i][j] != LABEL.vocab.stoi['<eos>'] and j+1 == len(temp_dec_input[i]):
230 + print("\t-ERROR- No <EOS> token")
231 + exit()
232 +
233 + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).cuda()
234 +
235 + # 거친 성격
236 + elif args.per_rough:
237 + banmal_dic = get_rough_dic()
238 +
239 + for i in range(len(dec_outputs)):
240 + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1)
241 +
242 + temp_dec = dec_outputs.data.cpu().numpy()
243 +
244 + # decoder outputs : 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌.
245 + for i in range(len(temp_dec)): # i = batch size
246 + temp_sentence = ''
247 + for ix, token_i in enumerate(temp_dec[i]):
248 + if LABEL.vocab.itos[token_i] == '<eos>':
249 + break
250 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
251 + temp_sentence = temp_sentence+'.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
252 + exchange_word, idx = make_special_word(temp_sentence, args, True)
253 + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence, args)
254 +
255 + if exist:
256 + temp_dec[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word]
257 +
258 + if exchange_word == '':
259 + continue
260 + try:
261 + exchange_word = banmal_dic[exchange_word]
262 + except KeyError:
263 + asdf.append(exchange_word)
264 + print("not include banmal dictionary")
265 + pass
266 +
267 + temp_dec[i][idx] = LABEL.vocab.stoi[exchange_word]
268 + temp_dec[i][idx+1] = LABEL.vocab.stoi['<eos>']
269 + for k in range(idx+2, args.max_len):
270 + temp_dec[i][k] = LABEL.vocab.stoi['<pad>']
271 +
272 + # for j in range(len(temp_dec[i])):
273 + # if LABEL.vocab.itos[temp_dec[i][j]]=='<eos>':
274 + # break
275 + # print(LABEL.vocab.itos[temp_dec[i][j]], end='')
276 + # print()
277 +
278 + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).cuda()
279 +
280 + temp_dec_input = dec_input.data.cpu().numpy()
281 + # decoder input : <sos> 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌.
282 + for i in range(len(temp_dec_input)):
283 + temp_sentence = ''
284 + for ix, token_i in enumerate(temp_dec_input[i]):
285 + if ix == 0 :
286 + continue # because of token <sos>
287 + if LABEL.vocab.itos[token_i] == '<eos>':
288 + break
289 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
290 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
291 + exchange_word, idx = make_special_word(temp_sentence, args, True)
292 + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence, args)
293 + idx = idx + 1 # because of token <sos>
294 + NP_idx = NP_idx + 1
295 +
296 + if exist:
297 + temp_dec_input[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word]
298 +
299 + if exchange_word == '':
300 + continue
301 +
302 + try:
303 + exchange_word = banmal_dic[exchange_word]
304 + except KeyError:
305 + print("not include banmal dictionary")
306 + pass
307 +
308 + temp_dec_input[i][idx] = LABEL.vocab.stoi[exchange_word]
309 + temp_dec_input[i][idx + 1] = LABEL.vocab.stoi['<eos>']
310 +
311 + for k in range(idx+2, args.max_len):
312 + temp_dec_input[i][k] = LABEL.vocab.stoi['<pad>']
313 +
314 + # for j in range(len(temp_dec_input[i])):
315 + # if LABEL.vocab.itos[temp_dec_input[i][j]]=='<eos>':
316 + # break
317 + # print(LABEL.vocab.itos[temp_dec_input[i][j]], end='')
318 + # print()
319 +
320 + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).cuda()
321 +
322 + return enc_input, dec_input, dec_outputs
323 +
324 +# 반말로 바꾸기위한 딕셔너리
325 +def get_rough_dic():
326 + my_exword = {
327 + '돌아와요': '돌아와',
328 + '으세요': '으셈',
329 + '잊어버려요': '잊어버려',
330 + '나온대요': '나온대',
331 + '될까요': '될까',
332 + '할텐데': '할텐데',
333 + '옵니다': '온다',
334 + '봅니다': '본다',
335 + '네요': '네',
336 + '된답니다': '된대',
337 + '데요': '데',
338 + '봐요': '봐',
339 + '부러워요': '부러워',
340 + '바랄게요': '바랄게',
341 + '지나갑니다': "지가간다",
342 + '이뻐요': "이뻐",
343 + '지요': "지",
344 + '사세요': "사라",
345 + '던가요': "던가",
346 + '모릅니다': "몰라",
347 + '은가요': "은가",
348 + '심해요': "심해",
349 + '몰라요': "몰라",
350 + '라요': "라",
351 + '더라고요': '더라고',
352 + '입니다': '이라고',
353 + '는다면요': '는다면',
354 + '멋져요': '멋져',
355 + '다면요': '다면',
356 + '다니': '다나',
357 + '져요': '져',
358 + '만드세요': '만들어',
359 + '야죠': '야지',
360 + '죠': '지',
361 + '해줄게요': '해줄게',
362 + '대요': '대',
363 + '돌아갑시다': '돌아가자',
364 + '해보여요': '해봐',
365 + '라뇨': '라니',
366 + '편합니다': '편해',
367 + '합시다': '하자',
368 + '드세요': '먹어',
369 + '아름다워요': '아름답네',
370 + '드립니다': '줄게',
371 + '받아들여요': '받아들여',
372 + '건가요': '간기',
373 + '쏟아진다': '쏟아지네',
374 + '슬퍼요': '슬퍼',
375 + '해서요': '해서',
376 + '다릅니다': '다르다',
377 + '니다': '니',
378 + '내려요': '내려',
379 + '마셔요': '마셔',
380 + '아세요': '아냐',
381 + '변해요': '뱐헤',
382 + '드려요': '드려',
383 + '아요': '아',
384 + '어서요': '어서',
385 + '뜁니다': '뛴다',
386 + '속상해요': '속상해',
387 + '래요': '래',
388 + '까요': '까',
389 + '어야죠': '어야지',
390 + '라니': '라니',
391 + '해집니다': '해진다',
392 + '으련만': '으련만',
393 + '지워져요': '지워져',
394 + '잘라요': '잘라',
395 + '고요': '고',
396 + '셔야죠': '셔야지',
397 + '다쳐요': '다쳐',
398 + '는구나': '는구만',
399 + '은데요': '은데',
400 + '일까요': '일까',
401 + '인가요': '인가',
402 + '아닐까요': '아닐까',
403 + '텐데요': '텐데',
404 + '할게요': '할게',
405 + '보입니다': '보이네',
406 + '에요': '야',
407 + '걸요': '걸',
408 + '한답니다': '한대',
409 + '을까요': '을까',
410 + '못해요': '못해',
411 + '베푸세요': '베풀어',
412 + '어때요': '어떄',
413 + '더라구요': '더라구',
414 + '노라': '노라',
415 + '반가워요': '반가워',
416 + '군요': '군',
417 + '만납시다': '만나자',
418 + '어떠세요': '어때',
419 + '달라져요': '달라져',
420 + '예뻐요': '예뻐',
421 + '됩니다': '된다',
422 + '봅시다': '보자',
423 + '한대요': '한대',
424 + '싸워요': '싸워',
425 + '와요': '와',
426 + '인데요': '인데',
427 + '야': '야',
428 + '줄게요': '줄게',
429 + '기에요': '기',
430 + '던데요': '던데',
431 + '걸까요': '걸까',
432 + '신가요': '신가',
433 + '어요': '어',
434 + '따져요': '따져',
435 + '갈게요': '갈게',
436 + '봐': '봐',
437 + '나요': '나',
438 + '니까요': '니까',
439 + '마요': '마',
440 + '씁니다': '쓴다',
441 + '집니다': '진다',
442 + '건데요': '건데',
443 + '지웁시다': '지우자',
444 + '바랍니다': '바래',
445 + '는데요': '는데',
446 + '으니까요': '으니까',
447 + '셔요': '셔',
448 + '네여': '네',
449 + '달라요': '달라',
450 + '거려요': '거려',
451 + '보여요': '보여',
452 + '겁니다': '껄',
453 + '다': '다',
454 + '그래요': '그래',
455 + '한가요': '한가',
456 + '잖아요': '잖아',
457 + '한데요': '한데',
458 + '우세요': '우셈',
459 + '해야죠': '해야지',
460 + '세요': '셈',
461 + '걸려요': '걸려',
462 + '텐데': '텐데',
463 + '어딘가': '어딘가',
464 + '요': '',
465 + '흘러갑니다': '흘러간다',
466 + '줘요': '줘',
467 + '편해요': '편해',
468 + '거예요': '거야',
469 + '예요': '야',
470 + '습니다': '어',
471 + '아닌가요': '아닌가',
472 + '합니다': '한다',
473 + '사라집니다': '사라져',
474 + '드릴게요': '줄게',
475 + '다면': '다면',
476 + '그럴까요': '그럴까',
477 + '해요': '해',
478 + '답니다': '다',
479 + '주무세요': '자라',
480 + '마세요': '마라',
481 + '아픈가요': '아프냐',
482 + '그런가요': '그런가',
483 + '했잖아요': '했잖아',
484 + '버려요': '버려',
485 + '갑니다': '간다',
486 + '가요': '가',
487 + '라면요': '라면',
488 + '아야죠': '아야지',
489 + '살펴봐요': '살펴봐',
490 + '남겨요': '남겨',
491 + '내려놔요': '내려놔',
492 + '떨려요': '떨려',
493 + '랍니다': '란다',
494 + '돼요': '돼',
495 + '버텨요': '버텨',
496 + '만나': '만나',
497 + '일러요': '일러',
498 + '을게요': '을게',
499 + '갑시다': '가자',
500 + '나아요': '나아',
501 + '어려요': '어려',
502 + '온대요': '온대',
503 + '다고요': '다고',
504 + '할래요': '할래',
505 + '된대요': '된대',
506 + '어울려요': '어울려',
507 + '는군요': '는군',
508 + '볼까요': '볼까',
509 + '드릴까요': '줄까',
510 + '라던데요': '라던데',
511 + '올게요': '올게',
512 + '기뻐요': '기뻐',
513 + '아닙니다': '아냐',
514 + '둬요': '둬',
515 + '십니다': '십',
516 + '아파요': '아파',
517 + '생겨요': '생겨',
518 + '해줘요': '해줘',
519 + '로군요': '로군요',
520 + '시켜요': '시켜',
521 + '느껴져요': '느껴져',
522 + '가재요': '가재',
523 + '어 ': ' ',
524 + '느려요': '느려',
525 + '볼게요': '볼게',
526 + '쉬워요': '쉬워',
527 + '나빠요': '나빠',
528 + '불러줄게요': '불러줄게',
529 + '살쪄요': '살쪄',
530 + '봐야겠어요': '봐야겠어',
531 + '네': '네',
532 + '어': '어',
533 + '든지요': '든지',
534 + '드신다': '드심',
535 + '가져요': '가져',
536 + '할까요': '할까',
537 + '졸려요': '졸려',
538 + '그럴게요': '그럴게',
539 + '': '',
540 + '어린가': '어린가',
541 + '나와요': '나와',
542 + '빨라요': '빨라',
543 + '겠죠': '겠지',
544 + '졌어요': '졌어',
545 + '해봐요': '해봐',
546 + '게요': '게',
547 + '해드릴까요': '해줄까',
548 + '인걸요': '인걸',
549 + '했어요': '했어',
550 + '원해요': '원해',
551 + '는걸요': '는걸',
552 + '좋아합니다': '좋아해',
553 + '했으면': '했으면',
554 + '나갑니다': '나간다',
555 + '왔어요': '왔어',
556 + '해봅시다': '해보자',
557 + '물어봐요': '물어봐',
558 + '생겼어요': '생겼어',
559 + '해': '해',
560 + '다녀올게요': '다녀올게',
561 + '납시다': '나자'
562 + }
563 + return my_exword
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from get_data import tokenizer1
3 +from torch.autograd import Variable
4 +from chatspace import ChatSpace
5 +spacer = ChatSpace()
6 +
7 +def inference(device, args, TEXT, LABEL, model, sa_model):
8 + from KoBERT.Sentiment_Analysis_BERT_main import bert_inference
9 + sentence = input("문장을 입력하세요 : ")
10 + se_list = [sentence]
11 +
12 + # https://github.com/SKTBrain/KoBERT
13 + # SKT 에서 공개한 KoBert Sentiment Analysis 를 통해 입력문장의 긍정 부정 판단.
14 + sa_label = int(bert_inference(sa_model, se_list))
15 +
16 + sa_token = ''
17 + # SA Label 에 따른 encoder input 변화.
18 + if sa_label == 0:
19 + sa_token = TEXT.vocab.stoi['<nega>']
20 + else:
21 + sa_token = TEXT.vocab.stoi['<posi>']
22 +
23 + enc_input = tokenizer1(sentence)
24 + enc_input_index = []
25 +
26 + for tok in enc_input:
27 + enc_input_index.append(TEXT.vocab.stoi[tok])
28 +
29 + # encoder input string to index tensor and plus <pad>
30 + if args.per_soft:
31 + enc_input_index.append(sa_token)
32 +
33 + for j in range(args.max_len - len(enc_input_index)):
34 + enc_input_index.append(TEXT.vocab.stoi['<pad>'])
35 +
36 + enc_input_index = Variable(torch.LongTensor([enc_input_index]))
37 +
38 + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]])
39 + #print("긍정" if sa_label == 1 else "부정")
40 +
41 + model.eval()
42 + pred = []
43 + for i in range(args.max_len):
44 + y_pred = model(enc_input_index.to(device), dec_input.to(device))
45 + y_pred_ids = y_pred.max(dim=-1)[1]
46 + if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']):
47 + y_pred_ids = y_pred_ids.squeeze(0)
48 + print(">", end=" ")
49 + for idx in range(len(y_pred_ids)):
50 + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
51 + pred_sentence = "".join(pred)
52 + pred_str = spacer.space(pred_sentence)
53 + print(pred_str)
54 + break
55 + else:
56 + pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
57 + return 0
58 +
59 + dec_input = torch.cat(
60 + [dec_input.to(torch.device('cpu')),
61 + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
62 + return 0
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from torchtext import data
3 +from torchtext.data import TabularDataset
4 +from torchtext.data import BucketIterator
5 +from torchtext.vocab import Vectors
6 +from konlpy.tag import Mecab
7 +import re
8 +from Styling import styling, make_special_token
9 +
10 +# tokenizer
11 +def tokenizer1(text):
12 + result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text)
13 + a = Mecab().morphs(result_text)
14 + return ([a[i] for i in range(len(a))])
15 +
16 +# 데이터 전처리 및 loader return
17 +def data_preprocessing(args, device):
18 +
19 + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임.
20 + ID = data.Field(sequential=False,
21 + use_vocab=False)
22 +
23 + TEXT = data.Field(sequential=True,
24 + use_vocab=True,
25 + tokenize=tokenizer1,
26 + batch_first=True,
27 + fix_length=args.max_len,
28 + dtype=torch.int32
29 + )
30 +
31 + LABEL = data.Field(sequential=True,
32 + use_vocab=True,
33 + tokenize=tokenizer1,
34 + batch_first=True,
35 + fix_length=args.max_len,
36 + init_token='<sos>',
37 + eos_token='<eos>',
38 + dtype=torch.int32
39 + )
40 +
41 + SA = data.Field(sequential=False,
42 + use_vocab=False)
43 +
44 + train_data, test_data = TabularDataset.splits(
45 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
46 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
47 + )
48 +
49 + vectors = Vectors(name="kr-projected.txt")
50 +
51 + # TEXT, LABEL 에 필요한 special token 만듦.
52 + text_specials, label_specials = make_special_token(args)
53 +
54 + TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials)
55 + LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials)
56 +
57 + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
58 + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
59 + # BucketIterator(dataset=traing_data check)
60 + return TEXT, LABEL, train_loader, test_loader
1 +import torch
2 +
3 +# acc 출력
4 +def acc(yhat, y):
5 + with torch.no_grad():
6 + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value
7 + acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거
8 + return acc
9 +
10 +# 학습시 모델에 넣는 입력과 모델의 예측 출력.
11 +def train_test(step, y_pred, dec_output, real_value_index, enc_input, args, TEXT, LABEL):
12 +
13 + if 0 <= step < 3:
14 + _, ix = y_pred[real_value_index].data.topk(1)
15 + train_Q = enc_input[0]
16 + print("<<Q>> :", end=" ")
17 + for i in train_Q:
18 + if TEXT.vocab.itos[i] == "<pad>":
19 + break
20 + print(TEXT.vocab.itos[i], end=" ")
21 +
22 + print("\n<<trg A>> :", end=" ")
23 + for jj, jx in enumerate(dec_output[real_value_index]):
24 + if LABEL.vocab.itos[jx] == "<eos>":
25 + break
26 + print(LABEL.vocab.itos[jx], end=" ")
27 +
28 + print("\n<<pred A>> :", end=" ")
29 + for jj, ix in enumerate(ix):
30 + if jj == args.max_len:
31 + break
32 + if LABEL.vocab.itos[ix] == '<eos>':
33 + break
34 + print(LABEL.vocab.itos[ix], end=" ")
35 + print("\n")
1 +import torch
2 +import torch.nn as nn
3 +import math
4 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
5 +
6 +class Transformer(nn.Module):
7 + def __init__(self, args, SRC_vocab, TRG_vocab):
8 + super(Transformer, self).__init__()
9 + self.d_model = args.embedding_dim
10 + self.n_head = args.nhead
11 + self.num_encoder_layers = args.nlayers
12 + self.num_decoder_layers = args.nlayers
13 + self.dim_feedforward = args.embedding_dim
14 + self.dropout = args.dropout
15 +
16 + self.SRC_vo = SRC_vocab
17 + self.TRG_vo = TRG_vocab
18 +
19 + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout)
20 +
21 + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model)
22 + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model)
23 +
24 + self.transfomrer = torch.nn.Transformer(d_model=self.d_model,
25 + nhead=self.n_head,
26 + num_encoder_layers=self.num_encoder_layers,
27 + num_decoder_layers=self.num_decoder_layers,
28 + dim_feedforward=self.dim_feedforward,
29 + dropout=self.dropout)
30 + self.proj_vocab_layer = nn.Linear(
31 + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab))
32 +
33 + #self.apply(self._initailze)
34 +
35 + def forward(self, en_input, de_input):
36 + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model)
37 + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model)
38 + x_en_embed = self.pos_encoder(x_en_embed)
39 + x_de_embed = self.pos_encoder(x_de_embed)
40 +
41 + # Masking
42 + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>']
43 + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>']
44 + memory_key_padding_mask = src_key_padding_mask
45 + tgt_mask = self.transfomrer.generate_square_subsequent_mask(de_input.size(1))
46 +
47 + x_en_embed = torch.einsum('ijk->jik', x_en_embed)
48 + x_de_embed = torch.einsum('ijk->jik', x_de_embed)
49 +
50 + feature = self.transfomrer(src=x_en_embed,
51 + tgt=x_de_embed,
52 + src_key_padding_mask=src_key_padding_mask,
53 + tgt_key_padding_mask=tgt_key_padding_mask,
54 + memory_key_padding_mask=memory_key_padding_mask,
55 + tgt_mask=tgt_mask.to(device))
56 +
57 + logits = self.proj_vocab_layer(feature)
58 + logits = torch.einsum('ijk->jik', logits)
59 +
60 + return logits
61 +
62 + def _initailze(self, layer):
63 + if isinstance(layer, (nn.Linear)):
64 + nn.init.kaiming_uniform_(layer.weight)
65 +
66 +class PositionalEncoding(nn.Module):
67 +
68 + def __init__(self, d_model, dropout, max_len=15000):
69 + super(PositionalEncoding, self).__init__()
70 + self.dropout = nn.Dropout(p=dropout)
71 +
72 + pe = torch.zeros(max_len, d_model)
73 + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
74 + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
75 + pe[:, 0::2] = torch.sin(position * div_term)
76 + pe[:, 1::2] = torch.cos(position * div_term)
77 + pe = pe.unsqueeze(0).transpose(0, 1)
78 + self.register_buffer('pe', pe)
79 +
80 + def forward(self, x):
81 + x = x + self.pe[:x.size(0), :]
82 + return self.dropout(x)
83 +
84 +from torch.optim.lr_scheduler import _LRScheduler
85 +from torch.optim.lr_scheduler import ReduceLROnPlateau
86 +
87 +class GradualWarmupScheduler(_LRScheduler):
88 +
89 + """ Gradually warm-up(increasing) learning rate in optimizer.
90 + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
91 + Args:
92 + optimizer (Optimizer): Wrapped optimizer.
93 + multiplier: target learning rate = base lr * multiplier
94 + total_epoch: target learning rate is reached at total_epoch, gradually
95 + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
96 + """
97 +
98 + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
99 + self.multiplier = multiplier
100 + if self.multiplier <= 1.:
101 + raise ValueError('multiplier should be greater than 1.')
102 + self.total_epoch = total_epoch
103 + self.after_scheduler = after_scheduler
104 + self.finished = False
105 + super().__init__(optimizer)
106 +
107 + def get_lr(self):
108 + if self.last_epoch > self.total_epoch:
109 + if self.after_scheduler:
110 + if not self.finished:
111 + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
112 + self.finished = True
113 + return self.after_scheduler.get_lr()
114 + return [base_lr * self.multiplier for base_lr in self.base_lrs]
115 +
116 + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
117 +
118 + def step_ReduceLROnPlateau(self, metrics, epoch=None):
119 + if epoch is None:
120 + epoch = self.last_epoch + 1
121 + self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
122 + if self.last_epoch <= self.total_epoch:
123 + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
124 + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
125 + param_group['lr'] = lr
126 + else:
127 + if epoch is None:
128 + self.after_scheduler.step(metrics, None)
129 + else:
130 + self.after_scheduler.step(metrics, epoch - self.total_epoch)
131 +
132 + def step(self, epoch=None, metrics=None):
133 + if type(self.after_scheduler) != ReduceLROnPlateau:
134 + if self.finished and self.after_scheduler:
135 + if epoch is None:
136 + self.after_scheduler.step(None)
137 + else:
138 + self.after_scheduler.step(epoch - self.total_epoch)
139 + else:
140 + return super(GradualWarmupScheduler, self).step(epoch)
141 + else:
142 + self.step_ReduceLROnPlateau(metrics, epoch)