Showing
19 changed files
with
883 additions
and
0 deletions
Light_model/.gitignore
0 → 100644
Light_model/Dockerfile
0 → 100644
1 | +FROM ufoym/deepo:pytorch-cpu | ||
2 | +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile | ||
3 | +# Install JVM for Konlpy | ||
4 | +RUN apt-get update && \ | ||
5 | + apt-get upgrade -y && \ | ||
6 | + apt-get install -y \ | ||
7 | + openjdk-8-jdk wget curl git python3-dev \ | ||
8 | + language-pack-ko | ||
9 | + | ||
10 | +RUN locale-gen en_US.UTF-8 && \ | ||
11 | + update-locale LANG=en_US.UTF-8 | ||
12 | + | ||
13 | +# Install zsh | ||
14 | +RUN apt-get install -y zsh && \ | ||
15 | + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" | ||
16 | + | ||
17 | +# Install another packages | ||
18 | +RUN pip install --upgrade pip | ||
19 | +RUN pip install autopep8 | ||
20 | +RUN pip install konlpy | ||
21 | +RUN pip install torchtext pytorch_pretrained_bert | ||
22 | +# Install dependency of styling chatbot | ||
23 | +RUN pip install hgtk chatspace | ||
24 | + | ||
25 | +# Add Mecab-Ko | ||
26 | +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash | ||
27 | +# install styling chatbot by BM-K | ||
28 | +RUN git clone https://github.com/km19809/light_model.git | ||
29 | +RUN pip install -r light_model/requirements.txt | ||
30 | + | ||
31 | +# Add non-root user | ||
32 | +RUN adduser --disabled-password --gecos "" user | ||
33 | + | ||
34 | +# Reset Workdir | ||
35 | +WORKDIR /light_model | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/README.md
0 → 100644
1 | +# Light weight model of styling chatbot | ||
2 | +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\ | ||
3 | +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer) | ||
4 | + | ||
5 | +## 요구사항 | ||
6 | + | ||
7 | +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다. | ||
8 | +``` | ||
9 | +torch~=1.4.0 | ||
10 | +Flask~=1.1.2 | ||
11 | +torchtext~=0.6.0 | ||
12 | +hgtk~=0.1.3 | ||
13 | +konlpy~=0.5.2 | ||
14 | +chatspace~=1.0.1 | ||
15 | +``` | ||
16 | + | ||
17 | +## 사용법 | ||
18 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
19 | + | ||
20 | +* train: 학습해 모델을 만들 경우에 사용합니다. \ | ||
21 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
22 | +* per_soft: soft 말투를 학습 또는 시험합니다.\ | ||
23 | +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\ | ||
24 | +두 옵션은 양립 불가능합니다. | ||
25 | + | ||
26 | +`app.py` | ||
27 | + | ||
28 | +챗봇을 시험하기 위한 간단한 플라스크 서버입니다. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/Styling.py
0 → 100644
This diff is collapsed. Click to expand it.
Light_model/app.js
0 → 100644
1 | +function send() { | ||
2 | + /*client side */ | ||
3 | + var chat = document.createElement("li"); | ||
4 | + var chat_input = document.getElementById("chat_input"); | ||
5 | + var chat_text = chat_input.value; | ||
6 | + chat.className = "chat-bubble mine"; | ||
7 | + chat.innerText = chat_text | ||
8 | + document.getElementById("chat_list").appendChild(chat); | ||
9 | + chat_input.value = ""; | ||
10 | + | ||
11 | + /* ajax request */ | ||
12 | + var request = new XMLHttpRequest(); | ||
13 | + request.open("POST", `${window.location.host}/api/soft`, true); | ||
14 | + request.onreadystatechange = function() { | ||
15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
16 | + var bot_chat = document.createElement("li"); | ||
17 | + bot_chat.className = "chat-bubble bots"; | ||
18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
20 | + | ||
21 | + }; | ||
22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
23 | +request.send(JSON.stringify({"data":chat_text})); | ||
24 | +} | ||
25 | + | ||
26 | +function setDefault() { | ||
27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
28 | + let input = document.getElementById("chat_input").value; | ||
29 | + let button = document.getElementById("send_button"); | ||
30 | + if(input.length>0) | ||
31 | + { | ||
32 | + button.removeAttribute("disabled"); | ||
33 | + } | ||
34 | + else | ||
35 | + { | ||
36 | + button.setAttribute("disabled", "true"); | ||
37 | + } | ||
38 | + // Number 13 is the "Enter" key on the keyboard | ||
39 | + if (event.keyCode === 13) { | ||
40 | + // Cancel the default action, if needed | ||
41 | + event.preventDefault(); | ||
42 | + // Trigger the button element with a click | ||
43 | + button.click(); | ||
44 | + } | ||
45 | + }); | ||
46 | +} |
Light_model/app.py
0 → 100644
1 | +from flask import Flask, request, jsonify, send_from_directory | ||
2 | +import torch | ||
3 | +from torchtext import data | ||
4 | +from generation import inference, tokenizer1 | ||
5 | +from Styling import make_special_token | ||
6 | +from model import Transformer | ||
7 | + | ||
8 | +app = Flask(__name__, | ||
9 | + static_url_path='', | ||
10 | + static_folder='static',) | ||
11 | +app.config['JSON_AS_ASCII'] = False | ||
12 | +device = torch.device('cpu') | ||
13 | +max_len = 40 | ||
14 | +ID = data.Field(sequential=False, | ||
15 | + use_vocab=False) | ||
16 | +SA = data.Field(sequential=False, | ||
17 | + use_vocab=False) | ||
18 | +TEXT = data.Field(sequential=True, | ||
19 | + use_vocab=True, | ||
20 | + tokenize=tokenizer1, | ||
21 | + batch_first=True, | ||
22 | + fix_length=max_len, | ||
23 | + dtype=torch.int32 | ||
24 | + ) | ||
25 | + | ||
26 | +LABEL = data.Field(sequential=True, | ||
27 | + use_vocab=True, | ||
28 | + tokenize=tokenizer1, | ||
29 | + batch_first=True, | ||
30 | + fix_length=max_len, | ||
31 | + init_token='<sos>', | ||
32 | + eos_token='<eos>', | ||
33 | + dtype=torch.int32 | ||
34 | + ) | ||
35 | +text_specials, label_specials = make_special_token(False) | ||
36 | +train_data, _ = data.TabularDataset.splits( | ||
37 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
38 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
39 | +) | ||
40 | +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
41 | +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
42 | +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL) | ||
43 | +# rough_model = Transformer(args, TEXT, LABEL) | ||
44 | +soft_model.to(device) | ||
45 | +# rough_model.to(device) | ||
46 | +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict']) | ||
47 | + | ||
48 | + | ||
49 | +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict']) | ||
50 | + | ||
51 | + | ||
52 | +@app.route('/api/soft', methods=['POST']) | ||
53 | +def soft(): | ||
54 | + if request.is_json: | ||
55 | + sentence = request.json["data"] | ||
56 | + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200 | ||
57 | + else: | ||
58 | + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400 | ||
59 | + | ||
60 | +# @app.route('/rough', methods=['POST']) | ||
61 | +# def rough(): | ||
62 | +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200 | ||
63 | + | ||
64 | +@app.route('/', methods=['GET']) | ||
65 | +def main_page(): | ||
66 | + return send_from_directory('static','main.html') | ||
67 | + | ||
68 | +if __name__ == '__main__': | ||
69 | + app.run(host='0.0.0.0', port=8080) |
Light_model/chat.css
0 → 100644
1 | +ul.no-bullets { | ||
2 | + list-style-type: none; /* Remove bullets */ | ||
3 | + padding: 0; /* Remove padding */ | ||
4 | + margin: 0; /* Remove margins */ | ||
5 | + } | ||
6 | + | ||
7 | +.chat-bubble { | ||
8 | + position: relative; | ||
9 | + padding: 0.5em; | ||
10 | + margin-top: 0.25em; | ||
11 | + margin-bottom: 0.25em; | ||
12 | + border-radius: 0.4em; | ||
13 | + color: white; | ||
14 | +} | ||
15 | +.mine { | ||
16 | + background: #00aabb; | ||
17 | +} | ||
18 | +.bots { | ||
19 | + background: #cc78c5; | ||
20 | +} | ||
21 | + | ||
22 | +.chat-bubble:after { | ||
23 | + content: ""; | ||
24 | + position: absolute; | ||
25 | + top: 50%; | ||
26 | + width: 0; | ||
27 | + height: 0; | ||
28 | + border: 0.625em solid transparent; | ||
29 | + border-top: 0; | ||
30 | + margin-top: -0.312em; | ||
31 | + | ||
32 | +} | ||
33 | +.chat-bubble.mine:after { | ||
34 | + right: 0; | ||
35 | + | ||
36 | + border-left-color: #00aabb; | ||
37 | + border-right: 0; | ||
38 | + margin-right: -0.625em; | ||
39 | +} | ||
40 | + | ||
41 | +.chat-bubble.bots:after { | ||
42 | + left: 0; | ||
43 | + | ||
44 | + border-right-color: #cc78c5; | ||
45 | + border-left: 0; | ||
46 | + margin-left: -0.625em; | ||
47 | +} | ||
48 | + | ||
49 | +#chat_input { | ||
50 | + width: 90%; | ||
51 | +} | ||
52 | + | ||
53 | +#send_button { | ||
54 | + | ||
55 | + width: 5%; | ||
56 | + border-radius: 0.4em; | ||
57 | + color: white; | ||
58 | + background-color: rgb(15, 145, 138); | ||
59 | +} | ||
60 | + | ||
61 | +.input-holder { | ||
62 | + position: fixed; | ||
63 | + left: 0; | ||
64 | + right: 0; | ||
65 | + bottom: 0; | ||
66 | + padding: 0.25em; | ||
67 | + background-color: lightseagreen; | ||
68 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/generation.py
0 → 100644
1 | +import torch | ||
2 | +from konlpy.tag import Mecab | ||
3 | +from torch.autograd import Variable | ||
4 | +from chatspace import ChatSpace | ||
5 | + | ||
6 | +spacer = ChatSpace() | ||
7 | + | ||
8 | + | ||
9 | +def tokenizer1(text: str): | ||
10 | + result_text = ''.join(c for c in text if c.isalnum()) | ||
11 | + a = Mecab().morphs(result_text) | ||
12 | + return [a[i] for i in range(len(a))] | ||
13 | + | ||
14 | + | ||
15 | +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str): | ||
16 | + | ||
17 | + enc_input = tokenizer1(sentence) | ||
18 | + enc_input_index = [] | ||
19 | + | ||
20 | + for tok in enc_input: | ||
21 | + enc_input_index.append(TEXT.vocab.stoi[tok]) | ||
22 | + | ||
23 | + for j in range(max_len - len(enc_input_index)): | ||
24 | + enc_input_index.append(TEXT.vocab.stoi['<pad>']) | ||
25 | + | ||
26 | + enc_input_index = Variable(torch.LongTensor([enc_input_index])) | ||
27 | + | ||
28 | + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]]) | ||
29 | + | ||
30 | + model.eval() | ||
31 | + pred = [] | ||
32 | + for i in range(max_len): | ||
33 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
34 | + y_pred_ids = y_pred.max(dim=-1)[1] | ||
35 | + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']: | ||
36 | + y_pred_ids = y_pred_ids.squeeze(0) | ||
37 | + print(">", end=" ") | ||
38 | + for idx in range(len(y_pred_ids)): | ||
39 | + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | ||
40 | + pred_sentence = "".join(pred) | ||
41 | + pred_str = spacer.space(pred_sentence) | ||
42 | + return pred_str | ||
43 | + else: | ||
44 | + pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | ||
45 | + return 'Error: Sentence is not end' | ||
46 | + | ||
47 | + dec_input = torch.cat( | ||
48 | + [dec_input.to(torch.device('cpu')), | ||
49 | + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | ||
50 | + return 'Error: Sentence is not predicted' |
Light_model/light_chatbot.py
0 → 100644
1 | +import argparse | ||
2 | +import time | ||
3 | +import torch | ||
4 | +from torch import nn | ||
5 | +from torchtext import data | ||
6 | +from torchtext.data import BucketIterator | ||
7 | +from torchtext.data import TabularDataset | ||
8 | + | ||
9 | +from Styling import styling, make_special_token | ||
10 | +from generation import inference, tokenizer1 | ||
11 | +from model import Transformer, GradualWarmupScheduler | ||
12 | + | ||
13 | +SEED = 1234 | ||
14 | + | ||
15 | + | ||
16 | + | ||
17 | + | ||
18 | +def acc(yhat: torch.Tensor, y: torch.Tensor): | ||
19 | + with torch.no_grad(): | ||
20 | + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value | ||
21 | + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거 | ||
22 | + return _acc | ||
23 | + | ||
24 | + | ||
25 | +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool): | ||
26 | + total_loss = 0 | ||
27 | + iter_num = 0 | ||
28 | + tr_acc = 0 | ||
29 | + model.train() | ||
30 | + | ||
31 | + for step, batch in enumerate(iterator): | ||
32 | + optimizer.zero_grad() | ||
33 | + | ||
34 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
35 | + dec_output = dec_input[:, 1:] | ||
36 | + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data) | ||
37 | + | ||
38 | + # emotion 과 체를 반영 | ||
39 | + enc_input, dec_input, dec_outputs = \ | ||
40 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL) | ||
41 | + | ||
42 | + y_pred = model(enc_input, dec_input) | ||
43 | + | ||
44 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
45 | + dec_output = dec_outputs.view(-1).long() | ||
46 | + | ||
47 | + # padding 제외한 value index 추출 | ||
48 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
49 | + | ||
50 | + # padding 은 loss 계산시 제외 | ||
51 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
52 | + loss.backward() | ||
53 | + optimizer.step() | ||
54 | + | ||
55 | + with torch.no_grad(): | ||
56 | + train_acc = acc(y_pred, dec_output) | ||
57 | + | ||
58 | + total_loss += loss | ||
59 | + iter_num += 1 | ||
60 | + tr_acc += train_acc | ||
61 | + | ||
62 | + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num | ||
63 | + | ||
64 | + | ||
65 | +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss): | ||
66 | + total_loss = 0 | ||
67 | + iter_num = 0 | ||
68 | + te_acc = 0 | ||
69 | + model.eval() | ||
70 | + | ||
71 | + with torch.no_grad(): | ||
72 | + for batch in iterator: | ||
73 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
74 | + dec_output = dec_input[:, 1:] | ||
75 | + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data) | ||
76 | + | ||
77 | + # emotion 과 체를 반영 | ||
78 | + enc_input, dec_input, dec_outputs = \ | ||
79 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL) | ||
80 | + | ||
81 | + y_pred = model(enc_input, dec_input) | ||
82 | + | ||
83 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
84 | + dec_output = dec_outputs.view(-1).long() | ||
85 | + | ||
86 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
87 | + | ||
88 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
89 | + | ||
90 | + with torch.no_grad(): | ||
91 | + test_acc = acc(y_pred, dec_output) | ||
92 | + total_loss += loss | ||
93 | + iter_num += 1 | ||
94 | + te_acc += test_acc | ||
95 | + | ||
96 | + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num | ||
97 | + | ||
98 | + | ||
99 | +# 데이터 전처리 및 loader return | ||
100 | +def data_preprocessing(args, device): | ||
101 | + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임. | ||
102 | + ID = data.Field(sequential=False, | ||
103 | + use_vocab=False) | ||
104 | + | ||
105 | + TEXT = data.Field(sequential=True, | ||
106 | + use_vocab=True, | ||
107 | + tokenize=tokenizer1, | ||
108 | + batch_first=True, | ||
109 | + fix_length=args.max_len, | ||
110 | + dtype=torch.int32 | ||
111 | + ) | ||
112 | + | ||
113 | + LABEL = data.Field(sequential=True, | ||
114 | + use_vocab=True, | ||
115 | + tokenize=tokenizer1, | ||
116 | + batch_first=True, | ||
117 | + fix_length=args.max_len, | ||
118 | + init_token='<sos>', | ||
119 | + eos_token='<eos>', | ||
120 | + dtype=torch.int32 | ||
121 | + ) | ||
122 | + | ||
123 | + SA = data.Field(sequential=False, | ||
124 | + use_vocab=False) | ||
125 | + | ||
126 | + train_data, test_data = TabularDataset.splits( | ||
127 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
128 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
129 | + ) | ||
130 | + | ||
131 | + # TEXT, LABEL 에 필요한 special token 만듦. | ||
132 | + text_specials, label_specials = make_special_token(args.per_rough) | ||
133 | + | ||
134 | + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
135 | + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
136 | + | ||
137 | + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
138 | + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
139 | + | ||
140 | + return TEXT, LABEL, train_loader, test_loader | ||
141 | + | ||
142 | + | ||
143 | +def main(TEXT, LABEL, arguments): | ||
144 | + | ||
145 | + # print argparse | ||
146 | + for idx, (key, value) in enumerate(args.__dict__.items()): | ||
147 | + if idx == 0: | ||
148 | + print("\nargparse{\n", "\t", key, ":", value) | ||
149 | + elif idx == len(args.__dict__) - 1: | ||
150 | + print("\t", key, ":", value, "\n}") | ||
151 | + else: | ||
152 | + print("\t", key, ":", value) | ||
153 | + | ||
154 | + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL) | ||
155 | + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>']) | ||
156 | + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr) | ||
157 | + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs) | ||
158 | + if args.per_soft: | ||
159 | + sorted_path = 'sorted_model-soft.pth' | ||
160 | + else: | ||
161 | + sorted_path = 'sorted_model-rough.pth' | ||
162 | + model.to(device) | ||
163 | + if arguments.train: | ||
164 | + best_valid_loss = float('inf') | ||
165 | + for epoch in range(arguments.num_epochs): | ||
166 | + torch.manual_seed(SEED) | ||
167 | + start_time = time.time() | ||
168 | + | ||
169 | + # train, validation | ||
170 | + train_loss, train_acc = \ | ||
171 | + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft, | ||
172 | + arguments.per_rough) | ||
173 | + valid_loss, valid_acc = test(model, test_loader, criterion) | ||
174 | + | ||
175 | + scheduler.step(epoch) | ||
176 | + # time cal | ||
177 | + end_time = time.time() | ||
178 | + elapsed_time = end_time - start_time | ||
179 | + epoch_mins = int(elapsed_time / 60) | ||
180 | + epoch_secs = int(elapsed_time - (epoch_mins * 60)) | ||
181 | + | ||
182 | + # torch.save(model.state_dict(), sorted_path) # for some overfitting | ||
183 | + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장. | ||
184 | + if valid_loss < best_valid_loss: | ||
185 | + best_valid_loss = valid_loss | ||
186 | + torch.save({ | ||
187 | + 'epoch': epoch, | ||
188 | + 'model_state_dict': model.state_dict(), | ||
189 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
190 | + 'loss': valid_loss}, | ||
191 | + sorted_path) | ||
192 | + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##') | ||
193 | + | ||
194 | + # print loss and acc | ||
195 | + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==') | ||
196 | + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==') | ||
197 | + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n') | ||
198 | + | ||
199 | + | ||
200 | + | ||
201 | + checkpoint = torch.load(sorted_path, map_location=device) | ||
202 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
203 | + | ||
204 | + test_loss, test_acc = test(model, test_loader, criterion) # 아 | ||
205 | + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==') | ||
206 | + print("\t-----------------------------") | ||
207 | + while True: | ||
208 | + sentence = input("문장을 입력하세요 : ") | ||
209 | + print(inference(device, args.max_len, TEXT, LABEL, model, sentence)) | ||
210 | + print("\n") | ||
211 | + | ||
212 | + | ||
213 | +if __name__ == '__main__': | ||
214 | + # argparse 정의 | ||
215 | + parser = argparse.ArgumentParser() | ||
216 | + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김. | ||
217 | + parser.add_argument('--batch_size', type=int, default=256) | ||
218 | + parser.add_argument('--num_epochs', type=int, default=22) | ||
219 | + parser.add_argument('--warming_up_epochs', type=int, default=5) | ||
220 | + parser.add_argument('--lr', type=float, default=0.0002) | ||
221 | + parser.add_argument('--embedding_dim', type=int, default=160) | ||
222 | + parser.add_argument('--nlayers', type=int, default=2) | ||
223 | + parser.add_argument('--nhead', type=int, default=2) | ||
224 | + parser.add_argument('--dropout', type=float, default=0.1) | ||
225 | + parser.add_argument('--train', action="store_true") | ||
226 | + group = parser.add_mutually_exclusive_group() | ||
227 | + group.add_argument('--per_soft', action="store_true") | ||
228 | + group.add_argument('--per_rough', action="store_true") | ||
229 | + args = parser.parse_args() | ||
230 | + print("-준비중-") | ||
231 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
232 | + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device) | ||
233 | + main(TEXT, LABEL, args) |
Light_model/main.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html> | ||
3 | + <head> | ||
4 | + <meta charset="UTF-8"> | ||
5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
6 | + <title>Emotional Chatbot with Styler</title> | ||
7 | + <script src="app.js"></script> | ||
8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
9 | + </head> | ||
10 | + <body onload="setDefault()"> | ||
11 | + <ul id="chat_list" class="list no-bullets"> | ||
12 | +<li class="chat-bubble mine">(대충 적당한 대사)</li> | ||
13 | +<li class="chat-bubble bots">(대충 알맞은 답변)</li> | ||
14 | + </ul> | ||
15 | + <div class="input-holder"> | ||
16 | + <input type="text" id="chat_input" autofocus/> | ||
17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
18 | + </div> | ||
19 | + </body> | ||
20 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/model.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import math | ||
4 | + | ||
5 | +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
6 | + | ||
7 | + | ||
8 | +class Transformer(nn.Module): | ||
9 | + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab): | ||
10 | + super(Transformer, self).__init__() | ||
11 | + self.d_model = embedding_dim | ||
12 | + self.n_head = nhead | ||
13 | + self.num_encoder_layers = nlayers | ||
14 | + self.num_decoder_layers = nlayers | ||
15 | + self.dim_feedforward = embedding_dim | ||
16 | + self.dropout = dropout | ||
17 | + | ||
18 | + self.SRC_vo = SRC_vocab | ||
19 | + self.TRG_vo = TRG_vocab | ||
20 | + | ||
21 | + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout) | ||
22 | + | ||
23 | + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model) | ||
24 | + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model) | ||
25 | + | ||
26 | + self.transformer = nn.Transformer(d_model=self.d_model, | ||
27 | + nhead=self.n_head, | ||
28 | + num_encoder_layers=self.num_encoder_layers, | ||
29 | + num_decoder_layers=self.num_decoder_layers, | ||
30 | + dim_feedforward=self.dim_feedforward, | ||
31 | + dropout=self.dropout) | ||
32 | + self.proj_vocab_layer = nn.Linear( | ||
33 | + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab)) | ||
34 | + | ||
35 | + | ||
36 | + def forward(self, en_input, de_input): | ||
37 | + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model) | ||
38 | + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model) | ||
39 | + x_en_embed = self.pos_encoder(x_en_embed) | ||
40 | + x_de_embed = self.pos_encoder(x_de_embed) | ||
41 | + | ||
42 | + # Masking | ||
43 | + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>'] | ||
44 | + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>'] | ||
45 | + memory_key_padding_mask = src_key_padding_mask | ||
46 | + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1)) | ||
47 | + | ||
48 | + x_en_embed = torch.einsum('ijk->jik', x_en_embed) | ||
49 | + x_de_embed = torch.einsum('ijk->jik', x_de_embed) | ||
50 | + | ||
51 | + feature = self.transformer(src=x_en_embed, | ||
52 | + tgt=x_de_embed, | ||
53 | + src_key_padding_mask=src_key_padding_mask, | ||
54 | + tgt_key_padding_mask=tgt_key_padding_mask, | ||
55 | + memory_key_padding_mask=memory_key_padding_mask, | ||
56 | + tgt_mask=tgt_mask.to(device)) | ||
57 | + | ||
58 | + logits = self.proj_vocab_layer(feature) | ||
59 | + logits = torch.einsum('ijk->jik', logits) | ||
60 | + | ||
61 | + return logits | ||
62 | + | ||
63 | + | ||
64 | +class PositionalEncoding(nn.Module): | ||
65 | + | ||
66 | + def __init__(self, d_model, dropout, max_len=15000): | ||
67 | + super(PositionalEncoding, self).__init__() | ||
68 | + self.dropout = nn.Dropout(p=dropout) | ||
69 | + | ||
70 | + pe = torch.zeros(max_len, d_model) | ||
71 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | ||
72 | + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
73 | + pe[:, 0::2] = torch.sin(position * div_term) | ||
74 | + pe[:, 1::2] = torch.cos(position * div_term) | ||
75 | + pe = pe.unsqueeze(0).transpose(0, 1) | ||
76 | + self.register_buffer('pe', pe) | ||
77 | + | ||
78 | + def forward(self, x): | ||
79 | + x = x + self.pe[:x.size(0), :] | ||
80 | + return self.dropout(x) | ||
81 | + | ||
82 | + | ||
83 | +from torch.optim.lr_scheduler import _LRScheduler | ||
84 | +from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
85 | + | ||
86 | + | ||
87 | +class GradualWarmupScheduler(_LRScheduler): | ||
88 | + """ Gradually warm-up(increasing) learning rate in optimizer. | ||
89 | + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | ||
90 | + Args: | ||
91 | + optimizer (Optimizer): Wrapped optimizer. | ||
92 | + multiplier: target learning rate = base lr * multiplier | ||
93 | + total_epoch: target learning rate is reached at total_epoch, gradually | ||
94 | + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | ||
95 | + """ | ||
96 | + | ||
97 | + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | ||
98 | + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | ||
99 | + self.multiplier = multiplier | ||
100 | + if self.multiplier <= 1.: | ||
101 | + raise ValueError('multiplier should be greater than 1.') | ||
102 | + self.total_epoch = total_epoch | ||
103 | + self.after_scheduler = after_scheduler | ||
104 | + self.finished = False | ||
105 | + super().__init__(optimizer) | ||
106 | + | ||
107 | + def get_lr(self): | ||
108 | + if self.last_epoch > self.total_epoch: | ||
109 | + if self.after_scheduler: | ||
110 | + if not self.finished: | ||
111 | + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
112 | + self.finished = True | ||
113 | + return self.after_scheduler.get_lr() | ||
114 | + return [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
115 | + | ||
116 | + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
117 | + self.base_lrs] | ||
118 | + | ||
119 | + def step_ReduceLROnPlateau(self, metrics, epoch=None): | ||
120 | + if epoch is None: | ||
121 | + epoch = self.last_epoch + 1 | ||
122 | + self.last_epoch = epoch if epoch != 0 else 1 | ||
123 | + if self.last_epoch <= self.total_epoch: | ||
124 | + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
125 | + self.base_lrs] | ||
126 | + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | ||
127 | + param_group['lr'] = lr | ||
128 | + else: | ||
129 | + if epoch is None: | ||
130 | + self.after_scheduler.step(metrics, None) | ||
131 | + else: | ||
132 | + self.after_scheduler.step(metrics, epoch - self.total_epoch) | ||
133 | + | ||
134 | + def step(self, epoch=None, metrics=None): | ||
135 | + if type(self.after_scheduler) != ReduceLROnPlateau: | ||
136 | + if self.finished and self.after_scheduler: | ||
137 | + if epoch is None: | ||
138 | + self.after_scheduler.step(None) | ||
139 | + else: | ||
140 | + self.after_scheduler.step(epoch - self.total_epoch) | ||
141 | + else: | ||
142 | + return super(GradualWarmupScheduler, self).step(epoch) | ||
143 | + else: | ||
144 | + self.step_ReduceLROnPlateau(metrics, epoch) |
Light_model/requirements.txt
0 → 100644
Light_model/sorted_model-rough.pth
0 → 100644
This file is too large to display.
Light_model/sorted_model-soft.pth
0 → 100644
This file is too large to display.
Light_model/static/app.js
0 → 100644
1 | +function send() { | ||
2 | + /*client side */ | ||
3 | + var chat = document.createElement("li"); | ||
4 | + var chat_input = document.getElementById("chat_input"); | ||
5 | + var chat_text = chat_input.value; | ||
6 | + chat.className = "chat-bubble mine"; | ||
7 | + chat.innerText = chat_text | ||
8 | + document.getElementById("chat_list").appendChild(chat); | ||
9 | + chat_input.value = ""; | ||
10 | + | ||
11 | + /* ajax request */ | ||
12 | + var request = new XMLHttpRequest(); | ||
13 | + request.open("POST", `${window.location.protocol}//${window.location.host}/api/soft`, true); | ||
14 | + request.onreadystatechange = function() { | ||
15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
16 | + var bot_chat = document.createElement("li"); | ||
17 | + bot_chat.className = "chat-bubble bots"; | ||
18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
20 | + | ||
21 | + }; | ||
22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
23 | +request.send(JSON.stringify({"data":chat_text})); | ||
24 | +} | ||
25 | + | ||
26 | +function setDefault() { | ||
27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
28 | + let input = document.getElementById("chat_input").value; | ||
29 | + let button = document.getElementById("send_button"); | ||
30 | + if(input.length>0) | ||
31 | + { | ||
32 | + button.removeAttribute("disabled"); | ||
33 | + } | ||
34 | + else | ||
35 | + { | ||
36 | + button.setAttribute("disabled", "true"); | ||
37 | + } | ||
38 | + // Number 13 is the "Enter" key on the keyboard | ||
39 | + if (event.keyCode === 13) { | ||
40 | + // Cancel the default action, if needed | ||
41 | + event.preventDefault(); | ||
42 | + // Trigger the button element with a click | ||
43 | + button.click(); | ||
44 | + } | ||
45 | + }); | ||
46 | +} |
Light_model/static/chat.css
0 → 100644
1 | +ul.no-bullets { | ||
2 | + list-style-type: none; /* Remove bullets */ | ||
3 | + padding: 0; /* Remove padding */ | ||
4 | + margin: 0; /* Remove margins */ | ||
5 | + } | ||
6 | + | ||
7 | +.chat-bubble { | ||
8 | + position: relative; | ||
9 | + padding: 0.5em; | ||
10 | + margin-top: 0.25em; | ||
11 | + margin-bottom: 0.25em; | ||
12 | + border-radius: 0.4em; | ||
13 | + color: white; | ||
14 | +} | ||
15 | +.mine { | ||
16 | + background: #00aabb; | ||
17 | +} | ||
18 | +.bots { | ||
19 | + background: #cc78c5; | ||
20 | +} | ||
21 | + | ||
22 | +.chat-bubble:after { | ||
23 | + content: ""; | ||
24 | + position: absolute; | ||
25 | + top: 50%; | ||
26 | + width: 0; | ||
27 | + height: 0; | ||
28 | + border: 0.625em solid transparent; | ||
29 | + border-top: 0; | ||
30 | + margin-top: -0.312em; | ||
31 | + | ||
32 | +} | ||
33 | +.chat-bubble.mine:after { | ||
34 | + right: 0; | ||
35 | + | ||
36 | + border-left-color: #00aabb; | ||
37 | + border-right: 0; | ||
38 | + margin-right: -0.625em; | ||
39 | +} | ||
40 | + | ||
41 | +.chat-bubble.bots:after { | ||
42 | + left: 0; | ||
43 | + | ||
44 | + border-right-color: #cc78c5; | ||
45 | + border-left: 0; | ||
46 | + margin-left: -0.625em; | ||
47 | +} | ||
48 | + | ||
49 | +#chat_input { | ||
50 | + width: 90%; | ||
51 | +} | ||
52 | + | ||
53 | +#send_button { | ||
54 | + | ||
55 | + width: 5%; | ||
56 | + border-radius: 0.4em; | ||
57 | + color: white; | ||
58 | + background-color: rgb(15, 145, 138); | ||
59 | +} | ||
60 | + | ||
61 | +.input-holder { | ||
62 | + position: fixed; | ||
63 | + left: 0; | ||
64 | + right: 0; | ||
65 | + bottom: 0; | ||
66 | + padding: 0.25em; | ||
67 | + background-color: lightseagreen; | ||
68 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/static/favicon.ico
0 → 100644
No preview for this file type
Light_model/static/main.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html> | ||
3 | + <head> | ||
4 | + <meta charset="UTF-8"> | ||
5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
6 | + <title>Emotional Chatbot with Styler</title> | ||
7 | + <script src="app.js"></script> | ||
8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
9 | + </head> | ||
10 | + <body onload="setDefault()"> | ||
11 | + <ul id="chat_list" class="list no-bullets"> | ||
12 | +<li class="chat-bubble mine">이렇게 질문을 하면...</li> | ||
13 | +<li class="chat-bubble bots">이렇게 답변이 옵니다!</li> | ||
14 | + </ul> | ||
15 | + <div class="input-holder"> | ||
16 | + <input type="text" id="chat_input" autofocus/> | ||
17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
18 | + </div> | ||
19 | + </body> | ||
20 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : | ... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : |
10 | - Force RTX 2080 Ti | 10 | - Force RTX 2080 Ti |
11 | - Python 3.6.8 | 11 | - Python 3.6.8 |
12 | - Pytorch 1.2.0 | 12 | - Pytorch 1.2.0 |
13 | + | ||
14 | +# Code | ||
15 | +## Chatbot | ||
16 | + | ||
17 | +### Chatbot_main.py | ||
18 | +챗봇 학습 및 시험에 사용되는 메인 파일입니다. | ||
19 | +### model.py | ||
20 | +챗봇에 이용되는 Transfer 모델 클래스 파일입니다. | ||
21 | +### generation.py | ||
22 | +추론 및 Beam search, Greedy search를 하는 파일입니다. | ||
23 | +### metric.py | ||
24 | +학습 성능을 측정하기 위한 모델입니다.\ | ||
25 | +`acc(yhat, y)`\ | ||
26 | +### Styling.py | ||
27 | +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다. | ||
28 | +### get_data.py | ||
29 | +데이터셋을 전처리하고 불러오기 위한 파일입니다.\ | ||
30 | +`tokenizer1(text)`\ | ||
31 | +* text: 토크나이징할 문자열 | ||
32 | +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\ | ||
33 | +`data_preprocessing(args, device)`\ | ||
34 | +* args: argparser로 파싱한 NamedTuple | ||
35 | +* device: pytorch device | ||
36 | +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다. | ||
37 | + | ||
38 | +## KoBERT | ||
39 | +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\ | ||
40 | +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\ | ||
41 | +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\ | ||
42 | +## Light_model | ||
43 | +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다. | ||
44 | +### light_chatbot.py | ||
45 | +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다. | ||
46 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
47 | + | ||
48 | +* train: 학습해 모델을 만들 경우에 사용합니다. | ||
49 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
50 | +* per_soft: soft 말투를 학습 또는 시험합니다. | ||
51 | +* per_rough: rough 말투를 학습 또는 시험합니다. | ||
52 | +두 옵션은 양립 불가능합니다. | ||
53 | +### app.py | ||
54 | +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\ | ||
55 | +`POST /api/soft`\ | ||
56 | +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\ | ||
57 | +`GET /`\ | ||
58 | +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다. | ||
59 | +### 기타 | ||
60 | +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다. | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment