김민수

light_model 추가 및 README 보강

1 +FROM ufoym/deepo:pytorch-cpu
2 +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile
3 +# Install JVM for Konlpy
4 +RUN apt-get update && \
5 + apt-get upgrade -y && \
6 + apt-get install -y \
7 + openjdk-8-jdk wget curl git python3-dev \
8 + language-pack-ko
9 +
10 +RUN locale-gen en_US.UTF-8 && \
11 + update-locale LANG=en_US.UTF-8
12 +
13 +# Install zsh
14 +RUN apt-get install -y zsh && \
15 + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
16 +
17 +# Install another packages
18 +RUN pip install --upgrade pip
19 +RUN pip install autopep8
20 +RUN pip install konlpy
21 +RUN pip install torchtext pytorch_pretrained_bert
22 +# Install dependency of styling chatbot
23 +RUN pip install hgtk chatspace
24 +
25 +# Add Mecab-Ko
26 +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash
27 +# install styling chatbot by BM-K
28 +RUN git clone https://github.com/km19809/light_model.git
29 +RUN pip install -r light_model/requirements.txt
30 +
31 +# Add non-root user
32 +RUN adduser --disabled-password --gecos "" user
33 +
34 +# Reset Workdir
35 +WORKDIR /light_model
...\ No newline at end of file ...\ No newline at end of file
1 +# Light weight model of styling chatbot
2 +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\
3 +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer)
4 +
5 +## 요구사항
6 +
7 +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다.
8 +```
9 +torch~=1.4.0
10 +Flask~=1.1.2
11 +torchtext~=0.6.0
12 +hgtk~=0.1.3
13 +konlpy~=0.5.2
14 +chatspace~=1.0.1
15 +```
16 +
17 +## 사용법
18 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
19 +
20 +* train: 학습해 모델을 만들 경우에 사용합니다. \
21 +사용하지 않으면 모델을 불러와 시험 합니다.
22 +* per_soft: soft 말투를 학습 또는 시험합니다.\
23 +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\
24 +두 옵션은 양립 불가능합니다.
25 +
26 +`app.py`
27 +
28 +챗봇을 시험하기 위한 간단한 플라스크 서버입니다.
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 +function send() {
2 + /*client side */
3 + var chat = document.createElement("li");
4 + var chat_input = document.getElementById("chat_input");
5 + var chat_text = chat_input.value;
6 + chat.className = "chat-bubble mine";
7 + chat.innerText = chat_text
8 + document.getElementById("chat_list").appendChild(chat);
9 + chat_input.value = "";
10 +
11 + /* ajax request */
12 + var request = new XMLHttpRequest();
13 + request.open("POST", `${window.location.host}/api/soft`, true);
14 + request.onreadystatechange = function() {
15 + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return;
16 + var bot_chat = document.createElement("li");
17 + bot_chat.className = "chat-bubble bots";
18 + bot_chat.innerText = JSON.parse(request.responseText).data;
19 + document.getElementById("chat_list").appendChild(bot_chat);
20 +
21 + };
22 + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
23 +request.send(JSON.stringify({"data":chat_text}));
24 +}
25 +
26 +function setDefault() {
27 + document.getElementById("chat_input").addEventListener("keyup", function(event) {
28 + let input = document.getElementById("chat_input").value;
29 + let button = document.getElementById("send_button");
30 + if(input.length>0)
31 + {
32 + button.removeAttribute("disabled");
33 + }
34 + else
35 + {
36 + button.setAttribute("disabled", "true");
37 + }
38 + // Number 13 is the "Enter" key on the keyboard
39 + if (event.keyCode === 13) {
40 + // Cancel the default action, if needed
41 + event.preventDefault();
42 + // Trigger the button element with a click
43 + button.click();
44 + }
45 + });
46 +}
1 +from flask import Flask, request, jsonify, send_from_directory
2 +import torch
3 +from torchtext import data
4 +from generation import inference, tokenizer1
5 +from Styling import make_special_token
6 +from model import Transformer
7 +
8 +app = Flask(__name__,
9 + static_url_path='',
10 + static_folder='static',)
11 +app.config['JSON_AS_ASCII'] = False
12 +device = torch.device('cpu')
13 +max_len = 40
14 +ID = data.Field(sequential=False,
15 + use_vocab=False)
16 +SA = data.Field(sequential=False,
17 + use_vocab=False)
18 +TEXT = data.Field(sequential=True,
19 + use_vocab=True,
20 + tokenize=tokenizer1,
21 + batch_first=True,
22 + fix_length=max_len,
23 + dtype=torch.int32
24 + )
25 +
26 +LABEL = data.Field(sequential=True,
27 + use_vocab=True,
28 + tokenize=tokenizer1,
29 + batch_first=True,
30 + fix_length=max_len,
31 + init_token='<sos>',
32 + eos_token='<eos>',
33 + dtype=torch.int32
34 + )
35 +text_specials, label_specials = make_special_token(False)
36 +train_data, _ = data.TabularDataset.splits(
37 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
38 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
39 +)
40 +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
41 +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
42 +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL)
43 +# rough_model = Transformer(args, TEXT, LABEL)
44 +soft_model.to(device)
45 +# rough_model.to(device)
46 +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict'])
47 +
48 +
49 +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict'])
50 +
51 +
52 +@app.route('/api/soft', methods=['POST'])
53 +def soft():
54 + if request.is_json:
55 + sentence = request.json["data"]
56 + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200
57 + else:
58 + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400
59 +
60 +# @app.route('/rough', methods=['POST'])
61 +# def rough():
62 +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200
63 +
64 +@app.route('/', methods=['GET'])
65 +def main_page():
66 + return send_from_directory('static','main.html')
67 +
68 +if __name__ == '__main__':
69 + app.run(host='0.0.0.0', port=8080)
1 +ul.no-bullets {
2 + list-style-type: none; /* Remove bullets */
3 + padding: 0; /* Remove padding */
4 + margin: 0; /* Remove margins */
5 + }
6 +
7 +.chat-bubble {
8 + position: relative;
9 + padding: 0.5em;
10 + margin-top: 0.25em;
11 + margin-bottom: 0.25em;
12 + border-radius: 0.4em;
13 + color: white;
14 +}
15 +.mine {
16 + background: #00aabb;
17 +}
18 +.bots {
19 + background: #cc78c5;
20 +}
21 +
22 +.chat-bubble:after {
23 + content: "";
24 + position: absolute;
25 + top: 50%;
26 + width: 0;
27 + height: 0;
28 + border: 0.625em solid transparent;
29 + border-top: 0;
30 + margin-top: -0.312em;
31 +
32 +}
33 +.chat-bubble.mine:after {
34 + right: 0;
35 +
36 + border-left-color: #00aabb;
37 + border-right: 0;
38 + margin-right: -0.625em;
39 +}
40 +
41 +.chat-bubble.bots:after {
42 + left: 0;
43 +
44 + border-right-color: #cc78c5;
45 + border-left: 0;
46 + margin-left: -0.625em;
47 +}
48 +
49 +#chat_input {
50 + width: 90%;
51 +}
52 +
53 +#send_button {
54 +
55 + width: 5%;
56 + border-radius: 0.4em;
57 + color: white;
58 + background-color: rgb(15, 145, 138);
59 +}
60 +
61 +.input-holder {
62 + position: fixed;
63 + left: 0;
64 + right: 0;
65 + bottom: 0;
66 + padding: 0.25em;
67 + background-color: lightseagreen;
68 +}
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from konlpy.tag import Mecab
3 +from torch.autograd import Variable
4 +from chatspace import ChatSpace
5 +
6 +spacer = ChatSpace()
7 +
8 +
9 +def tokenizer1(text: str):
10 + result_text = ''.join(c for c in text if c.isalnum())
11 + a = Mecab().morphs(result_text)
12 + return [a[i] for i in range(len(a))]
13 +
14 +
15 +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str):
16 +
17 + enc_input = tokenizer1(sentence)
18 + enc_input_index = []
19 +
20 + for tok in enc_input:
21 + enc_input_index.append(TEXT.vocab.stoi[tok])
22 +
23 + for j in range(max_len - len(enc_input_index)):
24 + enc_input_index.append(TEXT.vocab.stoi['<pad>'])
25 +
26 + enc_input_index = Variable(torch.LongTensor([enc_input_index]))
27 +
28 + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]])
29 +
30 + model.eval()
31 + pred = []
32 + for i in range(max_len):
33 + y_pred = model(enc_input_index.to(device), dec_input.to(device))
34 + y_pred_ids = y_pred.max(dim=-1)[1]
35 + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']:
36 + y_pred_ids = y_pred_ids.squeeze(0)
37 + print(">", end=" ")
38 + for idx in range(len(y_pred_ids)):
39 + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
40 + pred_sentence = "".join(pred)
41 + pred_str = spacer.space(pred_sentence)
42 + return pred_str
43 + else:
44 + pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
45 + return 'Error: Sentence is not end'
46 +
47 + dec_input = torch.cat(
48 + [dec_input.to(torch.device('cpu')),
49 + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
50 + return 'Error: Sentence is not predicted'
1 +import argparse
2 +import time
3 +import torch
4 +from torch import nn
5 +from torchtext import data
6 +from torchtext.data import BucketIterator
7 +from torchtext.data import TabularDataset
8 +
9 +from Styling import styling, make_special_token
10 +from generation import inference, tokenizer1
11 +from model import Transformer, GradualWarmupScheduler
12 +
13 +SEED = 1234
14 +
15 +
16 +
17 +
18 +def acc(yhat: torch.Tensor, y: torch.Tensor):
19 + with torch.no_grad():
20 + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value
21 + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거
22 + return _acc
23 +
24 +
25 +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool):
26 + total_loss = 0
27 + iter_num = 0
28 + tr_acc = 0
29 + model.train()
30 +
31 + for step, batch in enumerate(iterator):
32 + optimizer.zero_grad()
33 +
34 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
35 + dec_output = dec_input[:, 1:]
36 + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data)
37 +
38 + # emotion 과 체를 반영
39 + enc_input, dec_input, dec_outputs = \
40 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL)
41 +
42 + y_pred = model(enc_input, dec_input)
43 +
44 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
45 + dec_output = dec_outputs.view(-1).long()
46 +
47 + # padding 제외한 value index 추출
48 + real_value_index = [dec_output != 1] # <pad> == 1
49 +
50 + # padding 은 loss 계산시 제외
51 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
52 + loss.backward()
53 + optimizer.step()
54 +
55 + with torch.no_grad():
56 + train_acc = acc(y_pred, dec_output)
57 +
58 + total_loss += loss
59 + iter_num += 1
60 + tr_acc += train_acc
61 +
62 + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num
63 +
64 +
65 +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss):
66 + total_loss = 0
67 + iter_num = 0
68 + te_acc = 0
69 + model.eval()
70 +
71 + with torch.no_grad():
72 + for batch in iterator:
73 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
74 + dec_output = dec_input[:, 1:]
75 + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data)
76 +
77 + # emotion 과 체를 반영
78 + enc_input, dec_input, dec_outputs = \
79 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL)
80 +
81 + y_pred = model(enc_input, dec_input)
82 +
83 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
84 + dec_output = dec_outputs.view(-1).long()
85 +
86 + real_value_index = [dec_output != 1] # <pad> == 1
87 +
88 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
89 +
90 + with torch.no_grad():
91 + test_acc = acc(y_pred, dec_output)
92 + total_loss += loss
93 + iter_num += 1
94 + te_acc += test_acc
95 +
96 + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num
97 +
98 +
99 +# 데이터 전처리 및 loader return
100 +def data_preprocessing(args, device):
101 + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임.
102 + ID = data.Field(sequential=False,
103 + use_vocab=False)
104 +
105 + TEXT = data.Field(sequential=True,
106 + use_vocab=True,
107 + tokenize=tokenizer1,
108 + batch_first=True,
109 + fix_length=args.max_len,
110 + dtype=torch.int32
111 + )
112 +
113 + LABEL = data.Field(sequential=True,
114 + use_vocab=True,
115 + tokenize=tokenizer1,
116 + batch_first=True,
117 + fix_length=args.max_len,
118 + init_token='<sos>',
119 + eos_token='<eos>',
120 + dtype=torch.int32
121 + )
122 +
123 + SA = data.Field(sequential=False,
124 + use_vocab=False)
125 +
126 + train_data, test_data = TabularDataset.splits(
127 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
128 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
129 + )
130 +
131 + # TEXT, LABEL 에 필요한 special token 만듦.
132 + text_specials, label_specials = make_special_token(args.per_rough)
133 +
134 + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
135 + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
136 +
137 + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
138 + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
139 +
140 + return TEXT, LABEL, train_loader, test_loader
141 +
142 +
143 +def main(TEXT, LABEL, arguments):
144 +
145 + # print argparse
146 + for idx, (key, value) in enumerate(args.__dict__.items()):
147 + if idx == 0:
148 + print("\nargparse{\n", "\t", key, ":", value)
149 + elif idx == len(args.__dict__) - 1:
150 + print("\t", key, ":", value, "\n}")
151 + else:
152 + print("\t", key, ":", value)
153 +
154 + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL)
155 + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>'])
156 + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr)
157 + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs)
158 + if args.per_soft:
159 + sorted_path = 'sorted_model-soft.pth'
160 + else:
161 + sorted_path = 'sorted_model-rough.pth'
162 + model.to(device)
163 + if arguments.train:
164 + best_valid_loss = float('inf')
165 + for epoch in range(arguments.num_epochs):
166 + torch.manual_seed(SEED)
167 + start_time = time.time()
168 +
169 + # train, validation
170 + train_loss, train_acc = \
171 + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft,
172 + arguments.per_rough)
173 + valid_loss, valid_acc = test(model, test_loader, criterion)
174 +
175 + scheduler.step(epoch)
176 + # time cal
177 + end_time = time.time()
178 + elapsed_time = end_time - start_time
179 + epoch_mins = int(elapsed_time / 60)
180 + epoch_secs = int(elapsed_time - (epoch_mins * 60))
181 +
182 + # torch.save(model.state_dict(), sorted_path) # for some overfitting
183 + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장.
184 + if valid_loss < best_valid_loss:
185 + best_valid_loss = valid_loss
186 + torch.save({
187 + 'epoch': epoch,
188 + 'model_state_dict': model.state_dict(),
189 + 'optimizer_state_dict': optimizer.state_dict(),
190 + 'loss': valid_loss},
191 + sorted_path)
192 + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##')
193 +
194 + # print loss and acc
195 + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==')
196 + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==')
197 + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n')
198 +
199 +
200 +
201 + checkpoint = torch.load(sorted_path, map_location=device)
202 + model.load_state_dict(checkpoint['model_state_dict'])
203 +
204 + test_loss, test_acc = test(model, test_loader, criterion) # 아
205 + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==')
206 + print("\t-----------------------------")
207 + while True:
208 + sentence = input("문장을 입력하세요 : ")
209 + print(inference(device, args.max_len, TEXT, LABEL, model, sentence))
210 + print("\n")
211 +
212 +
213 +if __name__ == '__main__':
214 + # argparse 정의
215 + parser = argparse.ArgumentParser()
216 + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김.
217 + parser.add_argument('--batch_size', type=int, default=256)
218 + parser.add_argument('--num_epochs', type=int, default=22)
219 + parser.add_argument('--warming_up_epochs', type=int, default=5)
220 + parser.add_argument('--lr', type=float, default=0.0002)
221 + parser.add_argument('--embedding_dim', type=int, default=160)
222 + parser.add_argument('--nlayers', type=int, default=2)
223 + parser.add_argument('--nhead', type=int, default=2)
224 + parser.add_argument('--dropout', type=float, default=0.1)
225 + parser.add_argument('--train', action="store_true")
226 + group = parser.add_mutually_exclusive_group()
227 + group.add_argument('--per_soft', action="store_true")
228 + group.add_argument('--per_rough', action="store_true")
229 + args = parser.parse_args()
230 + print("-준비중-")
231 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
232 + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device)
233 + main(TEXT, LABEL, args)
1 +<!DOCTYPE html>
2 +<html>
3 + <head>
4 + <meta charset="UTF-8">
5 + <meta name="viewport" content="width=device-width, initial-scale=1">
6 + <title>Emotional Chatbot with Styler</title>
7 + <script src="app.js"></script>
8 + <link rel="stylesheet" type="text/css" href="chat.css" />
9 + </head>
10 + <body onload="setDefault()">
11 + <ul id="chat_list" class="list no-bullets">
12 +<li class="chat-bubble mine">(대충 적당한 대사)</li>
13 +<li class="chat-bubble bots">(대충 알맞은 답변)</li>
14 + </ul>
15 + <div class="input-holder">
16 + <input type="text" id="chat_input" autofocus/>
17 + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled>
18 + </div>
19 + </body>
20 +</html>
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import math
4 +
5 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
6 +
7 +
8 +class Transformer(nn.Module):
9 + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab):
10 + super(Transformer, self).__init__()
11 + self.d_model = embedding_dim
12 + self.n_head = nhead
13 + self.num_encoder_layers = nlayers
14 + self.num_decoder_layers = nlayers
15 + self.dim_feedforward = embedding_dim
16 + self.dropout = dropout
17 +
18 + self.SRC_vo = SRC_vocab
19 + self.TRG_vo = TRG_vocab
20 +
21 + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout)
22 +
23 + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model)
24 + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model)
25 +
26 + self.transformer = nn.Transformer(d_model=self.d_model,
27 + nhead=self.n_head,
28 + num_encoder_layers=self.num_encoder_layers,
29 + num_decoder_layers=self.num_decoder_layers,
30 + dim_feedforward=self.dim_feedforward,
31 + dropout=self.dropout)
32 + self.proj_vocab_layer = nn.Linear(
33 + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab))
34 +
35 +
36 + def forward(self, en_input, de_input):
37 + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model)
38 + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model)
39 + x_en_embed = self.pos_encoder(x_en_embed)
40 + x_de_embed = self.pos_encoder(x_de_embed)
41 +
42 + # Masking
43 + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>']
44 + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>']
45 + memory_key_padding_mask = src_key_padding_mask
46 + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1))
47 +
48 + x_en_embed = torch.einsum('ijk->jik', x_en_embed)
49 + x_de_embed = torch.einsum('ijk->jik', x_de_embed)
50 +
51 + feature = self.transformer(src=x_en_embed,
52 + tgt=x_de_embed,
53 + src_key_padding_mask=src_key_padding_mask,
54 + tgt_key_padding_mask=tgt_key_padding_mask,
55 + memory_key_padding_mask=memory_key_padding_mask,
56 + tgt_mask=tgt_mask.to(device))
57 +
58 + logits = self.proj_vocab_layer(feature)
59 + logits = torch.einsum('ijk->jik', logits)
60 +
61 + return logits
62 +
63 +
64 +class PositionalEncoding(nn.Module):
65 +
66 + def __init__(self, d_model, dropout, max_len=15000):
67 + super(PositionalEncoding, self).__init__()
68 + self.dropout = nn.Dropout(p=dropout)
69 +
70 + pe = torch.zeros(max_len, d_model)
71 + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
72 + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
73 + pe[:, 0::2] = torch.sin(position * div_term)
74 + pe[:, 1::2] = torch.cos(position * div_term)
75 + pe = pe.unsqueeze(0).transpose(0, 1)
76 + self.register_buffer('pe', pe)
77 +
78 + def forward(self, x):
79 + x = x + self.pe[:x.size(0), :]
80 + return self.dropout(x)
81 +
82 +
83 +from torch.optim.lr_scheduler import _LRScheduler
84 +from torch.optim.lr_scheduler import ReduceLROnPlateau
85 +
86 +
87 +class GradualWarmupScheduler(_LRScheduler):
88 + """ Gradually warm-up(increasing) learning rate in optimizer.
89 + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
90 + Args:
91 + optimizer (Optimizer): Wrapped optimizer.
92 + multiplier: target learning rate = base lr * multiplier
93 + total_epoch: target learning rate is reached at total_epoch, gradually
94 + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
95 + """
96 +
97 + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
98 + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
99 + self.multiplier = multiplier
100 + if self.multiplier <= 1.:
101 + raise ValueError('multiplier should be greater than 1.')
102 + self.total_epoch = total_epoch
103 + self.after_scheduler = after_scheduler
104 + self.finished = False
105 + super().__init__(optimizer)
106 +
107 + def get_lr(self):
108 + if self.last_epoch > self.total_epoch:
109 + if self.after_scheduler:
110 + if not self.finished:
111 + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
112 + self.finished = True
113 + return self.after_scheduler.get_lr()
114 + return [base_lr * self.multiplier for base_lr in self.base_lrs]
115 +
116 + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
117 + self.base_lrs]
118 +
119 + def step_ReduceLROnPlateau(self, metrics, epoch=None):
120 + if epoch is None:
121 + epoch = self.last_epoch + 1
122 + self.last_epoch = epoch if epoch != 0 else 1
123 + if self.last_epoch <= self.total_epoch:
124 + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
125 + self.base_lrs]
126 + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
127 + param_group['lr'] = lr
128 + else:
129 + if epoch is None:
130 + self.after_scheduler.step(metrics, None)
131 + else:
132 + self.after_scheduler.step(metrics, epoch - self.total_epoch)
133 +
134 + def step(self, epoch=None, metrics=None):
135 + if type(self.after_scheduler) != ReduceLROnPlateau:
136 + if self.finished and self.after_scheduler:
137 + if epoch is None:
138 + self.after_scheduler.step(None)
139 + else:
140 + self.after_scheduler.step(epoch - self.total_epoch)
141 + else:
142 + return super(GradualWarmupScheduler, self).step(epoch)
143 + else:
144 + self.step_ReduceLROnPlateau(metrics, epoch)
1 +torch~=1.4.0
2 +Flask~=1.1.2
3 +torchtext~=0.6.0
4 +hgtk~=0.1.3
5 +konlpy~=0.5.2
6 +chatspace~=1.0.1
...\ No newline at end of file ...\ No newline at end of file
This file is too large to display.
This file is too large to display.
1 +function send() {
2 + /*client side */
3 + var chat = document.createElement("li");
4 + var chat_input = document.getElementById("chat_input");
5 + var chat_text = chat_input.value;
6 + chat.className = "chat-bubble mine";
7 + chat.innerText = chat_text
8 + document.getElementById("chat_list").appendChild(chat);
9 + chat_input.value = "";
10 +
11 + /* ajax request */
12 + var request = new XMLHttpRequest();
13 + request.open("POST", `${window.location.protocol}//${window.location.host}/api/soft`, true);
14 + request.onreadystatechange = function() {
15 + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return;
16 + var bot_chat = document.createElement("li");
17 + bot_chat.className = "chat-bubble bots";
18 + bot_chat.innerText = JSON.parse(request.responseText).data;
19 + document.getElementById("chat_list").appendChild(bot_chat);
20 +
21 + };
22 + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
23 +request.send(JSON.stringify({"data":chat_text}));
24 +}
25 +
26 +function setDefault() {
27 + document.getElementById("chat_input").addEventListener("keyup", function(event) {
28 + let input = document.getElementById("chat_input").value;
29 + let button = document.getElementById("send_button");
30 + if(input.length>0)
31 + {
32 + button.removeAttribute("disabled");
33 + }
34 + else
35 + {
36 + button.setAttribute("disabled", "true");
37 + }
38 + // Number 13 is the "Enter" key on the keyboard
39 + if (event.keyCode === 13) {
40 + // Cancel the default action, if needed
41 + event.preventDefault();
42 + // Trigger the button element with a click
43 + button.click();
44 + }
45 + });
46 +}
1 +ul.no-bullets {
2 + list-style-type: none; /* Remove bullets */
3 + padding: 0; /* Remove padding */
4 + margin: 0; /* Remove margins */
5 + }
6 +
7 +.chat-bubble {
8 + position: relative;
9 + padding: 0.5em;
10 + margin-top: 0.25em;
11 + margin-bottom: 0.25em;
12 + border-radius: 0.4em;
13 + color: white;
14 +}
15 +.mine {
16 + background: #00aabb;
17 +}
18 +.bots {
19 + background: #cc78c5;
20 +}
21 +
22 +.chat-bubble:after {
23 + content: "";
24 + position: absolute;
25 + top: 50%;
26 + width: 0;
27 + height: 0;
28 + border: 0.625em solid transparent;
29 + border-top: 0;
30 + margin-top: -0.312em;
31 +
32 +}
33 +.chat-bubble.mine:after {
34 + right: 0;
35 +
36 + border-left-color: #00aabb;
37 + border-right: 0;
38 + margin-right: -0.625em;
39 +}
40 +
41 +.chat-bubble.bots:after {
42 + left: 0;
43 +
44 + border-right-color: #cc78c5;
45 + border-left: 0;
46 + margin-left: -0.625em;
47 +}
48 +
49 +#chat_input {
50 + width: 90%;
51 +}
52 +
53 +#send_button {
54 +
55 + width: 5%;
56 + border-radius: 0.4em;
57 + color: white;
58 + background-color: rgb(15, 145, 138);
59 +}
60 +
61 +.input-holder {
62 + position: fixed;
63 + left: 0;
64 + right: 0;
65 + bottom: 0;
66 + padding: 0.25em;
67 + background-color: lightseagreen;
68 +}
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
1 +<!DOCTYPE html>
2 +<html>
3 + <head>
4 + <meta charset="UTF-8">
5 + <meta name="viewport" content="width=device-width, initial-scale=1">
6 + <title>Emotional Chatbot with Styler</title>
7 + <script src="app.js"></script>
8 + <link rel="stylesheet" type="text/css" href="chat.css" />
9 + </head>
10 + <body onload="setDefault()">
11 + <ul id="chat_list" class="list no-bullets">
12 +<li class="chat-bubble mine">이렇게 질문을 하면...</li>
13 +<li class="chat-bubble bots">이렇게 답변이 옵니다!</li>
14 + </ul>
15 + <div class="input-holder">
16 + <input type="text" id="chat_input" autofocus/>
17 + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled>
18 + </div>
19 + </body>
20 +</html>
...\ No newline at end of file ...\ No newline at end of file
...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : ...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 :
10 - Force RTX 2080 Ti 10 - Force RTX 2080 Ti
11 - Python 3.6.8 11 - Python 3.6.8
12 - Pytorch 1.2.0 12 - Pytorch 1.2.0
13 +
14 +# Code
15 +## Chatbot
16 +
17 +### Chatbot_main.py
18 +챗봇 학습 및 시험에 사용되는 메인 파일입니다.
19 +### model.py
20 +챗봇에 이용되는 Transfer 모델 클래스 파일입니다.
21 +### generation.py
22 +추론 및 Beam search, Greedy search를 하는 파일입니다.
23 +### metric.py
24 +학습 성능을 측정하기 위한 모델입니다.\
25 +`acc(yhat, y)`\
26 +### Styling.py
27 +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다.
28 +### get_data.py
29 +데이터셋을 전처리하고 불러오기 위한 파일입니다.\
30 +`tokenizer1(text)`\
31 +* text: 토크나이징할 문자열
32 +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\
33 +`data_preprocessing(args, device)`\
34 +* args: argparser로 파싱한 NamedTuple
35 +* device: pytorch device
36 +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다.
37 +
38 +## KoBERT
39 +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\
40 +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\
41 +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\
42 +## Light_model
43 +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다.
44 +### light_chatbot.py
45 +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다.
46 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
47 +
48 +* train: 학습해 모델을 만들 경우에 사용합니다.
49 +사용하지 않으면 모델을 불러와 시험 합니다.
50 +* per_soft: soft 말투를 학습 또는 시험합니다.
51 +* per_rough: rough 말투를 학습 또는 시험합니다.
52 +두 옵션은 양립 불가능합니다.
53 +### app.py
54 +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\
55 +`POST /api/soft`\
56 +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\
57 +`GET /`\
58 +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다.
59 +### 기타
60 +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다.
...\ No newline at end of file ...\ No newline at end of file
......