Showing
19 changed files
with
1451 additions
and
0 deletions
Light_model/.gitignore
0 → 100644
Light_model/Dockerfile
0 → 100644
1 | +FROM ufoym/deepo:pytorch-cpu | ||
2 | +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile | ||
3 | +# Install JVM for Konlpy | ||
4 | +RUN apt-get update && \ | ||
5 | + apt-get upgrade -y && \ | ||
6 | + apt-get install -y \ | ||
7 | + openjdk-8-jdk wget curl git python3-dev \ | ||
8 | + language-pack-ko | ||
9 | + | ||
10 | +RUN locale-gen en_US.UTF-8 && \ | ||
11 | + update-locale LANG=en_US.UTF-8 | ||
12 | + | ||
13 | +# Install zsh | ||
14 | +RUN apt-get install -y zsh && \ | ||
15 | + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" | ||
16 | + | ||
17 | +# Install another packages | ||
18 | +RUN pip install --upgrade pip | ||
19 | +RUN pip install autopep8 | ||
20 | +RUN pip install konlpy | ||
21 | +RUN pip install torchtext pytorch_pretrained_bert | ||
22 | +# Install dependency of styling chatbot | ||
23 | +RUN pip install hgtk chatspace | ||
24 | + | ||
25 | +# Add Mecab-Ko | ||
26 | +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash | ||
27 | +# install styling chatbot by BM-K | ||
28 | +RUN git clone https://github.com/km19809/light_model.git | ||
29 | +RUN pip install -r light_model/requirements.txt | ||
30 | + | ||
31 | +# Add non-root user | ||
32 | +RUN adduser --disabled-password --gecos "" user | ||
33 | + | ||
34 | +# Reset Workdir | ||
35 | +WORKDIR /light_model | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/README.md
0 → 100644
1 | +# Light weight model of styling chatbot | ||
2 | +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\ | ||
3 | +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer) | ||
4 | + | ||
5 | +## 요구사항 | ||
6 | + | ||
7 | +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다. | ||
8 | +``` | ||
9 | +torch~=1.4.0 | ||
10 | +Flask~=1.1.2 | ||
11 | +torchtext~=0.6.0 | ||
12 | +hgtk~=0.1.3 | ||
13 | +konlpy~=0.5.2 | ||
14 | +chatspace~=1.0.1 | ||
15 | +``` | ||
16 | + | ||
17 | +## 사용법 | ||
18 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
19 | + | ||
20 | +* train: 학습해 모델을 만들 경우에 사용합니다. \ | ||
21 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
22 | +* per_soft: soft 말투를 학습 또는 시험합니다.\ | ||
23 | +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\ | ||
24 | +두 옵션은 양립 불가능합니다. | ||
25 | + | ||
26 | +`app.py` | ||
27 | + | ||
28 | +챗봇을 시험하기 위한 간단한 플라스크 서버입니다. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/Styling.py
0 → 100644
1 | +import torch | ||
2 | +import csv | ||
3 | +import hgtk | ||
4 | +from konlpy.tag import Mecab | ||
5 | +import random | ||
6 | + | ||
7 | +mecab = Mecab() | ||
8 | +empty_list = [] | ||
9 | +positive_emo = ['ㅎㅎ', '~'] | ||
10 | +negative_emo = ['...', 'ㅠㅠ'] | ||
11 | +asdf = [] | ||
12 | + | ||
13 | + | ||
14 | +# mecab 을 통한 형태소 분석. | ||
15 | +def mecab_token_pos_flat_fn(string: str): | ||
16 | + tokens_ko = mecab.pos(string) | ||
17 | + return [str(pos[0]) + '/' + str(pos[1]) for pos in tokens_ko] | ||
18 | + | ||
19 | + | ||
20 | +# rough 를 위한 함수. 대명사 NP (저, 제) 를 찾아 나 or 내 로 바꿔준다. | ||
21 | +def exchange_NP(target: str): | ||
22 | + keyword = [] | ||
23 | + ko_sp = mecab_token_pos_flat_fn(target) | ||
24 | + _idx = -1 # 실패 시 기본 값 | ||
25 | + for idx, word in enumerate(ko_sp): | ||
26 | + if word.find('NP') > 0: | ||
27 | + keyword.append(word.split('/')) | ||
28 | + _idx = idx | ||
29 | + break | ||
30 | + if not keyword: # keyword 가 비었을 때 | ||
31 | + return '', _idx, False | ||
32 | + | ||
33 | + if keyword[0][0] == '저': | ||
34 | + keyword[0][0] = '나' | ||
35 | + elif keyword[0][0] == '제': | ||
36 | + keyword[0][0] = '내' | ||
37 | + else: | ||
38 | + return keyword[0], _idx, False | ||
39 | + | ||
40 | + return keyword[0][0], _idx, True | ||
41 | + | ||
42 | + | ||
43 | +# 단어를 soft or rough 말투로 바꾸는 과정 | ||
44 | +def make_special_word(target: str, per_rough: bool, search_ec: bool): | ||
45 | + # mecab 를 통해 문장을 구분 (example output : ['오늘/MAG', '날씨/NNG', '좋/VA', '다/EF', './SF']) | ||
46 | + ko_sp = mecab_token_pos_flat_fn(target) | ||
47 | + | ||
48 | + keyword = [] | ||
49 | + _idx = -1 # 실패 시 기본 값 | ||
50 | + # word 에 종결어미 'EF' or 'EC' 가 포함 되어 있을 경우 index 와 keyword 추출. | ||
51 | + for idx, word in enumerate(ko_sp): | ||
52 | + if word.find('EF') > 0: | ||
53 | + keyword.append(word.split('/')) | ||
54 | + _idx = idx | ||
55 | + break | ||
56 | + if search_ec: | ||
57 | + if ko_sp[-2].find('EC') > 0: | ||
58 | + keyword.append(ko_sp[-2].split('/')) | ||
59 | + _idx = len(ko_sp) - 1 | ||
60 | + break | ||
61 | + else: | ||
62 | + continue | ||
63 | + | ||
64 | + # 'EF'가 없을 시 return. | ||
65 | + if not keyword: | ||
66 | + return '', _idx | ||
67 | + else: | ||
68 | + _keyword = keyword[0] | ||
69 | + | ||
70 | + if per_rough: | ||
71 | + return _keyword[0], _idx | ||
72 | + | ||
73 | + # hgtk 를 사용하여 keyword 를 쪼갬. (ex output : 하ᴥ세요) | ||
74 | + h_separation = hgtk.text.decompose(_keyword[0]) | ||
75 | + total_word = '' | ||
76 | + | ||
77 | + for idx, word in enumerate(h_separation): | ||
78 | + total_word += word | ||
79 | + | ||
80 | + # 'EF' 에 종성 'ㅇ' 를 붙여 Styling | ||
81 | + total_word = replace_right(total_word, "ᴥ", "ㅇᴥ", 1) | ||
82 | + | ||
83 | + # 다 이어 붙임. ' 하세요 -> 하세용 ' 으로 변환. | ||
84 | + h_combine = hgtk.text.compose(total_word) | ||
85 | + | ||
86 | + return h_combine, _idx | ||
87 | + | ||
88 | + | ||
89 | +# special token 을 만드는 함수 | ||
90 | +def make_special_token(per_rough: bool): | ||
91 | + # 감정을 나타내기 위한 special token | ||
92 | + target_special_voca = [] | ||
93 | + | ||
94 | + banmal_dict = get_rough_dic() | ||
95 | + | ||
96 | + # train data set 의 chatbot answer 에서 'EF' 를 뽑아 종성 'ㅇ' 을 붙인 special token 생성 | ||
97 | + with open('chatbot_0325_ALLLABEL_train.txt', 'r', encoding='utf-8') as f: | ||
98 | + rdr = csv.reader(f, delimiter='\t') | ||
99 | + for idx, line in enumerate(rdr): | ||
100 | + target = line[2] # chatbot answer | ||
101 | + exchange_word, _ = make_special_word(target, per_rough, False) | ||
102 | + target_special_voca.append(str(exchange_word)) | ||
103 | + target_special_voca = list(set(target_special_voca)) | ||
104 | + | ||
105 | + banmal_special_voca = [] | ||
106 | + for i in range(len(target_special_voca)): | ||
107 | + try: | ||
108 | + banmal_special_voca.append(banmal_dict[target_special_voca[i]]) | ||
109 | + except KeyError: | ||
110 | + if per_rough: | ||
111 | + print("not include banmal dictionary") | ||
112 | + pass | ||
113 | + | ||
114 | + # 임의 이모티콘 추가. | ||
115 | + target_special_voca.append('ㅎㅎ') | ||
116 | + target_special_voca.append('~') | ||
117 | + target_special_voca.append('ㅠㅠ') | ||
118 | + target_special_voca.append('...') | ||
119 | + target_special_voca = target_special_voca + banmal_special_voca | ||
120 | + | ||
121 | + # '<posi> : positive, <nega> : negative' 를 의미 | ||
122 | + return ['<posi>', '<nega>'], target_special_voca | ||
123 | + | ||
124 | + | ||
125 | +# python string 함수 replace 를 오른쪽부터 시작하는 함수. | ||
126 | +def replace_right(original: str, old: str, new: str, count_right: int): | ||
127 | + text = original | ||
128 | + | ||
129 | + count_find = original.count(old) | ||
130 | + # 바꿀 횟수가 문자열에 포함된 old보다 많다면 문자열에 포함된 old의 모든 개수(count_find)만큼 교체한다 아니라면 입력받은 개수(count)만큼 교체한다 | ||
131 | + repeat = count_find if count_right > count_find else count_right | ||
132 | + for _ in range(repeat): | ||
133 | + find_index = text.rfind(old) # 오른쪽부터 index를 찾기위해 rfind 사용 | ||
134 | + text = text[:find_index] + new + text[find_index + 1:] | ||
135 | + | ||
136 | + return text | ||
137 | + | ||
138 | + | ||
139 | +# transformer 에 input 과 output 으로 들어갈 tensor Styling 변환. | ||
140 | +def styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len: int, per_soft: bool, per_rough: bool, TEXT, LABEL): | ||
141 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
142 | + | ||
143 | + pad_tensor = torch.tensor([LABEL.vocab.stoi['<pad>']]).type(dtype=torch.int32).to(device) | ||
144 | + | ||
145 | + temp_enc = enc_input.data.cpu().numpy() | ||
146 | + batch_sentiment_list = [] | ||
147 | + | ||
148 | + # 부드러운 성격 | ||
149 | + if per_soft: | ||
150 | + # encoder input : 나는 너를 좋아해 <posi> <pad> <pad> ... - 형식으로 바꿔줌. | ||
151 | + for i in range(len(temp_enc)): | ||
152 | + for j in range(max_len): | ||
153 | + if temp_enc[i][j] == 1 and enc_label[i] == 0: | ||
154 | + temp_enc[i][j] = TEXT.vocab.stoi["<nega>"] | ||
155 | + batch_sentiment_list.append(0) | ||
156 | + break | ||
157 | + elif temp_enc[i][j] == 1 and enc_label[i] == 1: | ||
158 | + temp_enc[i][j] = TEXT.vocab.stoi["<posi>"] | ||
159 | + batch_sentiment_list.append(1) | ||
160 | + break | ||
161 | + | ||
162 | + enc_input = torch.tensor(temp_enc, dtype=torch.int32).to(device) | ||
163 | + | ||
164 | + for i in range(len(dec_outputs)): | ||
165 | + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1) | ||
166 | + | ||
167 | + temp_dec = dec_outputs.data.cpu().numpy() | ||
168 | + | ||
169 | + dec_outputs_sentiment_list = [] # decoder 에 들어가 감정표현 저장. | ||
170 | + | ||
171 | + # decoder outputs : 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
172 | + for i in range(len(temp_dec)): # i = batch size | ||
173 | + temp_sentence = '' | ||
174 | + sa_ = batch_sentiment_list[i] | ||
175 | + if sa_ == 0: | ||
176 | + sa_ = random.choice(negative_emo) | ||
177 | + elif sa_ == 1: | ||
178 | + sa_ = random.choice(positive_emo) | ||
179 | + dec_outputs_sentiment_list.append(sa_) | ||
180 | + | ||
181 | + for ix, token_i in enumerate(temp_dec[i]): | ||
182 | + if LABEL.vocab.itos[token_i] in ['<sos>', '<eos>', '<pad>']: | ||
183 | + continue | ||
184 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
185 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
186 | + exchange_word, idx = make_special_word(temp_sentence, per_rough, True) | ||
187 | + | ||
188 | + if exchange_word == '': | ||
189 | + for j in range(len(temp_dec[i])): | ||
190 | + if temp_dec[i][j] == LABEL.vocab.stoi['<eos>']: | ||
191 | + temp_dec[i][j] = LABEL.vocab.stoi[sa_] | ||
192 | + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>'] | ||
193 | + break | ||
194 | + continue | ||
195 | + | ||
196 | + for j in range(len(temp_dec[i])): | ||
197 | + if LABEL.vocab.itos[temp_dec[i][j]] == '<eos>': | ||
198 | + temp_dec[i][j - 1] = LABEL.vocab.stoi[exchange_word] | ||
199 | + temp_dec[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
200 | + temp_dec[i][j + 1] = LABEL.vocab.stoi['<eos>'] | ||
201 | + break | ||
202 | + elif temp_dec[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec[i]): | ||
203 | + print("\t-ERROR- No <EOS> token") | ||
204 | + exit() | ||
205 | + | ||
206 | + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).to(device) | ||
207 | + | ||
208 | + temp_dec_input = dec_input.data.cpu().numpy() | ||
209 | + # decoder input : <sos> 저도 좋아용 ㅎㅎ <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
210 | + for i in range(len(temp_dec_input)): | ||
211 | + temp_sentence = '' | ||
212 | + for ix, token_i in enumerate(temp_dec_input[i]): | ||
213 | + if LABEL.vocab.itos[token_i] in ['<sos>', '<eos>', '<pad>']: | ||
214 | + continue | ||
215 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
216 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
217 | + exchange_word, idx = make_special_word(temp_sentence, per_rough, True) | ||
218 | + | ||
219 | + if exchange_word == '': | ||
220 | + for j in range(len(temp_dec_input[i])): | ||
221 | + if temp_dec_input[i][j] == LABEL.vocab.stoi['<eos>']: | ||
222 | + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
223 | + temp_dec_input[i][j + 1] = LABEL.vocab.stoi['<eos>'] | ||
224 | + break | ||
225 | + continue | ||
226 | + | ||
227 | + for j in range(len(temp_dec_input[i])): | ||
228 | + if LABEL.vocab.itos[temp_dec_input[i][j]] == '<eos>': | ||
229 | + temp_dec_input[i][j - 1] = LABEL.vocab.stoi[exchange_word] | ||
230 | + temp_dec_input[i][j] = LABEL.vocab.stoi[dec_outputs_sentiment_list[i]] | ||
231 | + temp_dec_input[i][j + 1] = LABEL.vocab.stoi['<eos>'] | ||
232 | + break | ||
233 | + elif temp_dec_input[i][j] != LABEL.vocab.stoi['<eos>'] and j + 1 == len(temp_dec_input[i]): | ||
234 | + print("\t-ERROR- No <EOS> token") | ||
235 | + exit() | ||
236 | + | ||
237 | + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).to(device) | ||
238 | + | ||
239 | + # 거친 성격 | ||
240 | + elif per_rough: | ||
241 | + banmal_dic = get_rough_dic() | ||
242 | + | ||
243 | + for i in range(len(dec_outputs)): | ||
244 | + dec_outputs[i] = torch.cat([dec_output[i], pad_tensor], dim=-1) | ||
245 | + | ||
246 | + temp_dec = dec_outputs.data.cpu().numpy() | ||
247 | + | ||
248 | + # decoder outputs : 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
249 | + for i in range(len(temp_dec)): # i = batch size | ||
250 | + temp_sentence = '' | ||
251 | + for ix, token_i in enumerate(temp_dec[i]): | ||
252 | + if LABEL.vocab.itos[token_i] == '<eos>': | ||
253 | + break | ||
254 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
255 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
256 | + exchange_word, idx = make_special_word(temp_sentence, per_rough, True) | ||
257 | + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence) | ||
258 | + | ||
259 | + if exist: | ||
260 | + temp_dec[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word] | ||
261 | + | ||
262 | + if exchange_word == '': | ||
263 | + continue | ||
264 | + try: | ||
265 | + exchange_word = banmal_dic[exchange_word] | ||
266 | + except KeyError: | ||
267 | + asdf.append(exchange_word) | ||
268 | + print("not include banmal dictionary") | ||
269 | + pass | ||
270 | + | ||
271 | + temp_dec[i][idx] = LABEL.vocab.stoi[exchange_word] | ||
272 | + temp_dec[i][idx + 1] = LABEL.vocab.stoi['<eos>'] | ||
273 | + for k in range(idx + 2, max_len): | ||
274 | + temp_dec[i][k] = LABEL.vocab.stoi['<pad>'] | ||
275 | + | ||
276 | + # for j in range(len(temp_dec[i])): | ||
277 | + # if LABEL.vocab.itos[temp_dec[i][j]]=='<eos>': | ||
278 | + # break | ||
279 | + # print(LABEL.vocab.itos[temp_dec[i][j]], end='') | ||
280 | + # print() | ||
281 | + | ||
282 | + dec_outputs = torch.tensor(temp_dec, dtype=torch.int32).to(device) | ||
283 | + | ||
284 | + temp_dec_input = dec_input.data.cpu().numpy() | ||
285 | + # decoder input : <sos> 나도 좋아 <eos> <pad> <pad> ... - 형식으로 바꿔줌. | ||
286 | + for i in range(len(temp_dec_input)): | ||
287 | + temp_sentence = '' | ||
288 | + for ix, token_i in enumerate(temp_dec_input[i]): | ||
289 | + if ix == 0: | ||
290 | + continue # because of token <sos> | ||
291 | + if LABEL.vocab.itos[token_i] == '<eos>': | ||
292 | + break | ||
293 | + temp_sentence = temp_sentence + LABEL.vocab.itos[token_i] | ||
294 | + temp_sentence = temp_sentence + '.' # 마침표에 유무에 따라 형태소 분석이 달라짐. | ||
295 | + exchange_word, idx = make_special_word(temp_sentence, per_rough, True) | ||
296 | + exchange_NP_word, NP_idx, exist = exchange_NP(temp_sentence) | ||
297 | + idx = idx + 1 # because of token <sos> | ||
298 | + NP_idx = NP_idx + 1 | ||
299 | + | ||
300 | + if exist: | ||
301 | + temp_dec_input[i][NP_idx] = LABEL.vocab.stoi[exchange_NP_word] | ||
302 | + | ||
303 | + if exchange_word == '': | ||
304 | + continue | ||
305 | + | ||
306 | + try: | ||
307 | + exchange_word = banmal_dic[exchange_word] | ||
308 | + except KeyError: | ||
309 | + print("not include banmal dictionary") | ||
310 | + pass | ||
311 | + | ||
312 | + temp_dec_input[i][idx] = LABEL.vocab.stoi[exchange_word] | ||
313 | + temp_dec_input[i][idx + 1] = LABEL.vocab.stoi['<eos>'] | ||
314 | + | ||
315 | + for k in range(idx + 2, max_len): | ||
316 | + temp_dec_input[i][k] = LABEL.vocab.stoi['<pad>'] | ||
317 | + | ||
318 | + # for j in range(len(temp_dec_input[i])): | ||
319 | + # if LABEL.vocab.itos[temp_dec_input[i][j]]=='<eos>': | ||
320 | + # break | ||
321 | + # print(LABEL.vocab.itos[temp_dec_input[i][j]], end='') | ||
322 | + # print() | ||
323 | + | ||
324 | + dec_input = torch.tensor(temp_dec_input, dtype=torch.int32).to(device) | ||
325 | + | ||
326 | + return enc_input, dec_input, dec_outputs | ||
327 | + | ||
328 | + | ||
329 | +# 반말로 바꾸기위한 딕셔너리 | ||
330 | +def get_rough_dic(): | ||
331 | + my_exword = { | ||
332 | + '돌아와요': '돌아와', | ||
333 | + '으세요': '으셈', | ||
334 | + '잊어버려요': '잊어버려', | ||
335 | + '나온대요': '나온대', | ||
336 | + '될까요': '될까', | ||
337 | + '할텐데': '할텐데', | ||
338 | + '옵니다': '온다', | ||
339 | + '봅니다': '본다', | ||
340 | + '네요': '네', | ||
341 | + '된답니다': '된대', | ||
342 | + '데요': '데', | ||
343 | + '봐요': '봐', | ||
344 | + '부러워요': '부러워', | ||
345 | + '바랄게요': '바랄게', | ||
346 | + '지나갑니다': "지가간다", | ||
347 | + '이뻐요': "이뻐", | ||
348 | + '지요': "지", | ||
349 | + '사세요': "사라", | ||
350 | + '던가요': "던가", | ||
351 | + '모릅니다': "몰라", | ||
352 | + '은가요': "은가", | ||
353 | + '심해요': "심해", | ||
354 | + '몰라요': "몰라", | ||
355 | + '라요': "라", | ||
356 | + '더라고요': '더라고', | ||
357 | + '입니다': '이라고', | ||
358 | + '는다면요': '는다면', | ||
359 | + '멋져요': '멋져', | ||
360 | + '다면요': '다면', | ||
361 | + '다니': '다나', | ||
362 | + '져요': '져', | ||
363 | + '만드세요': '만들어', | ||
364 | + '야죠': '야지', | ||
365 | + '죠': '지', | ||
366 | + '해줄게요': '해줄게', | ||
367 | + '대요': '대', | ||
368 | + '돌아갑시다': '돌아가자', | ||
369 | + '해보여요': '해봐', | ||
370 | + '라뇨': '라니', | ||
371 | + '편합니다': '편해', | ||
372 | + '합시다': '하자', | ||
373 | + '드세요': '먹어', | ||
374 | + '아름다워요': '아름답네', | ||
375 | + '드립니다': '줄게', | ||
376 | + '받아들여요': '받아들여', | ||
377 | + '건가요': '간기', | ||
378 | + '쏟아진다': '쏟아지네', | ||
379 | + '슬퍼요': '슬퍼', | ||
380 | + '해서요': '해서', | ||
381 | + '다릅니다': '다르다', | ||
382 | + '니다': '니', | ||
383 | + '내려요': '내려', | ||
384 | + '마셔요': '마셔', | ||
385 | + '아세요': '아냐', | ||
386 | + '변해요': '뱐헤', | ||
387 | + '드려요': '드려', | ||
388 | + '아요': '아', | ||
389 | + '어서요': '어서', | ||
390 | + '뜁니다': '뛴다', | ||
391 | + '속상해요': '속상해', | ||
392 | + '래요': '래', | ||
393 | + '까요': '까', | ||
394 | + '어야죠': '어야지', | ||
395 | + '라니': '라니', | ||
396 | + '해집니다': '해진다', | ||
397 | + '으련만': '으련만', | ||
398 | + '지워져요': '지워져', | ||
399 | + '잘라요': '잘라', | ||
400 | + '고요': '고', | ||
401 | + '셔야죠': '셔야지', | ||
402 | + '다쳐요': '다쳐', | ||
403 | + '는구나': '는구만', | ||
404 | + '은데요': '은데', | ||
405 | + '일까요': '일까', | ||
406 | + '인가요': '인가', | ||
407 | + '아닐까요': '아닐까', | ||
408 | + '텐데요': '텐데', | ||
409 | + '할게요': '할게', | ||
410 | + '보입니다': '보이네', | ||
411 | + '에요': '야', | ||
412 | + '걸요': '걸', | ||
413 | + '한답니다': '한대', | ||
414 | + '을까요': '을까', | ||
415 | + '못해요': '못해', | ||
416 | + '베푸세요': '베풀어', | ||
417 | + '어때요': '어떄', | ||
418 | + '더라구요': '더라구', | ||
419 | + '노라': '노라', | ||
420 | + '반가워요': '반가워', | ||
421 | + '군요': '군', | ||
422 | + '만납시다': '만나자', | ||
423 | + '어떠세요': '어때', | ||
424 | + '달라져요': '달라져', | ||
425 | + '예뻐요': '예뻐', | ||
426 | + '됩니다': '된다', | ||
427 | + '봅시다': '보자', | ||
428 | + '한대요': '한대', | ||
429 | + '싸워요': '싸워', | ||
430 | + '와요': '와', | ||
431 | + '인데요': '인데', | ||
432 | + '야': '야', | ||
433 | + '줄게요': '줄게', | ||
434 | + '기에요': '기', | ||
435 | + '던데요': '던데', | ||
436 | + '걸까요': '걸까', | ||
437 | + '신가요': '신가', | ||
438 | + '어요': '어', | ||
439 | + '따져요': '따져', | ||
440 | + '갈게요': '갈게', | ||
441 | + '봐': '봐', | ||
442 | + '나요': '나', | ||
443 | + '니까요': '니까', | ||
444 | + '마요': '마', | ||
445 | + '씁니다': '쓴다', | ||
446 | + '집니다': '진다', | ||
447 | + '건데요': '건데', | ||
448 | + '지웁시다': '지우자', | ||
449 | + '바랍니다': '바래', | ||
450 | + '는데요': '는데', | ||
451 | + '으니까요': '으니까', | ||
452 | + '셔요': '셔', | ||
453 | + '네여': '네', | ||
454 | + '달라요': '달라', | ||
455 | + '거려요': '거려', | ||
456 | + '보여요': '보여', | ||
457 | + '겁니다': '껄', | ||
458 | + '다': '다', | ||
459 | + '그래요': '그래', | ||
460 | + '한가요': '한가', | ||
461 | + '잖아요': '잖아', | ||
462 | + '한데요': '한데', | ||
463 | + '우세요': '우셈', | ||
464 | + '해야죠': '해야지', | ||
465 | + '세요': '셈', | ||
466 | + '걸려요': '걸려', | ||
467 | + '텐데': '텐데', | ||
468 | + '어딘가': '어딘가', | ||
469 | + '요': '', | ||
470 | + '흘러갑니다': '흘러간다', | ||
471 | + '줘요': '줘', | ||
472 | + '편해요': '편해', | ||
473 | + '거예요': '거야', | ||
474 | + '예요': '야', | ||
475 | + '습니다': '어', | ||
476 | + '아닌가요': '아닌가', | ||
477 | + '합니다': '한다', | ||
478 | + '사라집니다': '사라져', | ||
479 | + '드릴게요': '줄게', | ||
480 | + '다면': '다면', | ||
481 | + '그럴까요': '그럴까', | ||
482 | + '해요': '해', | ||
483 | + '답니다': '다', | ||
484 | + '주무세요': '자라', | ||
485 | + '마세요': '마라', | ||
486 | + '아픈가요': '아프냐', | ||
487 | + '그런가요': '그런가', | ||
488 | + '했잖아요': '했잖아', | ||
489 | + '버려요': '버려', | ||
490 | + '갑니다': '간다', | ||
491 | + '가요': '가', | ||
492 | + '라면요': '라면', | ||
493 | + '아야죠': '아야지', | ||
494 | + '살펴봐요': '살펴봐', | ||
495 | + '남겨요': '남겨', | ||
496 | + '내려놔요': '내려놔', | ||
497 | + '떨려요': '떨려', | ||
498 | + '랍니다': '란다', | ||
499 | + '돼요': '돼', | ||
500 | + '버텨요': '버텨', | ||
501 | + '만나': '만나', | ||
502 | + '일러요': '일러', | ||
503 | + '을게요': '을게', | ||
504 | + '갑시다': '가자', | ||
505 | + '나아요': '나아', | ||
506 | + '어려요': '어려', | ||
507 | + '온대요': '온대', | ||
508 | + '다고요': '다고', | ||
509 | + '할래요': '할래', | ||
510 | + '된대요': '된대', | ||
511 | + '어울려요': '어울려', | ||
512 | + '는군요': '는군', | ||
513 | + '볼까요': '볼까', | ||
514 | + '드릴까요': '줄까', | ||
515 | + '라던데요': '라던데', | ||
516 | + '올게요': '올게', | ||
517 | + '기뻐요': '기뻐', | ||
518 | + '아닙니다': '아냐', | ||
519 | + '둬요': '둬', | ||
520 | + '십니다': '십', | ||
521 | + '아파요': '아파', | ||
522 | + '생겨요': '생겨', | ||
523 | + '해줘요': '해줘', | ||
524 | + '로군요': '로군요', | ||
525 | + '시켜요': '시켜', | ||
526 | + '느껴져요': '느껴져', | ||
527 | + '가재요': '가재', | ||
528 | + '어 ': ' ', | ||
529 | + '느려요': '느려', | ||
530 | + '볼게요': '볼게', | ||
531 | + '쉬워요': '쉬워', | ||
532 | + '나빠요': '나빠', | ||
533 | + '불러줄게요': '불러줄게', | ||
534 | + '살쪄요': '살쪄', | ||
535 | + '봐야겠어요': '봐야겠어', | ||
536 | + '네': '네', | ||
537 | + '어': '어', | ||
538 | + '든지요': '든지', | ||
539 | + '드신다': '드심', | ||
540 | + '가져요': '가져', | ||
541 | + '할까요': '할까', | ||
542 | + '졸려요': '졸려', | ||
543 | + '그럴게요': '그럴게', | ||
544 | + '': '', | ||
545 | + '어린가': '어린가', | ||
546 | + '나와요': '나와', | ||
547 | + '빨라요': '빨라', | ||
548 | + '겠죠': '겠지', | ||
549 | + '졌어요': '졌어', | ||
550 | + '해봐요': '해봐', | ||
551 | + '게요': '게', | ||
552 | + '해드릴까요': '해줄까', | ||
553 | + '인걸요': '인걸', | ||
554 | + '했어요': '했어', | ||
555 | + '원해요': '원해', | ||
556 | + '는걸요': '는걸', | ||
557 | + '좋아합니다': '좋아해', | ||
558 | + '했으면': '했으면', | ||
559 | + '나갑니다': '나간다', | ||
560 | + '왔어요': '왔어', | ||
561 | + '해봅시다': '해보자', | ||
562 | + '물어봐요': '물어봐', | ||
563 | + '생겼어요': '생겼어', | ||
564 | + '해': '해', | ||
565 | + '다녀올게요': '다녀올게', | ||
566 | + '납시다': '나자' | ||
567 | + } | ||
568 | + return my_exword | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/app.js
0 → 100644
1 | +function send() { | ||
2 | + /*client side */ | ||
3 | + var chat = document.createElement("li"); | ||
4 | + var chat_input = document.getElementById("chat_input"); | ||
5 | + var chat_text = chat_input.value; | ||
6 | + chat.className = "chat-bubble mine"; | ||
7 | + chat.innerText = chat_text | ||
8 | + document.getElementById("chat_list").appendChild(chat); | ||
9 | + chat_input.value = ""; | ||
10 | + | ||
11 | + /* ajax request */ | ||
12 | + var request = new XMLHttpRequest(); | ||
13 | + request.open("POST", `${window.location.host}/api/soft`, true); | ||
14 | + request.onreadystatechange = function() { | ||
15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
16 | + var bot_chat = document.createElement("li"); | ||
17 | + bot_chat.className = "chat-bubble bots"; | ||
18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
20 | + | ||
21 | + }; | ||
22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
23 | +request.send(JSON.stringify({"data":chat_text})); | ||
24 | +} | ||
25 | + | ||
26 | +function setDefault() { | ||
27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
28 | + let input = document.getElementById("chat_input").value; | ||
29 | + let button = document.getElementById("send_button"); | ||
30 | + if(input.length>0) | ||
31 | + { | ||
32 | + button.removeAttribute("disabled"); | ||
33 | + } | ||
34 | + else | ||
35 | + { | ||
36 | + button.setAttribute("disabled", "true"); | ||
37 | + } | ||
38 | + // Number 13 is the "Enter" key on the keyboard | ||
39 | + if (event.keyCode === 13) { | ||
40 | + // Cancel the default action, if needed | ||
41 | + event.preventDefault(); | ||
42 | + // Trigger the button element with a click | ||
43 | + button.click(); | ||
44 | + } | ||
45 | + }); | ||
46 | +} |
Light_model/app.py
0 → 100644
1 | +from flask import Flask, request, jsonify, send_from_directory | ||
2 | +import torch | ||
3 | +from torchtext import data | ||
4 | +from generation import inference, tokenizer1 | ||
5 | +from Styling import make_special_token | ||
6 | +from model import Transformer | ||
7 | + | ||
8 | +app = Flask(__name__, | ||
9 | + static_url_path='', | ||
10 | + static_folder='static',) | ||
11 | +app.config['JSON_AS_ASCII'] = False | ||
12 | +device = torch.device('cpu') | ||
13 | +max_len = 40 | ||
14 | +ID = data.Field(sequential=False, | ||
15 | + use_vocab=False) | ||
16 | +SA = data.Field(sequential=False, | ||
17 | + use_vocab=False) | ||
18 | +TEXT = data.Field(sequential=True, | ||
19 | + use_vocab=True, | ||
20 | + tokenize=tokenizer1, | ||
21 | + batch_first=True, | ||
22 | + fix_length=max_len, | ||
23 | + dtype=torch.int32 | ||
24 | + ) | ||
25 | + | ||
26 | +LABEL = data.Field(sequential=True, | ||
27 | + use_vocab=True, | ||
28 | + tokenize=tokenizer1, | ||
29 | + batch_first=True, | ||
30 | + fix_length=max_len, | ||
31 | + init_token='<sos>', | ||
32 | + eos_token='<eos>', | ||
33 | + dtype=torch.int32 | ||
34 | + ) | ||
35 | +text_specials, label_specials = make_special_token(False) | ||
36 | +train_data, _ = data.TabularDataset.splits( | ||
37 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
38 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
39 | +) | ||
40 | +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
41 | +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
42 | +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL) | ||
43 | +# rough_model = Transformer(args, TEXT, LABEL) | ||
44 | +soft_model.to(device) | ||
45 | +# rough_model.to(device) | ||
46 | +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict']) | ||
47 | + | ||
48 | + | ||
49 | +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict']) | ||
50 | + | ||
51 | + | ||
52 | +@app.route('/api/soft', methods=['POST']) | ||
53 | +def soft(): | ||
54 | + if request.is_json: | ||
55 | + sentence = request.json["data"] | ||
56 | + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200 | ||
57 | + else: | ||
58 | + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400 | ||
59 | + | ||
60 | +# @app.route('/rough', methods=['POST']) | ||
61 | +# def rough(): | ||
62 | +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200 | ||
63 | + | ||
64 | +@app.route('/', methods=['GET']) | ||
65 | +def main_page(): | ||
66 | + return send_from_directory('static','main.html') | ||
67 | + | ||
68 | +if __name__ == '__main__': | ||
69 | + app.run(host='0.0.0.0', port=8080) |
Light_model/chat.css
0 → 100644
1 | +ul.no-bullets { | ||
2 | + list-style-type: none; /* Remove bullets */ | ||
3 | + padding: 0; /* Remove padding */ | ||
4 | + margin: 0; /* Remove margins */ | ||
5 | + } | ||
6 | + | ||
7 | +.chat-bubble { | ||
8 | + position: relative; | ||
9 | + padding: 0.5em; | ||
10 | + margin-top: 0.25em; | ||
11 | + margin-bottom: 0.25em; | ||
12 | + border-radius: 0.4em; | ||
13 | + color: white; | ||
14 | +} | ||
15 | +.mine { | ||
16 | + background: #00aabb; | ||
17 | +} | ||
18 | +.bots { | ||
19 | + background: #cc78c5; | ||
20 | +} | ||
21 | + | ||
22 | +.chat-bubble:after { | ||
23 | + content: ""; | ||
24 | + position: absolute; | ||
25 | + top: 50%; | ||
26 | + width: 0; | ||
27 | + height: 0; | ||
28 | + border: 0.625em solid transparent; | ||
29 | + border-top: 0; | ||
30 | + margin-top: -0.312em; | ||
31 | + | ||
32 | +} | ||
33 | +.chat-bubble.mine:after { | ||
34 | + right: 0; | ||
35 | + | ||
36 | + border-left-color: #00aabb; | ||
37 | + border-right: 0; | ||
38 | + margin-right: -0.625em; | ||
39 | +} | ||
40 | + | ||
41 | +.chat-bubble.bots:after { | ||
42 | + left: 0; | ||
43 | + | ||
44 | + border-right-color: #cc78c5; | ||
45 | + border-left: 0; | ||
46 | + margin-left: -0.625em; | ||
47 | +} | ||
48 | + | ||
49 | +#chat_input { | ||
50 | + width: 90%; | ||
51 | +} | ||
52 | + | ||
53 | +#send_button { | ||
54 | + | ||
55 | + width: 5%; | ||
56 | + border-radius: 0.4em; | ||
57 | + color: white; | ||
58 | + background-color: rgb(15, 145, 138); | ||
59 | +} | ||
60 | + | ||
61 | +.input-holder { | ||
62 | + position: fixed; | ||
63 | + left: 0; | ||
64 | + right: 0; | ||
65 | + bottom: 0; | ||
66 | + padding: 0.25em; | ||
67 | + background-color: lightseagreen; | ||
68 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/generation.py
0 → 100644
1 | +import torch | ||
2 | +from konlpy.tag import Mecab | ||
3 | +from torch.autograd import Variable | ||
4 | +from chatspace import ChatSpace | ||
5 | + | ||
6 | +spacer = ChatSpace() | ||
7 | + | ||
8 | + | ||
9 | +def tokenizer1(text: str): | ||
10 | + result_text = ''.join(c for c in text if c.isalnum()) | ||
11 | + a = Mecab().morphs(result_text) | ||
12 | + return [a[i] for i in range(len(a))] | ||
13 | + | ||
14 | + | ||
15 | +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str): | ||
16 | + | ||
17 | + enc_input = tokenizer1(sentence) | ||
18 | + enc_input_index = [] | ||
19 | + | ||
20 | + for tok in enc_input: | ||
21 | + enc_input_index.append(TEXT.vocab.stoi[tok]) | ||
22 | + | ||
23 | + for j in range(max_len - len(enc_input_index)): | ||
24 | + enc_input_index.append(TEXT.vocab.stoi['<pad>']) | ||
25 | + | ||
26 | + enc_input_index = Variable(torch.LongTensor([enc_input_index])) | ||
27 | + | ||
28 | + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]]) | ||
29 | + | ||
30 | + model.eval() | ||
31 | + pred = [] | ||
32 | + for i in range(max_len): | ||
33 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
34 | + y_pred_ids = y_pred.max(dim=-1)[1] | ||
35 | + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']: | ||
36 | + y_pred_ids = y_pred_ids.squeeze(0) | ||
37 | + print(">", end=" ") | ||
38 | + for idx in range(len(y_pred_ids)): | ||
39 | + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | ||
40 | + pred_sentence = "".join(pred) | ||
41 | + pred_str = spacer.space(pred_sentence) | ||
42 | + return pred_str | ||
43 | + else: | ||
44 | + pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | ||
45 | + return 'Error: Sentence is not end' | ||
46 | + | ||
47 | + dec_input = torch.cat( | ||
48 | + [dec_input.to(torch.device('cpu')), | ||
49 | + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | ||
50 | + return 'Error: Sentence is not predicted' |
Light_model/light_chatbot.py
0 → 100644
1 | +import argparse | ||
2 | +import time | ||
3 | +import torch | ||
4 | +from torch import nn | ||
5 | +from torchtext import data | ||
6 | +from torchtext.data import BucketIterator | ||
7 | +from torchtext.data import TabularDataset | ||
8 | + | ||
9 | +from Styling import styling, make_special_token | ||
10 | +from generation import inference, tokenizer1 | ||
11 | +from model import Transformer, GradualWarmupScheduler | ||
12 | + | ||
13 | +SEED = 1234 | ||
14 | + | ||
15 | + | ||
16 | + | ||
17 | + | ||
18 | +def acc(yhat: torch.Tensor, y: torch.Tensor): | ||
19 | + with torch.no_grad(): | ||
20 | + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value | ||
21 | + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거 | ||
22 | + return _acc | ||
23 | + | ||
24 | + | ||
25 | +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool): | ||
26 | + total_loss = 0 | ||
27 | + iter_num = 0 | ||
28 | + tr_acc = 0 | ||
29 | + model.train() | ||
30 | + | ||
31 | + for step, batch in enumerate(iterator): | ||
32 | + optimizer.zero_grad() | ||
33 | + | ||
34 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
35 | + dec_output = dec_input[:, 1:] | ||
36 | + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data) | ||
37 | + | ||
38 | + # emotion 과 체를 반영 | ||
39 | + enc_input, dec_input, dec_outputs = \ | ||
40 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL) | ||
41 | + | ||
42 | + y_pred = model(enc_input, dec_input) | ||
43 | + | ||
44 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
45 | + dec_output = dec_outputs.view(-1).long() | ||
46 | + | ||
47 | + # padding 제외한 value index 추출 | ||
48 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
49 | + | ||
50 | + # padding 은 loss 계산시 제외 | ||
51 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
52 | + loss.backward() | ||
53 | + optimizer.step() | ||
54 | + | ||
55 | + with torch.no_grad(): | ||
56 | + train_acc = acc(y_pred, dec_output) | ||
57 | + | ||
58 | + total_loss += loss | ||
59 | + iter_num += 1 | ||
60 | + tr_acc += train_acc | ||
61 | + | ||
62 | + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num | ||
63 | + | ||
64 | + | ||
65 | +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss): | ||
66 | + total_loss = 0 | ||
67 | + iter_num = 0 | ||
68 | + te_acc = 0 | ||
69 | + model.eval() | ||
70 | + | ||
71 | + with torch.no_grad(): | ||
72 | + for batch in iterator: | ||
73 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
74 | + dec_output = dec_input[:, 1:] | ||
75 | + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data) | ||
76 | + | ||
77 | + # emotion 과 체를 반영 | ||
78 | + enc_input, dec_input, dec_outputs = \ | ||
79 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL) | ||
80 | + | ||
81 | + y_pred = model(enc_input, dec_input) | ||
82 | + | ||
83 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
84 | + dec_output = dec_outputs.view(-1).long() | ||
85 | + | ||
86 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
87 | + | ||
88 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
89 | + | ||
90 | + with torch.no_grad(): | ||
91 | + test_acc = acc(y_pred, dec_output) | ||
92 | + total_loss += loss | ||
93 | + iter_num += 1 | ||
94 | + te_acc += test_acc | ||
95 | + | ||
96 | + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num | ||
97 | + | ||
98 | + | ||
99 | +# 데이터 전처리 및 loader return | ||
100 | +def data_preprocessing(args, device): | ||
101 | + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임. | ||
102 | + ID = data.Field(sequential=False, | ||
103 | + use_vocab=False) | ||
104 | + | ||
105 | + TEXT = data.Field(sequential=True, | ||
106 | + use_vocab=True, | ||
107 | + tokenize=tokenizer1, | ||
108 | + batch_first=True, | ||
109 | + fix_length=args.max_len, | ||
110 | + dtype=torch.int32 | ||
111 | + ) | ||
112 | + | ||
113 | + LABEL = data.Field(sequential=True, | ||
114 | + use_vocab=True, | ||
115 | + tokenize=tokenizer1, | ||
116 | + batch_first=True, | ||
117 | + fix_length=args.max_len, | ||
118 | + init_token='<sos>', | ||
119 | + eos_token='<eos>', | ||
120 | + dtype=torch.int32 | ||
121 | + ) | ||
122 | + | ||
123 | + SA = data.Field(sequential=False, | ||
124 | + use_vocab=False) | ||
125 | + | ||
126 | + train_data, test_data = TabularDataset.splits( | ||
127 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
128 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
129 | + ) | ||
130 | + | ||
131 | + # TEXT, LABEL 에 필요한 special token 만듦. | ||
132 | + text_specials, label_specials = make_special_token(args.per_rough) | ||
133 | + | ||
134 | + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
135 | + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
136 | + | ||
137 | + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
138 | + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
139 | + | ||
140 | + return TEXT, LABEL, train_loader, test_loader | ||
141 | + | ||
142 | + | ||
143 | +def main(TEXT, LABEL, arguments): | ||
144 | + | ||
145 | + # print argparse | ||
146 | + for idx, (key, value) in enumerate(args.__dict__.items()): | ||
147 | + if idx == 0: | ||
148 | + print("\nargparse{\n", "\t", key, ":", value) | ||
149 | + elif idx == len(args.__dict__) - 1: | ||
150 | + print("\t", key, ":", value, "\n}") | ||
151 | + else: | ||
152 | + print("\t", key, ":", value) | ||
153 | + | ||
154 | + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL) | ||
155 | + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>']) | ||
156 | + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr) | ||
157 | + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs) | ||
158 | + if args.per_soft: | ||
159 | + sorted_path = 'sorted_model-soft.pth' | ||
160 | + else: | ||
161 | + sorted_path = 'sorted_model-rough.pth' | ||
162 | + model.to(device) | ||
163 | + if arguments.train: | ||
164 | + best_valid_loss = float('inf') | ||
165 | + for epoch in range(arguments.num_epochs): | ||
166 | + torch.manual_seed(SEED) | ||
167 | + start_time = time.time() | ||
168 | + | ||
169 | + # train, validation | ||
170 | + train_loss, train_acc = \ | ||
171 | + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft, | ||
172 | + arguments.per_rough) | ||
173 | + valid_loss, valid_acc = test(model, test_loader, criterion) | ||
174 | + | ||
175 | + scheduler.step(epoch) | ||
176 | + # time cal | ||
177 | + end_time = time.time() | ||
178 | + elapsed_time = end_time - start_time | ||
179 | + epoch_mins = int(elapsed_time / 60) | ||
180 | + epoch_secs = int(elapsed_time - (epoch_mins * 60)) | ||
181 | + | ||
182 | + # torch.save(model.state_dict(), sorted_path) # for some overfitting | ||
183 | + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장. | ||
184 | + if valid_loss < best_valid_loss: | ||
185 | + best_valid_loss = valid_loss | ||
186 | + torch.save({ | ||
187 | + 'epoch': epoch, | ||
188 | + 'model_state_dict': model.state_dict(), | ||
189 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
190 | + 'loss': valid_loss}, | ||
191 | + sorted_path) | ||
192 | + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##') | ||
193 | + | ||
194 | + # print loss and acc | ||
195 | + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==') | ||
196 | + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==') | ||
197 | + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n') | ||
198 | + | ||
199 | + | ||
200 | + | ||
201 | + checkpoint = torch.load(sorted_path, map_location=device) | ||
202 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
203 | + | ||
204 | + test_loss, test_acc = test(model, test_loader, criterion) # 아 | ||
205 | + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==') | ||
206 | + print("\t-----------------------------") | ||
207 | + while True: | ||
208 | + sentence = input("문장을 입력하세요 : ") | ||
209 | + print(inference(device, args.max_len, TEXT, LABEL, model, sentence)) | ||
210 | + print("\n") | ||
211 | + | ||
212 | + | ||
213 | +if __name__ == '__main__': | ||
214 | + # argparse 정의 | ||
215 | + parser = argparse.ArgumentParser() | ||
216 | + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김. | ||
217 | + parser.add_argument('--batch_size', type=int, default=256) | ||
218 | + parser.add_argument('--num_epochs', type=int, default=22) | ||
219 | + parser.add_argument('--warming_up_epochs', type=int, default=5) | ||
220 | + parser.add_argument('--lr', type=float, default=0.0002) | ||
221 | + parser.add_argument('--embedding_dim', type=int, default=160) | ||
222 | + parser.add_argument('--nlayers', type=int, default=2) | ||
223 | + parser.add_argument('--nhead', type=int, default=2) | ||
224 | + parser.add_argument('--dropout', type=float, default=0.1) | ||
225 | + parser.add_argument('--train', action="store_true") | ||
226 | + group = parser.add_mutually_exclusive_group() | ||
227 | + group.add_argument('--per_soft', action="store_true") | ||
228 | + group.add_argument('--per_rough', action="store_true") | ||
229 | + args = parser.parse_args() | ||
230 | + print("-준비중-") | ||
231 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
232 | + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device) | ||
233 | + main(TEXT, LABEL, args) |
Light_model/main.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html> | ||
3 | + <head> | ||
4 | + <meta charset="UTF-8"> | ||
5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
6 | + <title>Emotional Chatbot with Styler</title> | ||
7 | + <script src="app.js"></script> | ||
8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
9 | + </head> | ||
10 | + <body onload="setDefault()"> | ||
11 | + <ul id="chat_list" class="list no-bullets"> | ||
12 | +<li class="chat-bubble mine">(대충 적당한 대사)</li> | ||
13 | +<li class="chat-bubble bots">(대충 알맞은 답변)</li> | ||
14 | + </ul> | ||
15 | + <div class="input-holder"> | ||
16 | + <input type="text" id="chat_input" autofocus/> | ||
17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
18 | + </div> | ||
19 | + </body> | ||
20 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/model.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import math | ||
4 | + | ||
5 | +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
6 | + | ||
7 | + | ||
8 | +class Transformer(nn.Module): | ||
9 | + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab): | ||
10 | + super(Transformer, self).__init__() | ||
11 | + self.d_model = embedding_dim | ||
12 | + self.n_head = nhead | ||
13 | + self.num_encoder_layers = nlayers | ||
14 | + self.num_decoder_layers = nlayers | ||
15 | + self.dim_feedforward = embedding_dim | ||
16 | + self.dropout = dropout | ||
17 | + | ||
18 | + self.SRC_vo = SRC_vocab | ||
19 | + self.TRG_vo = TRG_vocab | ||
20 | + | ||
21 | + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout) | ||
22 | + | ||
23 | + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model) | ||
24 | + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model) | ||
25 | + | ||
26 | + self.transformer = nn.Transformer(d_model=self.d_model, | ||
27 | + nhead=self.n_head, | ||
28 | + num_encoder_layers=self.num_encoder_layers, | ||
29 | + num_decoder_layers=self.num_decoder_layers, | ||
30 | + dim_feedforward=self.dim_feedforward, | ||
31 | + dropout=self.dropout) | ||
32 | + self.proj_vocab_layer = nn.Linear( | ||
33 | + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab)) | ||
34 | + | ||
35 | + | ||
36 | + def forward(self, en_input, de_input): | ||
37 | + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model) | ||
38 | + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model) | ||
39 | + x_en_embed = self.pos_encoder(x_en_embed) | ||
40 | + x_de_embed = self.pos_encoder(x_de_embed) | ||
41 | + | ||
42 | + # Masking | ||
43 | + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>'] | ||
44 | + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>'] | ||
45 | + memory_key_padding_mask = src_key_padding_mask | ||
46 | + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1)) | ||
47 | + | ||
48 | + x_en_embed = torch.einsum('ijk->jik', x_en_embed) | ||
49 | + x_de_embed = torch.einsum('ijk->jik', x_de_embed) | ||
50 | + | ||
51 | + feature = self.transformer(src=x_en_embed, | ||
52 | + tgt=x_de_embed, | ||
53 | + src_key_padding_mask=src_key_padding_mask, | ||
54 | + tgt_key_padding_mask=tgt_key_padding_mask, | ||
55 | + memory_key_padding_mask=memory_key_padding_mask, | ||
56 | + tgt_mask=tgt_mask.to(device)) | ||
57 | + | ||
58 | + logits = self.proj_vocab_layer(feature) | ||
59 | + logits = torch.einsum('ijk->jik', logits) | ||
60 | + | ||
61 | + return logits | ||
62 | + | ||
63 | + | ||
64 | +class PositionalEncoding(nn.Module): | ||
65 | + | ||
66 | + def __init__(self, d_model, dropout, max_len=15000): | ||
67 | + super(PositionalEncoding, self).__init__() | ||
68 | + self.dropout = nn.Dropout(p=dropout) | ||
69 | + | ||
70 | + pe = torch.zeros(max_len, d_model) | ||
71 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | ||
72 | + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
73 | + pe[:, 0::2] = torch.sin(position * div_term) | ||
74 | + pe[:, 1::2] = torch.cos(position * div_term) | ||
75 | + pe = pe.unsqueeze(0).transpose(0, 1) | ||
76 | + self.register_buffer('pe', pe) | ||
77 | + | ||
78 | + def forward(self, x): | ||
79 | + x = x + self.pe[:x.size(0), :] | ||
80 | + return self.dropout(x) | ||
81 | + | ||
82 | + | ||
83 | +from torch.optim.lr_scheduler import _LRScheduler | ||
84 | +from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
85 | + | ||
86 | + | ||
87 | +class GradualWarmupScheduler(_LRScheduler): | ||
88 | + """ Gradually warm-up(increasing) learning rate in optimizer. | ||
89 | + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | ||
90 | + Args: | ||
91 | + optimizer (Optimizer): Wrapped optimizer. | ||
92 | + multiplier: target learning rate = base lr * multiplier | ||
93 | + total_epoch: target learning rate is reached at total_epoch, gradually | ||
94 | + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | ||
95 | + """ | ||
96 | + | ||
97 | + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | ||
98 | + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | ||
99 | + self.multiplier = multiplier | ||
100 | + if self.multiplier <= 1.: | ||
101 | + raise ValueError('multiplier should be greater than 1.') | ||
102 | + self.total_epoch = total_epoch | ||
103 | + self.after_scheduler = after_scheduler | ||
104 | + self.finished = False | ||
105 | + super().__init__(optimizer) | ||
106 | + | ||
107 | + def get_lr(self): | ||
108 | + if self.last_epoch > self.total_epoch: | ||
109 | + if self.after_scheduler: | ||
110 | + if not self.finished: | ||
111 | + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
112 | + self.finished = True | ||
113 | + return self.after_scheduler.get_lr() | ||
114 | + return [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
115 | + | ||
116 | + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
117 | + self.base_lrs] | ||
118 | + | ||
119 | + def step_ReduceLROnPlateau(self, metrics, epoch=None): | ||
120 | + if epoch is None: | ||
121 | + epoch = self.last_epoch + 1 | ||
122 | + self.last_epoch = epoch if epoch != 0 else 1 | ||
123 | + if self.last_epoch <= self.total_epoch: | ||
124 | + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
125 | + self.base_lrs] | ||
126 | + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | ||
127 | + param_group['lr'] = lr | ||
128 | + else: | ||
129 | + if epoch is None: | ||
130 | + self.after_scheduler.step(metrics, None) | ||
131 | + else: | ||
132 | + self.after_scheduler.step(metrics, epoch - self.total_epoch) | ||
133 | + | ||
134 | + def step(self, epoch=None, metrics=None): | ||
135 | + if type(self.after_scheduler) != ReduceLROnPlateau: | ||
136 | + if self.finished and self.after_scheduler: | ||
137 | + if epoch is None: | ||
138 | + self.after_scheduler.step(None) | ||
139 | + else: | ||
140 | + self.after_scheduler.step(epoch - self.total_epoch) | ||
141 | + else: | ||
142 | + return super(GradualWarmupScheduler, self).step(epoch) | ||
143 | + else: | ||
144 | + self.step_ReduceLROnPlateau(metrics, epoch) |
Light_model/requirements.txt
0 → 100644
Light_model/sorted_model-rough.pth
0 → 100644
This file is too large to display.
Light_model/sorted_model-soft.pth
0 → 100644
This file is too large to display.
Light_model/static/app.js
0 → 100644
1 | +function send() { | ||
2 | + /*client side */ | ||
3 | + var chat = document.createElement("li"); | ||
4 | + var chat_input = document.getElementById("chat_input"); | ||
5 | + var chat_text = chat_input.value; | ||
6 | + chat.className = "chat-bubble mine"; | ||
7 | + chat.innerText = chat_text | ||
8 | + document.getElementById("chat_list").appendChild(chat); | ||
9 | + chat_input.value = ""; | ||
10 | + | ||
11 | + /* ajax request */ | ||
12 | + var request = new XMLHttpRequest(); | ||
13 | + request.open("POST", `${window.location.protocol}//${window.location.host}/api/soft`, true); | ||
14 | + request.onreadystatechange = function() { | ||
15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
16 | + var bot_chat = document.createElement("li"); | ||
17 | + bot_chat.className = "chat-bubble bots"; | ||
18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
20 | + | ||
21 | + }; | ||
22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
23 | +request.send(JSON.stringify({"data":chat_text})); | ||
24 | +} | ||
25 | + | ||
26 | +function setDefault() { | ||
27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
28 | + let input = document.getElementById("chat_input").value; | ||
29 | + let button = document.getElementById("send_button"); | ||
30 | + if(input.length>0) | ||
31 | + { | ||
32 | + button.removeAttribute("disabled"); | ||
33 | + } | ||
34 | + else | ||
35 | + { | ||
36 | + button.setAttribute("disabled", "true"); | ||
37 | + } | ||
38 | + // Number 13 is the "Enter" key on the keyboard | ||
39 | + if (event.keyCode === 13) { | ||
40 | + // Cancel the default action, if needed | ||
41 | + event.preventDefault(); | ||
42 | + // Trigger the button element with a click | ||
43 | + button.click(); | ||
44 | + } | ||
45 | + }); | ||
46 | +} |
Light_model/static/chat.css
0 → 100644
1 | +ul.no-bullets { | ||
2 | + list-style-type: none; /* Remove bullets */ | ||
3 | + padding: 0; /* Remove padding */ | ||
4 | + margin: 0; /* Remove margins */ | ||
5 | + } | ||
6 | + | ||
7 | +.chat-bubble { | ||
8 | + position: relative; | ||
9 | + padding: 0.5em; | ||
10 | + margin-top: 0.25em; | ||
11 | + margin-bottom: 0.25em; | ||
12 | + border-radius: 0.4em; | ||
13 | + color: white; | ||
14 | +} | ||
15 | +.mine { | ||
16 | + background: #00aabb; | ||
17 | +} | ||
18 | +.bots { | ||
19 | + background: #cc78c5; | ||
20 | +} | ||
21 | + | ||
22 | +.chat-bubble:after { | ||
23 | + content: ""; | ||
24 | + position: absolute; | ||
25 | + top: 50%; | ||
26 | + width: 0; | ||
27 | + height: 0; | ||
28 | + border: 0.625em solid transparent; | ||
29 | + border-top: 0; | ||
30 | + margin-top: -0.312em; | ||
31 | + | ||
32 | +} | ||
33 | +.chat-bubble.mine:after { | ||
34 | + right: 0; | ||
35 | + | ||
36 | + border-left-color: #00aabb; | ||
37 | + border-right: 0; | ||
38 | + margin-right: -0.625em; | ||
39 | +} | ||
40 | + | ||
41 | +.chat-bubble.bots:after { | ||
42 | + left: 0; | ||
43 | + | ||
44 | + border-right-color: #cc78c5; | ||
45 | + border-left: 0; | ||
46 | + margin-left: -0.625em; | ||
47 | +} | ||
48 | + | ||
49 | +#chat_input { | ||
50 | + width: 90%; | ||
51 | +} | ||
52 | + | ||
53 | +#send_button { | ||
54 | + | ||
55 | + width: 5%; | ||
56 | + border-radius: 0.4em; | ||
57 | + color: white; | ||
58 | + background-color: rgb(15, 145, 138); | ||
59 | +} | ||
60 | + | ||
61 | +.input-holder { | ||
62 | + position: fixed; | ||
63 | + left: 0; | ||
64 | + right: 0; | ||
65 | + bottom: 0; | ||
66 | + padding: 0.25em; | ||
67 | + background-color: lightseagreen; | ||
68 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/static/favicon.ico
0 → 100644
No preview for this file type
Light_model/static/main.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html> | ||
3 | + <head> | ||
4 | + <meta charset="UTF-8"> | ||
5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
6 | + <title>Emotional Chatbot with Styler</title> | ||
7 | + <script src="app.js"></script> | ||
8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
9 | + </head> | ||
10 | + <body onload="setDefault()"> | ||
11 | + <ul id="chat_list" class="list no-bullets"> | ||
12 | +<li class="chat-bubble mine">이렇게 질문을 하면...</li> | ||
13 | +<li class="chat-bubble bots">이렇게 답변이 옵니다!</li> | ||
14 | + </ul> | ||
15 | + <div class="input-holder"> | ||
16 | + <input type="text" id="chat_input" autofocus/> | ||
17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
18 | + </div> | ||
19 | + </body> | ||
20 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : | ... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : |
10 | - Force RTX 2080 Ti | 10 | - Force RTX 2080 Ti |
11 | - Python 3.6.8 | 11 | - Python 3.6.8 |
12 | - Pytorch 1.2.0 | 12 | - Pytorch 1.2.0 |
13 | + | ||
14 | +# Code | ||
15 | +## Chatbot | ||
16 | + | ||
17 | +### Chatbot_main.py | ||
18 | +챗봇 학습 및 시험에 사용되는 메인 파일입니다. | ||
19 | +### model.py | ||
20 | +챗봇에 이용되는 Transfer 모델 클래스 파일입니다. | ||
21 | +### generation.py | ||
22 | +추론 및 Beam search, Greedy search를 하는 파일입니다. | ||
23 | +### metric.py | ||
24 | +학습 성능을 측정하기 위한 모델입니다.\ | ||
25 | +`acc(yhat, y)`\ | ||
26 | +### Styling.py | ||
27 | +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다. | ||
28 | +### get_data.py | ||
29 | +데이터셋을 전처리하고 불러오기 위한 파일입니다.\ | ||
30 | +`tokenizer1(text)`\ | ||
31 | +* text: 토크나이징할 문자열 | ||
32 | +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\ | ||
33 | +`data_preprocessing(args, device)`\ | ||
34 | +* args: argparser로 파싱한 NamedTuple | ||
35 | +* device: pytorch device | ||
36 | +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다. | ||
37 | + | ||
38 | +## KoBERT | ||
39 | +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\ | ||
40 | +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\ | ||
41 | +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\ | ||
42 | +## Light_model | ||
43 | +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다. | ||
44 | +### light_chatbot.py | ||
45 | +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다. | ||
46 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
47 | + | ||
48 | +* train: 학습해 모델을 만들 경우에 사용합니다. | ||
49 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
50 | +* per_soft: soft 말투를 학습 또는 시험합니다. | ||
51 | +* per_rough: rough 말투를 학습 또는 시험합니다. | ||
52 | +두 옵션은 양립 불가능합니다. | ||
53 | +### app.py | ||
54 | +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\ | ||
55 | +`POST /api/soft`\ | ||
56 | +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\ | ||
57 | +`GET /`\ | ||
58 | +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다. | ||
59 | +### 기타 | ||
60 | +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다. | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment