김민수

light_model 추가 및 README 보강

1 +FROM ufoym/deepo:pytorch-cpu
2 +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile
3 +# Install JVM for Konlpy
4 +RUN apt-get update && \
5 + apt-get upgrade -y && \
6 + apt-get install -y \
7 + openjdk-8-jdk wget curl git python3-dev \
8 + language-pack-ko
9 +
10 +RUN locale-gen en_US.UTF-8 && \
11 + update-locale LANG=en_US.UTF-8
12 +
13 +# Install zsh
14 +RUN apt-get install -y zsh && \
15 + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
16 +
17 +# Install another packages
18 +RUN pip install --upgrade pip
19 +RUN pip install autopep8
20 +RUN pip install konlpy
21 +RUN pip install torchtext pytorch_pretrained_bert
22 +# Install dependency of styling chatbot
23 +RUN pip install hgtk chatspace
24 +
25 +# Add Mecab-Ko
26 +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash
27 +# install styling chatbot by BM-K
28 +RUN git clone https://github.com/km19809/light_model.git
29 +RUN pip install -r light_model/requirements.txt
30 +
31 +# Add non-root user
32 +RUN adduser --disabled-password --gecos "" user
33 +
34 +# Reset Workdir
35 +WORKDIR /light_model
...\ No newline at end of file ...\ No newline at end of file
1 +# Light weight model of styling chatbot
2 +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\
3 +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer)
4 +
5 +## 요구사항
6 +
7 +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다.
8 +```
9 +torch~=1.4.0
10 +Flask~=1.1.2
11 +torchtext~=0.6.0
12 +hgtk~=0.1.3
13 +konlpy~=0.5.2
14 +chatspace~=1.0.1
15 +```
16 +
17 +## 사용법
18 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
19 +
20 +* train: 학습해 모델을 만들 경우에 사용합니다. \
21 +사용하지 않으면 모델을 불러와 시험 합니다.
22 +* per_soft: soft 말투를 학습 또는 시험합니다.\
23 +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\
24 +두 옵션은 양립 불가능합니다.
25 +
26 +`app.py`
27 +
28 +챗봇을 시험하기 위한 간단한 플라스크 서버입니다.
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import csv
3 +import hgtk
4 +from konlpy.tag import Mecab
5 +import random
6 +
7 +mecab = Mecab()
8 +empty_list = []
9 +positive_emo = ['ㅎㅎ', '~']
10 +negative_emo = ['...', 'ㅠㅠ']
11 +asdf = []
12 +
13 +
14 +# mecab 을 통한 형태소 분석.
15 +def mecab_token_pos_flat_fn(string: str):
16 + tokens_ko = mecab.pos(string)
17 + return [str(pos[0]) + '/' + str(pos[1]) for pos in tokens_ko]
18 +
19 +
20 +# rough 를 위한 함수. 대명사 NP (저, 제) 를 찾아 나 or 내 로 바꿔준다.
21 +def exchange_NP(target: str):
22 + keyword = []
23 + ko_sp = mecab_token_pos_flat_fn(target)
24 + _idx = -1 # 실패 시 기본 값
25 + for idx, word in enumerate(ko_sp):
26 + if word.find('NP') > 0:
27 + keyword.append(word.split('/'))
28 + _idx = idx
29 + break
30 + if not keyword: # keyword 가 비었을 때
31 + return '', _idx, False
32 +
33 + if keyword[0][0] == '저':
34 + keyword[0][0] = '나'
35 + elif keyword[0][0] == '제':
36 + keyword[0][0] = '내'
37 + else:
38 + return keyword[0], _idx, False
39 +
40 + return keyword[0][0], _idx, True
41 +
42 +
43 +# 단어를 soft or rough 말투로 바꾸는 과정
44 +def make_special_word(target: str, per_rough: bool, search_ec: bool):
45 + # mecab 를 통해 문장을 구분 (example output : ['오늘/MAG', '날씨/NNG', '좋/VA', '다/EF', './SF'])
46 + ko_sp = mecab_token_pos_flat_fn(target)
47 +
48 + keyword = []
49 + _idx = -1 # 실패 시 기본 값
50 + # word 에 종결어미 'EF' or 'EC' 가 포함 되어 있을 경우 index 와 keyword 추출.
51 + for idx, word in enumerate(ko_sp):
52 + if word.find('EF') > 0:
53 + keyword.append(word.split('/'))
54 + _idx = idx
55 + break
56 + if search_ec:
57 + if ko_sp[-2].find('EC') > 0:
58 + keyword.append(ko_sp[-2].split('/'))
59 + _idx = len(ko_sp) - 1
60 + break
61 + else:
62 + continue
63 +
64 + # 'EF'가 없을 시 return.
65 + if not keyword:
66 + return '', _idx
67 + else:
68 + _keyword = keyword[0]
69 +
70 + if per_rough:
71 + return _keyword[0], _idx
72 +
73 + # hgtk 를 사용하여 keyword 를 쪼갬. (ex output : 하ᴥ세요)
74 + h_separation = hgtk.text.decompose(_keyword[0])
75 + total_word = ''
76 +
77 + for idx, word in enumerate(h_separation):
78 + total_word += word
79 +
80 + # 'EF' 에 종성 'ㅇ' 를 붙여 Styling
81 + total_word = replace_right(total_word, "ᴥ", "ㅇᴥ", 1)
82 +
83 + # 다 이어 붙임. ' 하세요 -> 하세용 ' 으로 변환.
84 + h_combine = hgtk.text.compose(total_word)
85 +
86 + return h_combine, _idx
87 +
88 +
89 +# special token 을 만드는 함수
90 +def make_special_token(per_rough: bool):
91 + # 감정을 나타내기 위한 special token
92 + target_special_voca = []
93 +
94 + banmal_dict = get_rough_dic()
95 +
96 + # train data set 의 chatbot answer 에서 'EF' 를 뽑아 종성 'ㅇ' 을 붙인 special token 생성
97 + with open('chatbot_0325_ALLLABEL_train.txt', 'r', encoding='utf-8') as f:
98 + rdr = csv.reader(f, delimiter='\t')
99 + for idx, line in enumerate(rdr):
100 + target = line[2] # chatbot answer
101 + exchange_word, _ = make_special_word(target, per_rough, False)
102 + target_special_voca.append(str(exchange_word))
103 + target_special_voca = list(set(target_special_voca))
104 +
105 + banmal_special_voca = []
106 + for i in range(len(target_special_voca)):
107 + try:
108 + banmal_special_voca.append(banmal_dict[target_special_voca[i]])
109 + except KeyError:
110 + if per_rough:
111 + print("not include banmal dictionary")
112 + pass
113 +
114 + # 임의 이모티콘 추가.
115 + target_special_voca.append('ㅎㅎ')
116 + target_special_voca.append('~')
117 + target_special_voca.append('ㅠㅠ')
118 + target_special_voca.append('...')
119 + target_special_voca = target_special_voca + banmal_special_voca
120 +
121 + # '<posi> : positive, <nega> : negative' 를 의미
122 + return ['<posi>', '<nega>'], target_special_voca
123 +
124 +
125 +# python string 함수 replace 를 오른쪽부터 시작하는 함수.
126 +def replace_right(original: str, old: str, new: str, count_right: int):
127 + text = original
128 +
129 + count_find = original.count(old)
130 + # 바꿀 횟수가 문자열에 포함된 old보다 많다면 문자열에 포함된 old의 모든 개수(count_find)만큼 교체한다 아니라면 입력받은 개수(count)만큼 교체한다
131 + repeat = count_find if count_right > count_find else count_right
132 + for _ in range(repeat):
133 + find_index = text.rfind(old) # 오른쪽부터 index를 찾기위해 rfind 사용
134 + text = text[:find_index] + new + text[find_index + 1:]
135 +
136 + return text
137 +
138 +
139 +# transformer 에 input 과 output 으로 들어갈 tensor Styling 변환.
140 +def styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len: int, per_soft: bool, per_rough: bool, TEXT, LABEL):
141 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
142 +
143 + pad_tensor = torch.tensor([LABEL.vocab.stoi['<pad>']]).type(dtype=torch.int32).to(device)
144 +
145 + temp_enc = enc_input.data.cpu().numpy()
146 + batch_sentiment_list = []
147 +
148 + # 부드러운 성격
149 + if per_soft:
150 + # encoder input : 나는 너를 좋아해 <posi> <pad> <pad> ... - 형식으로 바꿔줌.
151 + for i in range(len(temp_enc)):
152 + for j in range(max_len):
153 + if temp_enc[i][j] == 1 and enc_label[i] == 0:
154 + temp_enc[i][j] = TEXT.vocab.stoi["<nega>"]
155 + batch_sentiment_list.append(0)
156 + break
157 + elif temp_enc[i][j] == 1 and enc_label[i] == 1:
158 + temp_enc[i][j] = TEXT.vocab.stoi["<posi>"]
159 + batch_sentiment_list.append(1)
160 + break
161 +
162 + enc_input = torch.tensor(temp_enc, dtype=torch.int32).to(device)
163 +
164 + for i in range(len(dec_outputs)):
165 + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1)
166 +
167 + temp_dec = dec_outputs.data.cpu().numpy()
168 +
169 + dec_outputs_sentiment_list = [] # decoder 에 들어가 감정표현 저장.
170 +
171 + # decoder outputs : 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌.
172 + for i in range(len(temp_dec)): # i = batch size
173 + temp_sentence = ''
174 + sa_ = batch_sentiment_list[i]
175 + if sa_ == 0:
176 + sa_ = random.choice(negative_emo)
177 + elif sa_ == 1:
178 + sa_ = random.choice(positive_emo)
179 + dec_outputs_sentiment_list.append(sa_)
180 +
181 + for ix, token_i in enumerate(temp_dec[i]):
182 + if LABEL.vocab.itos[token_i] in ['<sos>', '<eos>', '<pad>']:
183 + continue
184 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
185 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
186 + exchange_word, idx = make_special_word(temp_sentence, per_rough, True)
187 +
188 + if exchange_word == '':
189 + for j in range(len(temp_dec[i])):
190 + if temp_dec[i][j] == LABEL.vocab.stoi['<eos>']:
191 + temp_dec[i][j] = LABEL.vocab.stoi[sa_]
192 + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>']
193 + break
194 + continue
195 +
196 + for j in range(len(temp_dec[i])):
197 + if LABEL.vocab.itos[temp_dec[i][j]] == '<eos>':
198 + temp_dec[i][j - 1] = LABEL.vocab.stoi[exchange_word]
199 + temp_dec[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
200 + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>']
201 + break
202 + elif temp_dec[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec[i]):
203 + print("\t-ERROR- No <EOS> token")
204 + exit()
205 +
206 + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).to(device)
207 +
208 + temp_dec_input = dec_input.data.cpu().numpy()
209 + # decoder input : <sos> 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌.
210 + for i in range(len(temp_dec_input)):
211 + temp_sentence = ''
212 + for ix, token_i in enumerate(temp_dec_input[i]):
213 + if LABEL.vocab.itos[token_i] in ['<sos>', '<eos>', '<pad>']:
214 + continue
215 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
216 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
217 + exchange_word, idx = make_special_word(temp_sentence, per_rough, True)
218 +
219 + if exchange_word == '':
220 + for j in range(len(temp_dec_input[i])):
221 + if temp_dec_input[i][j] == LABEL.vocab.stoi['<eos>']:
222 + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
223 + temp_dec_input[i][j + 1] = LABEL.vocab.stoi['<eos>']
224 + break
225 + continue
226 +
227 + for j in range(len(temp_dec_input[i])):
228 + if LABEL.vocab.itos[temp_dec_input[i][j]] == '<eos>':
229 + temp_dec_input[i][j - 1] = LABEL.vocab.stoi[exchange_word]
230 + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]]
231 + temp_dec_input[i][j + 1] = LABEL.vocab.stoi['<eos>']
232 + break
233 + elif temp_dec_input[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec_input[i]):
234 + print("\t-ERROR- No <EOS> token")
235 + exit()
236 +
237 + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).to(device)
238 +
239 + # 거친 성격
240 + elif per_rough:
241 + banmal_dic = get_rough_dic()
242 +
243 + for i in range(len(dec_outputs)):
244 + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1)
245 +
246 + temp_dec = dec_outputs.data.cpu().numpy()
247 +
248 + # decoder outputs : 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌.
249 + for i in range(len(temp_dec)): # i = batch size
250 + temp_sentence = ''
251 + for ix, token_i in enumerate(temp_dec[i]):
252 + if LABEL.vocab.itos[token_i] == '<eos>':
253 + break
254 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
255 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
256 + exchange_word, idx = make_special_word(temp_sentence, per_rough, True)
257 + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence)
258 +
259 + if exist:
260 + temp_dec[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word]
261 +
262 + if exchange_word == '':
263 + continue
264 + try:
265 + exchange_word = banmal_dic[exchange_word]
266 + except KeyError:
267 + asdf.append(exchange_word)
268 + print("not include banmal dictionary")
269 + pass
270 +
271 + temp_dec[i][idx] = LABEL.vocab.stoi[exchange_word]
272 + temp_dec[i][idx + 1] = LABEL.vocab.stoi['<eos>']
273 + for k in range(idx + 2, max_len):
274 + temp_dec[i][k] = LABEL.vocab.stoi['<pad>']
275 +
276 + # for j in range(len(temp_dec[i])):
277 + # if LABEL.vocab.itos[temp_dec[i][j]]=='<eos>':
278 + # break
279 + # print(LABEL.vocab.itos[temp_dec[i][j]], end='')
280 + # print()
281 +
282 + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).to(device)
283 +
284 + temp_dec_input = dec_input.data.cpu().numpy()
285 + # decoder input : <sos> 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌.
286 + for i in range(len(temp_dec_input)):
287 + temp_sentence = ''
288 + for ix, token_i in enumerate(temp_dec_input[i]):
289 + if ix == 0:
290 + continue # because of token <sos>
291 + if LABEL.vocab.itos[token_i] == '<eos>':
292 + break
293 + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i]
294 + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐.
295 + exchange_word, idx = make_special_word(temp_sentence, per_rough, True)
296 + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence)
297 + idx = idx + 1 # because of token <sos>
298 + NP_idx = NP_idx + 1
299 +
300 + if exist:
301 + temp_dec_input[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word]
302 +
303 + if exchange_word == '':
304 + continue
305 +
306 + try:
307 + exchange_word = banmal_dic[exchange_word]
308 + except KeyError:
309 + print("not include banmal dictionary")
310 + pass
311 +
312 + temp_dec_input[i][idx] = LABEL.vocab.stoi[exchange_word]
313 + temp_dec_input[i][idx + 1] = LABEL.vocab.stoi['<eos>']
314 +
315 + for k in range(idx + 2, max_len):
316 + temp_dec_input[i][k] = LABEL.vocab.stoi['<pad>']
317 +
318 + # for j in range(len(temp_dec_input[i])):
319 + # if LABEL.vocab.itos[temp_dec_input[i][j]]=='<eos>':
320 + # break
321 + # print(LABEL.vocab.itos[temp_dec_input[i][j]], end='')
322 + # print()
323 +
324 + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).to(device)
325 +
326 + return enc_input, dec_input, dec_outputs
327 +
328 +
329 +# 반말로 바꾸기위한 딕셔너리
330 +def get_rough_dic():
331 + my_exword = {
332 + '돌아와요': '돌아와',
333 + '으세요': '으셈',
334 + '잊어버려요': '잊어버려',
335 + '나온대요': '나온대',
336 + '될까요': '될까',
337 + '할텐데': '할텐데',
338 + '옵니다': '온다',
339 + '봅니다': '본다',
340 + '네요': '네',
341 + '된답니다': '된대',
342 + '데요': '데',
343 + '봐요': '봐',
344 + '부러워요': '부러워',
345 + '바랄게요': '바랄게',
346 + '지나갑니다': "지가간다",
347 + '이뻐요': "이뻐",
348 + '지요': "지",
349 + '사세요': "사라",
350 + '던가요': "던가",
351 + '모릅니다': "몰라",
352 + '은가요': "은가",
353 + '심해요': "심해",
354 + '몰라요': "몰라",
355 + '라요': "라",
356 + '더라고요': '더라고',
357 + '입니다': '이라고',
358 + '는다면요': '는다면',
359 + '멋져요': '멋져',
360 + '다면요': '다면',
361 + '다니': '다나',
362 + '져요': '져',
363 + '만드세요': '만들어',
364 + '야죠': '야지',
365 + '죠': '지',
366 + '해줄게요': '해줄게',
367 + '대요': '대',
368 + '돌아갑시다': '돌아가자',
369 + '해보여요': '해봐',
370 + '라뇨': '라니',
371 + '편합니다': '편해',
372 + '합시다': '하자',
373 + '드세요': '먹어',
374 + '아름다워요': '아름답네',
375 + '드립니다': '줄게',
376 + '받아들여요': '받아들여',
377 + '건가요': '간기',
378 + '쏟아진다': '쏟아지네',
379 + '슬퍼요': '슬퍼',
380 + '해서요': '해서',
381 + '다릅니다': '다르다',
382 + '니다': '니',
383 + '내려요': '내려',
384 + '마셔요': '마셔',
385 + '아세요': '아냐',
386 + '변해요': '뱐헤',
387 + '드려요': '드려',
388 + '아요': '아',
389 + '어서요': '어서',
390 + '뜁니다': '뛴다',
391 + '속상해요': '속상해',
392 + '래요': '래',
393 + '까요': '까',
394 + '어야죠': '어야지',
395 + '라니': '라니',
396 + '해집니다': '해진다',
397 + '으련만': '으련만',
398 + '지워져요': '지워져',
399 + '잘라요': '잘라',
400 + '고요': '고',
401 + '셔야죠': '셔야지',
402 + '다쳐요': '다쳐',
403 + '는구나': '는구만',
404 + '은데요': '은데',
405 + '일까요': '일까',
406 + '인가요': '인가',
407 + '아닐까요': '아닐까',
408 + '텐데요': '텐데',
409 + '할게요': '할게',
410 + '보입니다': '보이네',
411 + '에요': '야',
412 + '걸요': '걸',
413 + '한답니다': '한대',
414 + '을까요': '을까',
415 + '못해요': '못해',
416 + '베푸세요': '베풀어',
417 + '어때요': '어떄',
418 + '더라구요': '더라구',
419 + '노라': '노라',
420 + '반가워요': '반가워',
421 + '군요': '군',
422 + '만납시다': '만나자',
423 + '어떠세요': '어때',
424 + '달라져요': '달라져',
425 + '예뻐요': '예뻐',
426 + '됩니다': '된다',
427 + '봅시다': '보자',
428 + '한대요': '한대',
429 + '싸워요': '싸워',
430 + '와요': '와',
431 + '인데요': '인데',
432 + '야': '야',
433 + '줄게요': '줄게',
434 + '기에요': '기',
435 + '던데요': '던데',
436 + '걸까요': '걸까',
437 + '신가요': '신가',
438 + '어요': '어',
439 + '따져요': '따져',
440 + '갈게요': '갈게',
441 + '봐': '봐',
442 + '나요': '나',
443 + '니까요': '니까',
444 + '마요': '마',
445 + '씁니다': '쓴다',
446 + '집니다': '진다',
447 + '건데요': '건데',
448 + '지웁시다': '지우자',
449 + '바랍니다': '바래',
450 + '는데요': '는데',
451 + '으니까요': '으니까',
452 + '셔요': '셔',
453 + '네여': '네',
454 + '달라요': '달라',
455 + '거려요': '거려',
456 + '보여요': '보여',
457 + '겁니다': '껄',
458 + '다': '다',
459 + '그래요': '그래',
460 + '한가요': '한가',
461 + '잖아요': '잖아',
462 + '한데요': '한데',
463 + '우세요': '우셈',
464 + '해야죠': '해야지',
465 + '세요': '셈',
466 + '걸려요': '걸려',
467 + '텐데': '텐데',
468 + '어딘가': '어딘가',
469 + '요': '',
470 + '흘러갑니다': '흘러간다',
471 + '줘요': '줘',
472 + '편해요': '편해',
473 + '거예요': '거야',
474 + '예요': '야',
475 + '습니다': '어',
476 + '아닌가요': '아닌가',
477 + '합니다': '한다',
478 + '사라집니다': '사라져',
479 + '드릴게요': '줄게',
480 + '다면': '다면',
481 + '그럴까요': '그럴까',
482 + '해요': '해',
483 + '답니다': '다',
484 + '주무세요': '자라',
485 + '마세요': '마라',
486 + '아픈가요': '아프냐',
487 + '그런가요': '그런가',
488 + '했잖아요': '했잖아',
489 + '버려요': '버려',
490 + '갑니다': '간다',
491 + '가요': '가',
492 + '라면요': '라면',
493 + '아야죠': '아야지',
494 + '살펴봐요': '살펴봐',
495 + '남겨요': '남겨',
496 + '내려놔요': '내려놔',
497 + '떨려요': '떨려',
498 + '랍니다': '란다',
499 + '돼요': '돼',
500 + '버텨요': '버텨',
501 + '만나': '만나',
502 + '일러요': '일러',
503 + '을게요': '을게',
504 + '갑시다': '가자',
505 + '나아요': '나아',
506 + '어려요': '어려',
507 + '온대요': '온대',
508 + '다고요': '다고',
509 + '할래요': '할래',
510 + '된대요': '된대',
511 + '어울려요': '어울려',
512 + '는군요': '는군',
513 + '볼까요': '볼까',
514 + '드릴까요': '줄까',
515 + '라던데요': '라던데',
516 + '올게요': '올게',
517 + '기뻐요': '기뻐',
518 + '아닙니다': '아냐',
519 + '둬요': '둬',
520 + '십니다': '십',
521 + '아파요': '아파',
522 + '생겨요': '생겨',
523 + '해줘요': '해줘',
524 + '로군요': '로군요',
525 + '시켜요': '시켜',
526 + '느껴져요': '느껴져',
527 + '가재요': '가재',
528 + '어 ': ' ',
529 + '느려요': '느려',
530 + '볼게요': '볼게',
531 + '쉬워요': '쉬워',
532 + '나빠요': '나빠',
533 + '불러줄게요': '불러줄게',
534 + '살쪄요': '살쪄',
535 + '봐야겠어요': '봐야겠어',
536 + '네': '네',
537 + '어': '어',
538 + '든지요': '든지',
539 + '드신다': '드심',
540 + '가져요': '가져',
541 + '할까요': '할까',
542 + '졸려요': '졸려',
543 + '그럴게요': '그럴게',
544 + '': '',
545 + '어린가': '어린가',
546 + '나와요': '나와',
547 + '빨라요': '빨라',
548 + '겠죠': '겠지',
549 + '졌어요': '졌어',
550 + '해봐요': '해봐',
551 + '게요': '게',
552 + '해드릴까요': '해줄까',
553 + '인걸요': '인걸',
554 + '했어요': '했어',
555 + '원해요': '원해',
556 + '는걸요': '는걸',
557 + '좋아합니다': '좋아해',
558 + '했으면': '했으면',
559 + '나갑니다': '나간다',
560 + '왔어요': '왔어',
561 + '해봅시다': '해보자',
562 + '물어봐요': '물어봐',
563 + '생겼어요': '생겼어',
564 + '해': '해',
565 + '다녀올게요': '다녀올게',
566 + '납시다': '나자'
567 + }
568 + return my_exword
...\ No newline at end of file ...\ No newline at end of file
1 +function send() {
2 + /*client side */
3 + var chat = document.createElement("li");
4 + var chat_input = document.getElementById("chat_input");
5 + var chat_text = chat_input.value;
6 + chat.className = "chat-bubble mine";
7 + chat.innerText = chat_text
8 + document.getElementById("chat_list").appendChild(chat);
9 + chat_input.value = "";
10 +
11 + /* ajax request */
12 + var request = new XMLHttpRequest();
13 + request.open("POST", `${window.location.host}/api/soft`, true);
14 + request.onreadystatechange = function() {
15 + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return;
16 + var bot_chat = document.createElement("li");
17 + bot_chat.className = "chat-bubble bots";
18 + bot_chat.innerText = JSON.parse(request.responseText).data;
19 + document.getElementById("chat_list").appendChild(bot_chat);
20 +
21 + };
22 + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
23 +request.send(JSON.stringify({"data":chat_text}));
24 +}
25 +
26 +function setDefault() {
27 + document.getElementById("chat_input").addEventListener("keyup", function(event) {
28 + let input = document.getElementById("chat_input").value;
29 + let button = document.getElementById("send_button");
30 + if(input.length>0)
31 + {
32 + button.removeAttribute("disabled");
33 + }
34 + else
35 + {
36 + button.setAttribute("disabled", "true");
37 + }
38 + // Number 13 is the "Enter" key on the keyboard
39 + if (event.keyCode === 13) {
40 + // Cancel the default action, if needed
41 + event.preventDefault();
42 + // Trigger the button element with a click
43 + button.click();
44 + }
45 + });
46 +}
1 +from flask import Flask, request, jsonify, send_from_directory
2 +import torch
3 +from torchtext import data
4 +from generation import inference, tokenizer1
5 +from Styling import make_special_token
6 +from model import Transformer
7 +
8 +app = Flask(__name__,
9 + static_url_path='',
10 + static_folder='static',)
11 +app.config['JSON_AS_ASCII'] = False
12 +device = torch.device('cpu')
13 +max_len = 40
14 +ID = data.Field(sequential=False,
15 + use_vocab=False)
16 +SA = data.Field(sequential=False,
17 + use_vocab=False)
18 +TEXT = data.Field(sequential=True,
19 + use_vocab=True,
20 + tokenize=tokenizer1,
21 + batch_first=True,
22 + fix_length=max_len,
23 + dtype=torch.int32
24 + )
25 +
26 +LABEL = data.Field(sequential=True,
27 + use_vocab=True,
28 + tokenize=tokenizer1,
29 + batch_first=True,
30 + fix_length=max_len,
31 + init_token='<sos>',
32 + eos_token='<eos>',
33 + dtype=torch.int32
34 + )
35 +text_specials, label_specials = make_special_token(False)
36 +train_data, _ = data.TabularDataset.splits(
37 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
38 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
39 +)
40 +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
41 +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
42 +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL)
43 +# rough_model = Transformer(args, TEXT, LABEL)
44 +soft_model.to(device)
45 +# rough_model.to(device)
46 +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict'])
47 +
48 +
49 +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict'])
50 +
51 +
52 +@app.route('/api/soft', methods=['POST'])
53 +def soft():
54 + if request.is_json:
55 + sentence = request.json["data"]
56 + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200
57 + else:
58 + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400
59 +
60 +# @app.route('/rough', methods=['POST'])
61 +# def rough():
62 +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200
63 +
64 +@app.route('/', methods=['GET'])
65 +def main_page():
66 + return send_from_directory('static','main.html')
67 +
68 +if __name__ == '__main__':
69 + app.run(host='0.0.0.0', port=8080)
1 +ul.no-bullets {
2 + list-style-type: none; /* Remove bullets */
3 + padding: 0; /* Remove padding */
4 + margin: 0; /* Remove margins */
5 + }
6 +
7 +.chat-bubble {
8 + position: relative;
9 + padding: 0.5em;
10 + margin-top: 0.25em;
11 + margin-bottom: 0.25em;
12 + border-radius: 0.4em;
13 + color: white;
14 +}
15 +.mine {
16 + background: #00aabb;
17 +}
18 +.bots {
19 + background: #cc78c5;
20 +}
21 +
22 +.chat-bubble:after {
23 + content: "";
24 + position: absolute;
25 + top: 50%;
26 + width: 0;
27 + height: 0;
28 + border: 0.625em solid transparent;
29 + border-top: 0;
30 + margin-top: -0.312em;
31 +
32 +}
33 +.chat-bubble.mine:after {
34 + right: 0;
35 +
36 + border-left-color: #00aabb;
37 + border-right: 0;
38 + margin-right: -0.625em;
39 +}
40 +
41 +.chat-bubble.bots:after {
42 + left: 0;
43 +
44 + border-right-color: #cc78c5;
45 + border-left: 0;
46 + margin-left: -0.625em;
47 +}
48 +
49 +#chat_input {
50 + width: 90%;
51 +}
52 +
53 +#send_button {
54 +
55 + width: 5%;
56 + border-radius: 0.4em;
57 + color: white;
58 + background-color: rgb(15, 145, 138);
59 +}
60 +
61 +.input-holder {
62 + position: fixed;
63 + left: 0;
64 + right: 0;
65 + bottom: 0;
66 + padding: 0.25em;
67 + background-color: lightseagreen;
68 +}
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from konlpy.tag import Mecab
3 +from torch.autograd import Variable
4 +from chatspace import ChatSpace
5 +
6 +spacer = ChatSpace()
7 +
8 +
9 +def tokenizer1(text: str):
10 + result_text = ''.join(c for c in text if c.isalnum())
11 + a = Mecab().morphs(result_text)
12 + return [a[i] for i in range(len(a))]
13 +
14 +
15 +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str):
16 +
17 + enc_input = tokenizer1(sentence)
18 + enc_input_index = []
19 +
20 + for tok in enc_input:
21 + enc_input_index.append(TEXT.vocab.stoi[tok])
22 +
23 + for j in range(max_len - len(enc_input_index)):
24 + enc_input_index.append(TEXT.vocab.stoi['<pad>'])
25 +
26 + enc_input_index = Variable(torch.LongTensor([enc_input_index]))
27 +
28 + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]])
29 +
30 + model.eval()
31 + pred = []
32 + for i in range(max_len):
33 + y_pred = model(enc_input_index.to(device), dec_input.to(device))
34 + y_pred_ids = y_pred.max(dim=-1)[1]
35 + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']:
36 + y_pred_ids = y_pred_ids.squeeze(0)
37 + print(">", end=" ")
38 + for idx in range(len(y_pred_ids)):
39 + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
40 + pred_sentence = "".join(pred)
41 + pred_str = spacer.space(pred_sentence)
42 + return pred_str
43 + else:
44 + pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
45 + return 'Error: Sentence is not end'
46 +
47 + dec_input = torch.cat(
48 + [dec_input.to(torch.device('cpu')),
49 + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
50 + return 'Error: Sentence is not predicted'
1 +import argparse
2 +import time
3 +import torch
4 +from torch import nn
5 +from torchtext import data
6 +from torchtext.data import BucketIterator
7 +from torchtext.data import TabularDataset
8 +
9 +from Styling import styling, make_special_token
10 +from generation import inference, tokenizer1
11 +from model import Transformer, GradualWarmupScheduler
12 +
13 +SEED = 1234
14 +
15 +
16 +
17 +
18 +def acc(yhat: torch.Tensor, y: torch.Tensor):
19 + with torch.no_grad():
20 + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value
21 + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거
22 + return _acc
23 +
24 +
25 +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool):
26 + total_loss = 0
27 + iter_num = 0
28 + tr_acc = 0
29 + model.train()
30 +
31 + for step, batch in enumerate(iterator):
32 + optimizer.zero_grad()
33 +
34 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
35 + dec_output = dec_input[:, 1:]
36 + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data)
37 +
38 + # emotion 과 체를 반영
39 + enc_input, dec_input, dec_outputs = \
40 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL)
41 +
42 + y_pred = model(enc_input, dec_input)
43 +
44 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
45 + dec_output = dec_outputs.view(-1).long()
46 +
47 + # padding 제외한 value index 추출
48 + real_value_index = [dec_output != 1] # <pad> == 1
49 +
50 + # padding 은 loss 계산시 제외
51 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
52 + loss.backward()
53 + optimizer.step()
54 +
55 + with torch.no_grad():
56 + train_acc = acc(y_pred, dec_output)
57 +
58 + total_loss += loss
59 + iter_num += 1
60 + tr_acc += train_acc
61 +
62 + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num
63 +
64 +
65 +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss):
66 + total_loss = 0
67 + iter_num = 0
68 + te_acc = 0
69 + model.eval()
70 +
71 + with torch.no_grad():
72 + for batch in iterator:
73 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
74 + dec_output = dec_input[:, 1:]
75 + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data)
76 +
77 + # emotion 과 체를 반영
78 + enc_input, dec_input, dec_outputs = \
79 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL)
80 +
81 + y_pred = model(enc_input, dec_input)
82 +
83 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
84 + dec_output = dec_outputs.view(-1).long()
85 +
86 + real_value_index = [dec_output != 1] # <pad> == 1
87 +
88 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
89 +
90 + with torch.no_grad():
91 + test_acc = acc(y_pred, dec_output)
92 + total_loss += loss
93 + iter_num += 1
94 + te_acc += test_acc
95 +
96 + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num
97 +
98 +
99 +# 데이터 전처리 및 loader return
100 +def data_preprocessing(args, device):
101 + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임.
102 + ID = data.Field(sequential=False,
103 + use_vocab=False)
104 +
105 + TEXT = data.Field(sequential=True,
106 + use_vocab=True,
107 + tokenize=tokenizer1,
108 + batch_first=True,
109 + fix_length=args.max_len,
110 + dtype=torch.int32
111 + )
112 +
113 + LABEL = data.Field(sequential=True,
114 + use_vocab=True,
115 + tokenize=tokenizer1,
116 + batch_first=True,
117 + fix_length=args.max_len,
118 + init_token='<sos>',
119 + eos_token='<eos>',
120 + dtype=torch.int32
121 + )
122 +
123 + SA = data.Field(sequential=False,
124 + use_vocab=False)
125 +
126 + train_data, test_data = TabularDataset.splits(
127 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
128 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
129 + )
130 +
131 + # TEXT, LABEL 에 필요한 special token 만듦.
132 + text_specials, label_specials = make_special_token(args.per_rough)
133 +
134 + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
135 + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
136 +
137 + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
138 + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
139 +
140 + return TEXT, LABEL, train_loader, test_loader
141 +
142 +
143 +def main(TEXT, LABEL, arguments):
144 +
145 + # print argparse
146 + for idx, (key, value) in enumerate(args.__dict__.items()):
147 + if idx == 0:
148 + print("\nargparse{\n", "\t", key, ":", value)
149 + elif idx == len(args.__dict__) - 1:
150 + print("\t", key, ":", value, "\n}")
151 + else:
152 + print("\t", key, ":", value)
153 +
154 + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL)
155 + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>'])
156 + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr)
157 + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs)
158 + if args.per_soft:
159 + sorted_path = 'sorted_model-soft.pth'
160 + else:
161 + sorted_path = 'sorted_model-rough.pth'
162 + model.to(device)
163 + if arguments.train:
164 + best_valid_loss = float('inf')
165 + for epoch in range(arguments.num_epochs):
166 + torch.manual_seed(SEED)
167 + start_time = time.time()
168 +
169 + # train, validation
170 + train_loss, train_acc = \
171 + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft,
172 + arguments.per_rough)
173 + valid_loss, valid_acc = test(model, test_loader, criterion)
174 +
175 + scheduler.step(epoch)
176 + # time cal
177 + end_time = time.time()
178 + elapsed_time = end_time - start_time
179 + epoch_mins = int(elapsed_time / 60)
180 + epoch_secs = int(elapsed_time - (epoch_mins * 60))
181 +
182 + # torch.save(model.state_dict(), sorted_path) # for some overfitting
183 + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장.
184 + if valid_loss < best_valid_loss:
185 + best_valid_loss = valid_loss
186 + torch.save({
187 + 'epoch': epoch,
188 + 'model_state_dict': model.state_dict(),
189 + 'optimizer_state_dict': optimizer.state_dict(),
190 + 'loss': valid_loss},
191 + sorted_path)
192 + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##')
193 +
194 + # print loss and acc
195 + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==')
196 + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==')
197 + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n')
198 +
199 +
200 +
201 + checkpoint = torch.load(sorted_path, map_location=device)
202 + model.load_state_dict(checkpoint['model_state_dict'])
203 +
204 + test_loss, test_acc = test(model, test_loader, criterion) # 아
205 + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==')
206 + print("\t-----------------------------")
207 + while True:
208 + sentence = input("문장을 입력하세요 : ")
209 + print(inference(device, args.max_len, TEXT, LABEL, model, sentence))
210 + print("\n")
211 +
212 +
213 +if __name__ == '__main__':
214 + # argparse 정의
215 + parser = argparse.ArgumentParser()
216 + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김.
217 + parser.add_argument('--batch_size', type=int, default=256)
218 + parser.add_argument('--num_epochs', type=int, default=22)
219 + parser.add_argument('--warming_up_epochs', type=int, default=5)
220 + parser.add_argument('--lr', type=float, default=0.0002)
221 + parser.add_argument('--embedding_dim', type=int, default=160)
222 + parser.add_argument('--nlayers', type=int, default=2)
223 + parser.add_argument('--nhead', type=int, default=2)
224 + parser.add_argument('--dropout', type=float, default=0.1)
225 + parser.add_argument('--train', action="store_true")
226 + group = parser.add_mutually_exclusive_group()
227 + group.add_argument('--per_soft', action="store_true")
228 + group.add_argument('--per_rough', action="store_true")
229 + args = parser.parse_args()
230 + print("-준비중-")
231 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
232 + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device)
233 + main(TEXT, LABEL, args)
1 +<!DOCTYPE html>
2 +<html>
3 + <head>
4 + <meta charset="UTF-8">
5 + <meta name="viewport" content="width=device-width, initial-scale=1">
6 + <title>Emotional Chatbot with Styler</title>
7 + <script src="app.js"></script>
8 + <link rel="stylesheet" type="text/css" href="chat.css" />
9 + </head>
10 + <body onload="setDefault()">
11 + <ul id="chat_list" class="list no-bullets">
12 +<li class="chat-bubble mine">(대충 적당한 대사)</li>
13 +<li class="chat-bubble bots">(대충 알맞은 답변)</li>
14 + </ul>
15 + <div class="input-holder">
16 + <input type="text" id="chat_input" autofocus/>
17 + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled>
18 + </div>
19 + </body>
20 +</html>
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import math
4 +
5 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
6 +
7 +
8 +class Transformer(nn.Module):
9 + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab):
10 + super(Transformer, self).__init__()
11 + self.d_model = embedding_dim
12 + self.n_head = nhead
13 + self.num_encoder_layers = nlayers
14 + self.num_decoder_layers = nlayers
15 + self.dim_feedforward = embedding_dim
16 + self.dropout = dropout
17 +
18 + self.SRC_vo = SRC_vocab
19 + self.TRG_vo = TRG_vocab
20 +
21 + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout)
22 +
23 + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model)
24 + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model)
25 +
26 + self.transformer = nn.Transformer(d_model=self.d_model,
27 + nhead=self.n_head,
28 + num_encoder_layers=self.num_encoder_layers,
29 + num_decoder_layers=self.num_decoder_layers,
30 + dim_feedforward=self.dim_feedforward,
31 + dropout=self.dropout)
32 + self.proj_vocab_layer = nn.Linear(
33 + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab))
34 +
35 +
36 + def forward(self, en_input, de_input):
37 + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model)
38 + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model)
39 + x_en_embed = self.pos_encoder(x_en_embed)
40 + x_de_embed = self.pos_encoder(x_de_embed)
41 +
42 + # Masking
43 + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>']
44 + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>']
45 + memory_key_padding_mask = src_key_padding_mask
46 + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1))
47 +
48 + x_en_embed = torch.einsum('ijk->jik', x_en_embed)
49 + x_de_embed = torch.einsum('ijk->jik', x_de_embed)
50 +
51 + feature = self.transformer(src=x_en_embed,
52 + tgt=x_de_embed,
53 + src_key_padding_mask=src_key_padding_mask,
54 + tgt_key_padding_mask=tgt_key_padding_mask,
55 + memory_key_padding_mask=memory_key_padding_mask,
56 + tgt_mask=tgt_mask.to(device))
57 +
58 + logits = self.proj_vocab_layer(feature)
59 + logits = torch.einsum('ijk->jik', logits)
60 +
61 + return logits
62 +
63 +
64 +class PositionalEncoding(nn.Module):
65 +
66 + def __init__(self, d_model, dropout, max_len=15000):
67 + super(PositionalEncoding, self).__init__()
68 + self.dropout = nn.Dropout(p=dropout)
69 +
70 + pe = torch.zeros(max_len, d_model)
71 + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
72 + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
73 + pe[:, 0::2] = torch.sin(position * div_term)
74 + pe[:, 1::2] = torch.cos(position * div_term)
75 + pe = pe.unsqueeze(0).transpose(0, 1)
76 + self.register_buffer('pe', pe)
77 +
78 + def forward(self, x):
79 + x = x + self.pe[:x.size(0), :]
80 + return self.dropout(x)
81 +
82 +
83 +from torch.optim.lr_scheduler import _LRScheduler
84 +from torch.optim.lr_scheduler import ReduceLROnPlateau
85 +
86 +
87 +class GradualWarmupScheduler(_LRScheduler):
88 + """ Gradually warm-up(increasing) learning rate in optimizer.
89 + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
90 + Args:
91 + optimizer (Optimizer): Wrapped optimizer.
92 + multiplier: target learning rate = base lr * multiplier
93 + total_epoch: target learning rate is reached at total_epoch, gradually
94 + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
95 + """
96 +
97 + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
98 + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
99 + self.multiplier = multiplier
100 + if self.multiplier <= 1.:
101 + raise ValueError('multiplier should be greater than 1.')
102 + self.total_epoch = total_epoch
103 + self.after_scheduler = after_scheduler
104 + self.finished = False
105 + super().__init__(optimizer)
106 +
107 + def get_lr(self):
108 + if self.last_epoch > self.total_epoch:
109 + if self.after_scheduler:
110 + if not self.finished:
111 + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
112 + self.finished = True
113 + return self.after_scheduler.get_lr()
114 + return [base_lr * self.multiplier for base_lr in self.base_lrs]
115 +
116 + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
117 + self.base_lrs]
118 +
119 + def step_ReduceLROnPlateau(self, metrics, epoch=None):
120 + if epoch is None:
121 + epoch = self.last_epoch + 1
122 + self.last_epoch = epoch if epoch != 0 else 1
123 + if self.last_epoch <= self.total_epoch:
124 + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
125 + self.base_lrs]
126 + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
127 + param_group['lr'] = lr
128 + else:
129 + if epoch is None:
130 + self.after_scheduler.step(metrics, None)
131 + else:
132 + self.after_scheduler.step(metrics, epoch - self.total_epoch)
133 +
134 + def step(self, epoch=None, metrics=None):
135 + if type(self.after_scheduler) != ReduceLROnPlateau:
136 + if self.finished and self.after_scheduler:
137 + if epoch is None:
138 + self.after_scheduler.step(None)
139 + else:
140 + self.after_scheduler.step(epoch - self.total_epoch)
141 + else:
142 + return super(GradualWarmupScheduler, self).step(epoch)
143 + else:
144 + self.step_ReduceLROnPlateau(metrics, epoch)
1 +torch~=1.4.0
2 +Flask~=1.1.2
3 +torchtext~=0.6.0
4 +hgtk~=0.1.3
5 +konlpy~=0.5.2
6 +chatspace~=1.0.1
...\ No newline at end of file ...\ No newline at end of file
This file is too large to display.
This file is too large to display.
1 +function send() {
2 + /*client side */
3 + var chat = document.createElement("li");
4 + var chat_input = document.getElementById("chat_input");
5 + var chat_text = chat_input.value;
6 + chat.className = "chat-bubble mine";
7 + chat.innerText = chat_text
8 + document.getElementById("chat_list").appendChild(chat);
9 + chat_input.value = "";
10 +
11 + /* ajax request */
12 + var request = new XMLHttpRequest();
13 + request.open("POST", `${window.location.protocol}//${window.location.host}/api/soft`, true);
14 + request.onreadystatechange = function() {
15 + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return;
16 + var bot_chat = document.createElement("li");
17 + bot_chat.className = "chat-bubble bots";
18 + bot_chat.innerText = JSON.parse(request.responseText).data;
19 + document.getElementById("chat_list").appendChild(bot_chat);
20 +
21 + };
22 + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
23 +request.send(JSON.stringify({"data":chat_text}));
24 +}
25 +
26 +function setDefault() {
27 + document.getElementById("chat_input").addEventListener("keyup", function(event) {
28 + let input = document.getElementById("chat_input").value;
29 + let button = document.getElementById("send_button");
30 + if(input.length>0)
31 + {
32 + button.removeAttribute("disabled");
33 + }
34 + else
35 + {
36 + button.setAttribute("disabled", "true");
37 + }
38 + // Number 13 is the "Enter" key on the keyboard
39 + if (event.keyCode === 13) {
40 + // Cancel the default action, if needed
41 + event.preventDefault();
42 + // Trigger the button element with a click
43 + button.click();
44 + }
45 + });
46 +}
1 +ul.no-bullets {
2 + list-style-type: none; /* Remove bullets */
3 + padding: 0; /* Remove padding */
4 + margin: 0; /* Remove margins */
5 + }
6 +
7 +.chat-bubble {
8 + position: relative;
9 + padding: 0.5em;
10 + margin-top: 0.25em;
11 + margin-bottom: 0.25em;
12 + border-radius: 0.4em;
13 + color: white;
14 +}
15 +.mine {
16 + background: #00aabb;
17 +}
18 +.bots {
19 + background: #cc78c5;
20 +}
21 +
22 +.chat-bubble:after {
23 + content: "";
24 + position: absolute;
25 + top: 50%;
26 + width: 0;
27 + height: 0;
28 + border: 0.625em solid transparent;
29 + border-top: 0;
30 + margin-top: -0.312em;
31 +
32 +}
33 +.chat-bubble.mine:after {
34 + right: 0;
35 +
36 + border-left-color: #00aabb;
37 + border-right: 0;
38 + margin-right: -0.625em;
39 +}
40 +
41 +.chat-bubble.bots:after {
42 + left: 0;
43 +
44 + border-right-color: #cc78c5;
45 + border-left: 0;
46 + margin-left: -0.625em;
47 +}
48 +
49 +#chat_input {
50 + width: 90%;
51 +}
52 +
53 +#send_button {
54 +
55 + width: 5%;
56 + border-radius: 0.4em;
57 + color: white;
58 + background-color: rgb(15, 145, 138);
59 +}
60 +
61 +.input-holder {
62 + position: fixed;
63 + left: 0;
64 + right: 0;
65 + bottom: 0;
66 + padding: 0.25em;
67 + background-color: lightseagreen;
68 +}
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
1 +<!DOCTYPE html>
2 +<html>
3 + <head>
4 + <meta charset="UTF-8">
5 + <meta name="viewport" content="width=device-width, initial-scale=1">
6 + <title>Emotional Chatbot with Styler</title>
7 + <script src="app.js"></script>
8 + <link rel="stylesheet" type="text/css" href="chat.css" />
9 + </head>
10 + <body onload="setDefault()">
11 + <ul id="chat_list" class="list no-bullets">
12 +<li class="chat-bubble mine">이렇게 질문을 하면...</li>
13 +<li class="chat-bubble bots">이렇게 답변이 옵니다!</li>
14 + </ul>
15 + <div class="input-holder">
16 + <input type="text" id="chat_input" autofocus/>
17 + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled>
18 + </div>
19 + </body>
20 +</html>
...\ No newline at end of file ...\ No newline at end of file
...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : ...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 :
10 - Force RTX 2080 Ti 10 - Force RTX 2080 Ti
11 - Python 3.6.8 11 - Python 3.6.8
12 - Pytorch 1.2.0 12 - Pytorch 1.2.0
13 +
14 +# Code
15 +## Chatbot
16 +
17 +### Chatbot_main.py
18 +챗봇 학습 및 시험에 사용되는 메인 파일입니다.
19 +### model.py
20 +챗봇에 이용되는 Transfer 모델 클래스 파일입니다.
21 +### generation.py
22 +추론 및 Beam search, Greedy search를 하는 파일입니다.
23 +### metric.py
24 +학습 성능을 측정하기 위한 모델입니다.\
25 +`acc(yhat, y)`\
26 +### Styling.py
27 +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다.
28 +### get_data.py
29 +데이터셋을 전처리하고 불러오기 위한 파일입니다.\
30 +`tokenizer1(text)`\
31 +* text: 토크나이징할 문자열
32 +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\
33 +`data_preprocessing(args, device)`\
34 +* args: argparser로 파싱한 NamedTuple
35 +* device: pytorch device
36 +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다.
37 +
38 +## KoBERT
39 +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\
40 +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\
41 +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\
42 +## Light_model
43 +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다.
44 +### light_chatbot.py
45 +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다.
46 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
47 +
48 +* train: 학습해 모델을 만들 경우에 사용합니다.
49 +사용하지 않으면 모델을 불러와 시험 합니다.
50 +* per_soft: soft 말투를 학습 또는 시험합니다.
51 +* per_rough: rough 말투를 학습 또는 시험합니다.
52 +두 옵션은 양립 불가능합니다.
53 +### app.py
54 +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\
55 +`POST /api/soft`\
56 +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\
57 +`GET /`\
58 +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다.
59 +### 기타
60 +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다.
...\ No newline at end of file ...\ No newline at end of file
......