Showing
4 changed files
with
284 additions
and
0 deletions
KoBERT/kobert/__init__.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2019 SK T-Brain Authors. | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | +__version__ = '0.1.1' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
KoBERT/kobert/mxnet_kobert.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2019 SK T-Brain Authors. | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | + | ||
16 | +import os | ||
17 | +import sys | ||
18 | +import requests | ||
19 | +import hashlib | ||
20 | + | ||
21 | +import mxnet as mx | ||
22 | +import gluonnlp as nlp | ||
23 | +from gluonnlp.model import BERTModel, BERTEncoder | ||
24 | + | ||
25 | +from .utils import download as _download | ||
26 | +from .utils import tokenizer | ||
27 | + | ||
28 | +mxnet_kobert = { | ||
29 | + 'url': | ||
30 | + 'https://kobert.blob.core.windows.net/models/kobert/mxnet/mxnet_kobert_45b6957552.params', | ||
31 | + 'fname': 'mxnet_kobert_45b6957552.params', | ||
32 | + 'chksum': '45b6957552' | ||
33 | +} | ||
34 | + | ||
35 | + | ||
36 | +def get_mxnet_kobert_model(use_pooler=True, | ||
37 | + use_decoder=True, | ||
38 | + use_classifier=True, | ||
39 | + ctx=mx.cpu(0), | ||
40 | + cachedir='~/kobert/'): | ||
41 | + # download model | ||
42 | + model_info = mxnet_kobert | ||
43 | + model_path = _download(model_info['url'], | ||
44 | + model_info['fname'], | ||
45 | + model_info['chksum'], | ||
46 | + cachedir=cachedir) | ||
47 | + # download vocab | ||
48 | + vocab_info = tokenizer | ||
49 | + vocab_path = _download(vocab_info['url'], | ||
50 | + vocab_info['fname'], | ||
51 | + vocab_info['chksum'], | ||
52 | + cachedir=cachedir) | ||
53 | + return get_kobert_model(model_path, vocab_path, use_pooler, use_decoder, | ||
54 | + use_classifier, ctx) | ||
55 | + | ||
56 | + | ||
57 | +def get_kobert_model(model_file, | ||
58 | + vocab_file, | ||
59 | + use_pooler=True, | ||
60 | + use_decoder=True, | ||
61 | + use_classifier=True, | ||
62 | + ctx=mx.cpu(0)): | ||
63 | + vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, | ||
64 | + padding_token='[PAD]') | ||
65 | + | ||
66 | + predefined_args = { | ||
67 | + 'attention_cell': 'multi_head', | ||
68 | + 'num_layers': 12, | ||
69 | + 'units': 768, | ||
70 | + 'hidden_size': 3072, | ||
71 | + 'max_length': 512, | ||
72 | + 'num_heads': 12, | ||
73 | + 'scaled': True, | ||
74 | + 'dropout': 0.1, | ||
75 | + 'use_residual': True, | ||
76 | + 'embed_size': 768, | ||
77 | + 'embed_dropout': 0.1, | ||
78 | + 'token_type_vocab_size': 2, | ||
79 | + 'word_embed': None, | ||
80 | + } | ||
81 | + | ||
82 | + encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], | ||
83 | + num_layers=predefined_args['num_layers'], | ||
84 | + units=predefined_args['units'], | ||
85 | + hidden_size=predefined_args['hidden_size'], | ||
86 | + max_length=predefined_args['max_length'], | ||
87 | + num_heads=predefined_args['num_heads'], | ||
88 | + scaled=predefined_args['scaled'], | ||
89 | + dropout=predefined_args['dropout'], | ||
90 | + output_attention=False, | ||
91 | + output_all_encodings=False, | ||
92 | + use_residual=predefined_args['use_residual']) | ||
93 | + | ||
94 | + # BERT | ||
95 | + net = BERTModel( | ||
96 | + encoder, | ||
97 | + len(vocab_b_obj.idx_to_token), | ||
98 | + token_type_vocab_size=predefined_args['token_type_vocab_size'], | ||
99 | + units=predefined_args['units'], | ||
100 | + embed_size=predefined_args['embed_size'], | ||
101 | + embed_dropout=predefined_args['embed_dropout'], | ||
102 | + word_embed=predefined_args['word_embed'], | ||
103 | + use_pooler=use_pooler, | ||
104 | + use_decoder=use_decoder, | ||
105 | + use_classifier=use_classifier) | ||
106 | + net.initialize(ctx=ctx) | ||
107 | + net.load_parameters(model_file, ctx, ignore_extra=True) | ||
108 | + return (net, vocab_b_obj) |
KoBERT/kobert/pytorch_kobert.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2019 SK T-Brain Authors. | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | + | ||
16 | +import os | ||
17 | +import sys | ||
18 | +import requests | ||
19 | +import hashlib | ||
20 | + | ||
21 | +import torch | ||
22 | + | ||
23 | +from transformers import BertModel, BertConfig | ||
24 | +import gluonnlp as nlp | ||
25 | + | ||
26 | +from .utils import download as _download | ||
27 | +from .utils import tokenizer | ||
28 | + | ||
29 | +pytorch_kobert = { | ||
30 | + 'url': | ||
31 | + 'https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params', | ||
32 | + 'fname': 'pytorch_kobert_2439f391a6.params', | ||
33 | + 'chksum': '2439f391a6' | ||
34 | +} | ||
35 | + | ||
36 | +bert_config = { | ||
37 | + 'attention_probs_dropout_prob': 0.1, | ||
38 | + 'hidden_act': 'gelu', | ||
39 | + 'hidden_dropout_prob': 0.1, | ||
40 | + 'hidden_size': 768, | ||
41 | + 'initializer_range': 0.02, | ||
42 | + 'intermediate_size': 3072, | ||
43 | + 'max_position_embeddings': 512, | ||
44 | + 'num_attention_heads': 12, | ||
45 | + 'num_hidden_layers': 12, | ||
46 | + 'type_vocab_size': 2, | ||
47 | + 'vocab_size': 8002 | ||
48 | +} | ||
49 | + | ||
50 | + | ||
51 | +def get_pytorch_kobert_model(ctx='cpu', cachedir='~/kobert/'): | ||
52 | + # download model | ||
53 | + model_info = pytorch_kobert | ||
54 | + model_path = _download(model_info['url'], | ||
55 | + model_info['fname'], | ||
56 | + model_info['chksum'], | ||
57 | + cachedir=cachedir) | ||
58 | + # download vocab | ||
59 | + vocab_info = tokenizer | ||
60 | + vocab_path = _download(vocab_info['url'], | ||
61 | + vocab_info['fname'], | ||
62 | + vocab_info['chksum'], | ||
63 | + cachedir=cachedir) | ||
64 | + return get_kobert_model(model_path, vocab_path, ctx) | ||
65 | + | ||
66 | + | ||
67 | +def get_kobert_model(model_file, vocab_file, ctx="cpu"): | ||
68 | + bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) | ||
69 | + bertmodel.load_state_dict(torch.load(model_file)) | ||
70 | + #bertmodel = bertmodel.from_pretrained('https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params', output_hidden_states=True) | ||
71 | + device = torch.device(ctx) | ||
72 | + bertmodel.to(device) | ||
73 | + bertmodel.eval() | ||
74 | + vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, | ||
75 | + padding_token='[PAD]') | ||
76 | + return bertmodel, vocab_b_obj |
KoBERT/kobert/utils.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2019 SK T-Brain Authors. | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | + | ||
16 | +import os | ||
17 | +import sys | ||
18 | +import requests | ||
19 | +import hashlib | ||
20 | + | ||
21 | +onnx_kobert = { | ||
22 | + 'url': | ||
23 | + 'https://kobert.blob.core.windows.net/models/kobert/onnx/onnx_kobert_44529811f0.onnx', | ||
24 | + 'fname': 'onnx_kobert_44529811f0.onnx', | ||
25 | + 'chksum': '44529811f0' | ||
26 | +} | ||
27 | + | ||
28 | +tokenizer = { | ||
29 | + 'url': | ||
30 | + 'https://kobert.blob.core.windows.net/models/kobert/tokenizer/kobert_news_wiki_ko_cased-ae5711deb3.spiece', | ||
31 | + 'fname': 'kobert_news_wiki_ko_cased-1087f8699e.spiece', | ||
32 | + 'chksum': 'ae5711deb3' | ||
33 | +} | ||
34 | + | ||
35 | + | ||
36 | +def download(url, filename, chksum, cachedir='~/kobert/'): | ||
37 | + f_cachedir = os.path.expanduser(cachedir) | ||
38 | + os.makedirs(f_cachedir, exist_ok=True) | ||
39 | + file_path = os.path.join(f_cachedir, filename) | ||
40 | + if os.path.isfile(file_path): | ||
41 | + if hashlib.md5(open(file_path, | ||
42 | + 'rb').read()).hexdigest()[:10] == chksum: | ||
43 | + print('using cached model') | ||
44 | + return file_path | ||
45 | + with open(file_path, 'wb') as f: | ||
46 | + response = requests.get(url, stream=True) | ||
47 | + total = response.headers.get('content-length') | ||
48 | + | ||
49 | + if total is None: | ||
50 | + f.write(response.content) | ||
51 | + else: | ||
52 | + downloaded = 0 | ||
53 | + total = int(total) | ||
54 | + for data in response.iter_content( | ||
55 | + chunk_size=max(int(total / 1000), 1024 * 1024)): | ||
56 | + downloaded += len(data) | ||
57 | + f.write(data) | ||
58 | + done = int(50 * downloaded / total) | ||
59 | + sys.stdout.write('\r[{}{}]'.format('█' * done, | ||
60 | + '.' * (50 - done))) | ||
61 | + sys.stdout.flush() | ||
62 | + sys.stdout.write('\n') | ||
63 | + assert chksum == hashlib.md5(open( | ||
64 | + file_path, 'rb').read()).hexdigest()[:10], 'corrupted file!' | ||
65 | + return file_path | ||
66 | + | ||
67 | + | ||
68 | +def get_onnx(cachedir='~/kobert/'): | ||
69 | + """Get KoBERT ONNX file path after downloading | ||
70 | + """ | ||
71 | + model_info = onnx_kobert | ||
72 | + return download(model_info['url'], | ||
73 | + model_info['fname'], | ||
74 | + model_info['chksum'], | ||
75 | + cachedir=cachedir) | ||
76 | + | ||
77 | + | ||
78 | +def get_tokenizer(cachedir='~/kobert/'): | ||
79 | + """Get KoBERT Tokenizer file path after downloading | ||
80 | + """ | ||
81 | + model_info = tokenizer | ||
82 | + return download(model_info['url'], | ||
83 | + model_info['fname'], | ||
84 | + model_info['chksum'], | ||
85 | + cachedir=cachedir) |
-
Please register or login to post a comment