1 +FROM ufoym/deepo:pytorch-cpu
2 +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile
3 +# Install JVM for Konlpy
4 +RUN apt-get update && \
5 + apt-get upgrade -y && \
6 + apt-get install -y \
7 + openjdk-8-jdk wget curl git python3-dev \
8 + language-pack-ko
9 +
10 +RUN locale-gen en_US.UTF-8 && \
11 + update-locale LANG=en_US.UTF-8
12 +
13 +# Install zsh
14 +RUN apt-get install -y zsh && \
15 + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
16 +
17 +# Install another packages
18 +RUN pip install --upgrade pip
19 +RUN pip install autopep8
20 +RUN pip install konlpy
21 +RUN pip install torchtext pytorch_pretrained_bert
22 +# Install dependency of styling chatbot
23 +RUN pip install hgtk chatspace
24 +
25 +# Add Mecab-Ko
26 +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash
27 +# install styling chatbot by BM-K
28 +RUN git clone https://github.com/km19809/light_model.git
29 +RUN pip install -r light_model/requirements.txt
30 +
31 +# Add non-root user
32 +RUN adduser --disabled-password --gecos "" user
33 +
34 +# Reset Workdir
35 +WORKDIR /light_model
...\ No newline at end of file ...\ No newline at end of file
1 +# Light weight model of styling chatbot
2 +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\
3 +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer)
4 +
5 +## 요구사항
6 +
7 +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다.
8 +```
9 +torch~=1.4.0
10 +Flask~=1.1.2
11 +torchtext~=0.6.0
12 +hgtk~=0.1.3
13 +konlpy~=0.5.2
14 +chatspace~=1.0.1
15 +```
16 +
17 +## 사용법
18 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
19 +
20 +* train: 학습해 모델을 만들 경우에 사용합니다. \
21 +사용하지 않으면 모델을 불러와 시험 합니다.
22 +* per_soft: soft 말투를 학습 또는 시험합니다.\
23 +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\
24 +두 옵션은 양립 불가능합니다.
25 +
26 +`app.py`
27 +
28 +챗봇을 시험하기 위한 간단한 플라스크 서버입니다.
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 +function send() {
2 + /*client side */
3 + var chat = document.createElement("li");
4 + var chat_input = document.getElementById("chat_input");
5 + var chat_text = chat_input.value;
6 + chat.className = "chat-bubble mine";
7 + chat.innerText = chat_text
8 + document.getElementById("chat_list").appendChild(chat);
9 + chat_input.value = "";
10 +
11 + /* ajax request */
12 + var request = new XMLHttpRequest();
13 + request.open("POST", `${window.location.host}/api/soft`, true);
14 + request.onreadystatechange = function() {
15 + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return;
16 + var bot_chat = document.createElement("li");
17 + bot_chat.className = "chat-bubble bots";
18 + bot_chat.innerText = JSON.parse(request.responseText).data;
19 + document.getElementById("chat_list").appendChild(bot_chat);
20 +
21 + };
22 + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
23 +request.send(JSON.stringify({"data":chat_text}));
24 +}
25 +
26 +function setDefault() {
27 + document.getElementById("chat_input").addEventListener("keyup", function(event) {
28 + let input = document.getElementById("chat_input").value;
29 + let button = document.getElementById("send_button");
30 + if(input.length>0)
31 + {
32 + button.removeAttribute("disabled");
33 + }
34 + else
35 + {
36 + button.setAttribute("disabled", "true");
37 + }
38 + // Number 13 is the "Enter" key on the keyboard
39 + if (event.keyCode === 13) {
40 + // Cancel the default action, if needed
41 + event.preventDefault();
42 + // Trigger the button element with a click
43 + button.click();
44 + }
45 + });
46 +}
1 +from flask import Flask, request, jsonify, send_from_directory
2 +import torch
3 +from torchtext import data
4 +from generation import inference, tokenizer1
5 +from Styling import make_special_token
6 +from model import Transformer
7 +
8 +app = Flask(__name__,
9 + static_url_path='',
10 + static_folder='static',)
11 +app.config['JSON_AS_ASCII'] = False
12 +device = torch.device('cpu')
13 +max_len = 40
14 +ID = data.Field(sequential=False,
15 + use_vocab=False)
16 +SA = data.Field(sequential=False,
17 + use_vocab=False)
18 +TEXT = data.Field(sequential=True,
19 + use_vocab=True,
20 + tokenize=tokenizer1,
21 + batch_first=True,
22 + fix_length=max_len,
23 + dtype=torch.int32
24 + )
25 +
26 +LABEL = data.Field(sequential=True,
27 + use_vocab=True,
28 + tokenize=tokenizer1,
29 + batch_first=True,
30 + fix_length=max_len,
31 + init_token='<sos>',
32 + eos_token='<eos>',
33 + dtype=torch.int32
34 + )
35 +text_specials, label_specials = make_special_token(False)
36 +train_data, _ = data.TabularDataset.splits(
37 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
38 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
39 +)
40 +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
41 +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
42 +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL)
43 +# rough_model = Transformer(args, TEXT, LABEL)
44 +soft_model.to(device)
45 +# rough_model.to(device)
46 +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict'])
47 +
48 +
49 +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict'])
50 +
51 +
52 +@app.route('/api/soft', methods=['POST'])
53 +def soft():
54 + if request.is_json:
55 + sentence = request.json["data"]
56 + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200
57 + else:
58 + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400
59 +
60 +# @app.route('/rough', methods=['POST'])
61 +# def rough():
62 +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200
63 +
64 +@app.route('/', methods=['GET'])
65 +def main_page():
66 + return send_from_directory('static','main.html')
67 +
68 +if __name__ == '__main__':
69 + app.run(host='', port=8080)
1 +ul.no-bullets {
2 + list-style-type: none; /* Remove bullets */
3 + padding: 0; /* Remove padding */
4 + margin: 0; /* Remove margins */
5 + }
6 +
7 +.chat-bubble {
8 + position: relative;
9 + padding: 0.5em;
10 + margin-top: 0.25em;
11 + margin-bottom: 0.25em;
12 + border-radius: 0.4em;
13 + color: white;
14 +}
15 +.mine {
16 + background: #00aabb;
17 +}
18 +.bots {
19 + background: #cc78c5;
20 +}
21 +
22 +.chat-bubble:after {
23 + content: "";
24 + position: absolute;
25 + top: 50%;
26 + width: 0;
27 + height: 0;
28 + border: 0.625em solid transparent;
29 + border-top: 0;
30 + margin-top: -0.312em;
31 +
32 +}
33 +.chat-bubble.mine:after {
34 + right: 0;
35 +
36 + border-left-color: #00aabb;
37 + border-right: 0;
38 + margin-right: -0.625em;
39 +}
40 +
41 +.chat-bubble.bots:after {
42 + left: 0;
43 +
44 + border-right-color: #cc78c5;
45 + border-left: 0;
46 + margin-left: -0.625em;
47 +}
48 +
49 +#chat_input {
50 + width: 90%;
51 +}
52 +
53 +#send_button {
54 +
55 + width: 5%;
56 + border-radius: 0.4em;
57 + color: white;
58 + background-color: rgb(15, 145, 138);
59 +}
60 +
61 +.input-holder {
62 + position: fixed;
63 + left: 0;
64 + right: 0;
65 + bottom: 0;
66 + padding: 0.25em;
67 + background-color: lightseagreen;
68 +}
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from konlpy.tag import Mecab
3 +from torch.autograd import Variable
4 +from chatspace import ChatSpace
5 +
6 +spacer = ChatSpace()
7 +
8 +
9 +def tokenizer1(text: str):
10 + result_text = ''.join(c for c in text if c.isalnum())
11 + a = Mecab().morphs(result_text)
12 + return [a[i] for i in range(len(a))]
13 +
14 +
15 +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str):
16 +
17 + enc_input = tokenizer1(sentence)
18 + enc_input_index = []
19 +
20 + for tok in enc_input:
21 + enc_input_index.append(TEXT.vocab.stoi[tok])
22 +
23 + for j in range(max_len - len(enc_input_index)):
24 + enc_input_index.append(TEXT.vocab.stoi['<pad>'])
25 +
26 + enc_input_index = Variable(torch.LongTensor([enc_input_index]))
27 +
28 + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]])
29 +
30 + model.eval()
31 + pred = []
32 + for i in range(max_len):
33 + y_pred = model(enc_input_index.to(device), dec_input.to(device))
34 + y_pred_ids = y_pred.max(dim=-1)[1]
35 + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']:
36 + y_pred_ids = y_pred_ids.squeeze(0)
37 + print(">", end=" ")
38 + for idx in range(len(y_pred_ids)):
39 + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
40 + pred_sentence = "".join(pred)
41 + pred_str = spacer.space(pred_sentence)
42 + return pred_str
43 + else:
44 + pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
45 + return 'Error: Sentence is not end'
46 +
47 + dec_input = torch.cat(
48 + [dec_input.to(torch.device('cpu')),
49 + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
50 + return 'Error: Sentence is not predicted'
1 +import argparse
2 +import time
3 +import torch
4 +from torch import nn
5 +from torchtext import data
6 +from torchtext.data import BucketIterator
7 +from torchtext.data import TabularDataset
8 +
9 +from Styling import styling, make_special_token
10 +from generation import inference, tokenizer1
11 +from model import Transformer, GradualWarmupScheduler
12 +
13 +SEED = 1234
14 +
15 +
16 +
17 +
18 +def acc(yhat: torch.Tensor, y: torch.Tensor):
19 + with torch.no_grad():
20 + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value
21 + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거
22 + return _acc
23 +
24 +
25 +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool):
26 + total_loss = 0
27 + iter_num = 0
28 + tr_acc = 0
29 + model.train()
30 +
31 + for step, batch in enumerate(iterator):
32 + optimizer.zero_grad()
33 +
34 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
35 + dec_output = dec_input[:, 1:]
36 + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data)
37 +
38 + # emotion 과 체를 반영
39 + enc_input, dec_input, dec_outputs = \
40 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL)
41 +
42 + y_pred = model(enc_input, dec_input)
43 +
44 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
45 + dec_output = dec_outputs.view(-1).long()
46 +
47 + # padding 제외한 value index 추출
48 + real_value_index = [dec_output != 1] # <pad> == 1
49 +
50 + # padding 은 loss 계산시 제외
51 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
52 + loss.backward()
53 + optimizer.step()
54 +
55 + with torch.no_grad():
56 + train_acc = acc(y_pred, dec_output)
57 +
58 + total_loss += loss
59 + iter_num += 1
60 + tr_acc += train_acc
61 +
62 + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num
63 +
64 +
65 +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss):
66 + total_loss = 0
67 + iter_num = 0
68 + te_acc = 0
69 + model.eval()
70 +
71 + with torch.no_grad():
72 + for batch in iterator:
73 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
74 + dec_output = dec_input[:, 1:]
75 + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data)
76 +
77 + # emotion 과 체를 반영
78 + enc_input, dec_input, dec_outputs = \
79 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL)
80 +
81 + y_pred = model(enc_input, dec_input)
82 +
83 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
84 + dec_output = dec_outputs.view(-1).long()
85 +
86 + real_value_index = [dec_output != 1] # <pad> == 1
87 +
88 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
89 +
90 + with torch.no_grad():
91 + test_acc = acc(y_pred, dec_output)
92 + total_loss += loss
93 + iter_num += 1
94 + te_acc += test_acc
95 +
96 + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num
97 +
98 +
99 +# 데이터 전처리 및 loader return
100 +def data_preprocessing(args, device):
101 + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임.
102 + ID = data.Field(sequential=False,
103 + use_vocab=False)
104 +
105 + TEXT = data.Field(sequential=True,
106 + use_vocab=True,
107 + tokenize=tokenizer1,
108 + batch_first=True,
109 + fix_length=args.max_len,
110 + dtype=torch.int32
111 + )
112 +
113 + LABEL = data.Field(sequential=True,
114 + use_vocab=True,
115 + tokenize=tokenizer1,
116 + batch_first=True,
117 + fix_length=args.max_len,
118 + init_token='<sos>',
119 + eos_token='<eos>',
120 + dtype=torch.int32
121 + )
122 +
123 + SA = data.Field(sequential=False,
124 + use_vocab=False)
125 +
126 + train_data, test_data = TabularDataset.splits(
127 + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv',
128 + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True
129 + )
130 +
131 + # TEXT, LABEL 에 필요한 special token 만듦.
132 + text_specials, label_specials = make_special_token(args.per_rough)
133 +
134 + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials)
135 + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials)
136 +
137 + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
138 + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
139 +
140 + return TEXT, LABEL, train_loader, test_loader
141 +
142 +
143 +def main(TEXT, LABEL, arguments):
144 +
145 + # print argparse
146 + for idx, (key, value) in enumerate(args.__dict__.items()):
147 + if idx == 0:
148 + print("\nargparse{\n", "\t", key, ":", value)
149 + elif idx == len(args.__dict__) - 1:
150 + print("\t", key, ":", value, "\n}")
151 + else:
152 + print("\t", key, ":", value)
153 +
154 + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL)
155 + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>'])
156 + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr)
157 + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs)
158 + if args.per_soft:
159 + sorted_path = 'sorted_model-soft.pth'
160 + else:
161 + sorted_path = 'sorted_model-rough.pth'
162 + model.to(device)
163 + if arguments.train:
164 + best_valid_loss = float('inf')
165 + for epoch in range(arguments.num_epochs):
166 + torch.manual_seed(SEED)
167 + start_time = time.time()
168 +
169 + # train, validation
170 + train_loss, train_acc = \
171 + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft,
172 + arguments.per_rough)
173 + valid_loss, valid_acc = test(model, test_loader, criterion)
174 +
175 + scheduler.step(epoch)
176 + # time cal
177 + end_time = time.time()
178 + elapsed_time = end_time - start_time
179 + epoch_mins = int(elapsed_time / 60)
180 + epoch_secs = int(elapsed_time - (epoch_mins * 60))
181 +
182 + # torch.save(model.state_dict(), sorted_path) # for some overfitting
183 + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장.
184 + if valid_loss < best_valid_loss:
185 + best_valid_loss = valid_loss
186 + torch.save({
187 + 'epoch': epoch,
188 + 'model_state_dict': model.state_dict(),
189 + 'optimizer_state_dict': optimizer.state_dict(),
190 + 'loss': valid_loss},
191 + sorted_path)
192 + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##')
193 +
194 + # print loss and acc
195 + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==')
196 + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==')
197 + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n')
198 +
199 +
200 +
201 + checkpoint = torch.load(sorted_path, map_location=device)
202 + model.load_state_dict(checkpoint['model_state_dict'])
203 +
204 + test_loss, test_acc = test(model, test_loader, criterion) # 아
205 + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==')
206 + print("\t-----------------------------")
207 + while True:
208 + sentence = input("문장을 입력하세요 : ")
209 + print(inference(device, args.max_len, TEXT, LABEL, model, sentence))
210 + print("\n")
211 +
212 +
213 +if __name__ == '__main__':
214 + # argparse 정의
215 + parser = argparse.ArgumentParser()
216 + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김.
217 + parser.add_argument('--batch_size', type=int, default=256)
218 + parser.add_argument('--num_epochs', type=int, default=22)
219 + parser.add_argument('--warming_up_epochs', type=int, default=5)
220 + parser.add_argument('--lr', type=float, default=0.0002)
221 + parser.add_argument('--embedding_dim', type=int, default=160)
222 + parser.add_argument('--nlayers', type=int, default=2)
223 + parser.add_argument('--nhead', type=int, default=2)
224 + parser.add_argument('--dropout', type=float, default=0.1)
225 + parser.add_argument('--train', action="store_true")
226 + group = parser.add_mutually_exclusive_group()
227 + group.add_argument('--per_soft', action="store_true")
228 + group.add_argument('--per_rough', action="store_true")
229 + args = parser.parse_args()
230 + print("-준비중-")
231 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
232 + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device)
233 + main(TEXT, LABEL, args)
1 +<!DOCTYPE html>
2 +<html>
3 + <head>
4 + <meta charset="UTF-8">
5 + <meta name="viewport" content="width=device-width, initial-scale=1">
6 + <title>Emotional Chatbot with Styler</title>
7 + <script src="app.js"></script>
8 + <link rel="stylesheet" type="text/css" href="chat.css" />
9 + </head>
10 + <body onload="setDefault()">
11 + <ul id="chat_list" class="list no-bullets">
12 +<li class="chat-bubble mine">(대충 적당한 대사)</li>
13 +<li class="chat-bubble bots">(대충 알맞은 답변)</li>
14 + </ul>
15 + <div class="input-holder">
16 + <input type="text" id="chat_input" autofocus/>
17 + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled>
18 + </div>
19 + </body>
20 +</html>
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import math
4 +
5 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
6 +
7 +
8 +class Transformer(nn.Module):
9 + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab):
10 + super(Transformer, self).__init__()
11 + self.d_model = embedding_dim
12 + self.n_head = nhead
13 + self.num_encoder_layers = nlayers
14 + self.num_decoder_layers = nlayers
15 + self.dim_feedforward = embedding_dim
16 + self.dropout = dropout
17 +
18 + self.SRC_vo = SRC_vocab
19 + self.TRG_vo = TRG_vocab
20 +
21 + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout)
22 +
23 + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model)
24 + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model)
25 +
26 + self.transformer = nn.Transformer(d_model=self.d_model,
27 + nhead=self.n_head,
28 + num_encoder_layers=self.num_encoder_layers,
29 + num_decoder_layers=self.num_decoder_layers,
30 + dim_feedforward=self.dim_feedforward,
31 + dropout=self.dropout)
32 + self.proj_vocab_layer = nn.Linear(
33 + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab))
34 +
35 +
36 + def forward(self, en_input, de_input):
37 + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model)
38 + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model)
39 + x_en_embed = self.pos_encoder(x_en_embed)
40 + x_de_embed = self.pos_encoder(x_de_embed)
41 +
42 + # Masking
43 + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>']
44 + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>']
45 + memory_key_padding_mask = src_key_padding_mask
46 + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1))
47 +
48 + x_en_embed = torch.einsum('ijk->jik', x_en_embed)
49 + x_de_embed = torch.einsum('ijk->jik', x_de_embed)
50 +
51 + feature = self.transformer(src=x_en_embed,
52 + tgt=x_de_embed,
53 + src_key_padding_mask=src_key_padding_mask,
54 + tgt_key_padding_mask=tgt_key_padding_mask,
55 + memory_key_padding_mask=memory_key_padding_mask,
56 + tgt_mask=tgt_mask.to(device))
57 +
58 + logits = self.proj_vocab_layer(feature)
59 + logits = torch.einsum('ijk->jik', logits)
60 +
61 + return logits
62 +
63 +
64 +class PositionalEncoding(nn.Module):
65 +
66 + def __init__(self, d_model, dropout, max_len=15000):
67 + super(PositionalEncoding, self).__init__()
68 + self.dropout = nn.Dropout(p=dropout)
69 +
70 + pe = torch.zeros(max_len, d_model)
71 + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
72 + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
73 + pe[:, 0::2] = torch.sin(position * div_term)
74 + pe[:, 1::2] = torch.cos(position * div_term)
75 + pe = pe.unsqueeze(0).transpose(0, 1)
76 + self.register_buffer('pe', pe)
77 +
78 + def forward(self, x):
79 + x = x + self.pe[:x.size(0), :]
80 + return self.dropout(x)
81 +
82 +
83 +from torch.optim.lr_scheduler import _LRScheduler
84 +from torch.optim.lr_scheduler import ReduceLROnPlateau
85 +
86 +
87 +class GradualWarmupScheduler(_LRScheduler):
88 + """ Gradually warm-up(increasing) learning rate in optimizer.
89 + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
90 + Args:
91 + optimizer (Optimizer): Wrapped optimizer.
92 + multiplier: target learning rate = base lr * multiplier
93 + total_epoch: target learning rate is reached at total_epoch, gradually
94 + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
95 + """
96 +
97 + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
98 + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
99 + self.multiplier = multiplier
100 + if self.multiplier <= 1.:
101 + raise ValueError('multiplier should be greater than 1.')
102 + self.total_epoch = total_epoch
103 + self.after_scheduler = after_scheduler
104 + self.finished = False
105 + super().__init__(optimizer)
106 +
107 + def get_lr(self):
108 + if self.last_epoch > self.total_epoch:
109 + if self.after_scheduler:
110 + if not self.finished:
111 + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
112 + self.finished = True
113 + return self.after_scheduler.get_lr()
114 + return [base_lr * self.multiplier for base_lr in self.base_lrs]
115 +
116 + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
117 + self.base_lrs]
118 +
119 + def step_ReduceLROnPlateau(self, metrics, epoch=None):
120 + if epoch is None:
121 + epoch = self.last_epoch + 1
122 + self.last_epoch = epoch if epoch != 0 else 1
123 + if self.last_epoch <= self.total_epoch:
124 + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
125 + self.base_lrs]
126 + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
127 + param_group['lr'] = lr
128 + else:
129 + if epoch is None:
130 + self.after_scheduler.step(metrics, None)
131 + else:
132 + self.after_scheduler.step(metrics, epoch - self.total_epoch)
133 +
134 + def step(self, epoch=None, metrics=None):
135 + if type(self.after_scheduler) != ReduceLROnPlateau:
136 + if self.finished and self.after_scheduler:
137 + if epoch is None:
138 + self.after_scheduler.step(None)
139 + else:
140 + self.after_scheduler.step(epoch - self.total_epoch)
141 + else:
142 + return super(GradualWarmupScheduler, self).step(epoch)
143 + else:
144 + self.step_ReduceLROnPlateau(metrics, epoch)
1 +torch~=1.4.0
2 +Flask~=1.1.2
3 +torchtext~=0.6.0
4 +hgtk~=0.1.3
5 +konlpy~=0.5.2
6 +chatspace~=1.0.1
...\ No newline at end of file ...\ No newline at end of file
This file is too large to display.
This file is too large to display.
No preview for this file type
...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : ...@@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 :
10 - Force RTX 2080 Ti 10 - Force RTX 2080 Ti
11 - Python 3.6.8 11 - Python 3.6.8
12 - Pytorch 1.2.0 12 - Pytorch 1.2.0
13 +
14 +# Code
15 +## Chatbot
16 +
17 +### Chatbot_main.py
18 +챗봇 학습 및 시험에 사용되는 메인 파일입니다.
19 +### model.py
20 +챗봇에 이용되는 Transfer 모델 클래스 파일입니다.
21 +### generation.py
22 +추론 및 Beam search, Greedy search를 하는 파일입니다.
23 +### metric.py
24 +학습 성능을 측정하기 위한 모델입니다.\
25 +`acc(yhat, y)`\
26 +### Styling.py
27 +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다.
28 +### get_data.py
29 +데이터셋을 전처리하고 불러오기 위한 파일입니다.\
30 +`tokenizer1(text)`\
31 +* text: 토크나이징할 문자열
32 +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\
33 +`data_preprocessing(args, device)`\
34 +* args: argparser로 파싱한 NamedTuple
35 +* device: pytorch device
36 +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다.
37 +
38 +## KoBERT
39 +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\
40 +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\
41 +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\
42 +## Light_model
43 +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다.
44 +### light_chatbot.py
45 +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다.
46 +`light_chatbot.py [--train] [--per_soft|--per_rough]`
47 +
48 +* train: 학습해 모델을 만들 경우에 사용합니다.
49 +사용하지 않으면 모델을 불러와 시험 합니다.
50 +* per_soft: soft 말투를 학습 또는 시험합니다.
51 +* per_rough: rough 말투를 학습 또는 시험합니다.
52 +두 옵션은 양립 불가능합니다.
53 +### app.py
54 +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\
55 +`POST /api/soft`\
56 +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\
57 +`GET /`\
58 +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다.
59 +### 기타
60 +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다.
...\ No newline at end of file ...\ No newline at end of file