Showing
4 changed files
with
284 additions
and
0 deletions
KoBERT/kobert/__init__.py
0 → 100644
| 1 | +# coding=utf-8 | ||
| 2 | +# Copyright 2019 SK T-Brain Authors. | ||
| 3 | +# | ||
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | +# you may not use this file except in compliance with the License. | ||
| 6 | +# You may obtain a copy of the License at | ||
| 7 | +# | ||
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | +# | ||
| 10 | +# Unless required by applicable law or agreed to in writing, software | ||
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | +# See the License for the specific language governing permissions and | ||
| 14 | +# limitations under the License. | ||
| 15 | +__version__ = '0.1.1' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
KoBERT/kobert/mxnet_kobert.py
0 → 100644
| 1 | +# coding=utf-8 | ||
| 2 | +# Copyright 2019 SK T-Brain Authors. | ||
| 3 | +# | ||
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | +# you may not use this file except in compliance with the License. | ||
| 6 | +# You may obtain a copy of the License at | ||
| 7 | +# | ||
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | +# | ||
| 10 | +# Unless required by applicable law or agreed to in writing, software | ||
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | +# See the License for the specific language governing permissions and | ||
| 14 | +# limitations under the License. | ||
| 15 | + | ||
| 16 | +import os | ||
| 17 | +import sys | ||
| 18 | +import requests | ||
| 19 | +import hashlib | ||
| 20 | + | ||
| 21 | +import mxnet as mx | ||
| 22 | +import gluonnlp as nlp | ||
| 23 | +from gluonnlp.model import BERTModel, BERTEncoder | ||
| 24 | + | ||
| 25 | +from .utils import download as _download | ||
| 26 | +from .utils import tokenizer | ||
| 27 | + | ||
| 28 | +mxnet_kobert = { | ||
| 29 | + 'url': | ||
| 30 | + 'https://kobert.blob.core.windows.net/models/kobert/mxnet/mxnet_kobert_45b6957552.params', | ||
| 31 | + 'fname': 'mxnet_kobert_45b6957552.params', | ||
| 32 | + 'chksum': '45b6957552' | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def get_mxnet_kobert_model(use_pooler=True, | ||
| 37 | + use_decoder=True, | ||
| 38 | + use_classifier=True, | ||
| 39 | + ctx=mx.cpu(0), | ||
| 40 | + cachedir='~/kobert/'): | ||
| 41 | + # download model | ||
| 42 | + model_info = mxnet_kobert | ||
| 43 | + model_path = _download(model_info['url'], | ||
| 44 | + model_info['fname'], | ||
| 45 | + model_info['chksum'], | ||
| 46 | + cachedir=cachedir) | ||
| 47 | + # download vocab | ||
| 48 | + vocab_info = tokenizer | ||
| 49 | + vocab_path = _download(vocab_info['url'], | ||
| 50 | + vocab_info['fname'], | ||
| 51 | + vocab_info['chksum'], | ||
| 52 | + cachedir=cachedir) | ||
| 53 | + return get_kobert_model(model_path, vocab_path, use_pooler, use_decoder, | ||
| 54 | + use_classifier, ctx) | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +def get_kobert_model(model_file, | ||
| 58 | + vocab_file, | ||
| 59 | + use_pooler=True, | ||
| 60 | + use_decoder=True, | ||
| 61 | + use_classifier=True, | ||
| 62 | + ctx=mx.cpu(0)): | ||
| 63 | + vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, | ||
| 64 | + padding_token='[PAD]') | ||
| 65 | + | ||
| 66 | + predefined_args = { | ||
| 67 | + 'attention_cell': 'multi_head', | ||
| 68 | + 'num_layers': 12, | ||
| 69 | + 'units': 768, | ||
| 70 | + 'hidden_size': 3072, | ||
| 71 | + 'max_length': 512, | ||
| 72 | + 'num_heads': 12, | ||
| 73 | + 'scaled': True, | ||
| 74 | + 'dropout': 0.1, | ||
| 75 | + 'use_residual': True, | ||
| 76 | + 'embed_size': 768, | ||
| 77 | + 'embed_dropout': 0.1, | ||
| 78 | + 'token_type_vocab_size': 2, | ||
| 79 | + 'word_embed': None, | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], | ||
| 83 | + num_layers=predefined_args['num_layers'], | ||
| 84 | + units=predefined_args['units'], | ||
| 85 | + hidden_size=predefined_args['hidden_size'], | ||
| 86 | + max_length=predefined_args['max_length'], | ||
| 87 | + num_heads=predefined_args['num_heads'], | ||
| 88 | + scaled=predefined_args['scaled'], | ||
| 89 | + dropout=predefined_args['dropout'], | ||
| 90 | + output_attention=False, | ||
| 91 | + output_all_encodings=False, | ||
| 92 | + use_residual=predefined_args['use_residual']) | ||
| 93 | + | ||
| 94 | + # BERT | ||
| 95 | + net = BERTModel( | ||
| 96 | + encoder, | ||
| 97 | + len(vocab_b_obj.idx_to_token), | ||
| 98 | + token_type_vocab_size=predefined_args['token_type_vocab_size'], | ||
| 99 | + units=predefined_args['units'], | ||
| 100 | + embed_size=predefined_args['embed_size'], | ||
| 101 | + embed_dropout=predefined_args['embed_dropout'], | ||
| 102 | + word_embed=predefined_args['word_embed'], | ||
| 103 | + use_pooler=use_pooler, | ||
| 104 | + use_decoder=use_decoder, | ||
| 105 | + use_classifier=use_classifier) | ||
| 106 | + net.initialize(ctx=ctx) | ||
| 107 | + net.load_parameters(model_file, ctx, ignore_extra=True) | ||
| 108 | + return (net, vocab_b_obj) |
KoBERT/kobert/pytorch_kobert.py
0 → 100644
| 1 | +# coding=utf-8 | ||
| 2 | +# Copyright 2019 SK T-Brain Authors. | ||
| 3 | +# | ||
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | +# you may not use this file except in compliance with the License. | ||
| 6 | +# You may obtain a copy of the License at | ||
| 7 | +# | ||
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | +# | ||
| 10 | +# Unless required by applicable law or agreed to in writing, software | ||
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | +# See the License for the specific language governing permissions and | ||
| 14 | +# limitations under the License. | ||
| 15 | + | ||
| 16 | +import os | ||
| 17 | +import sys | ||
| 18 | +import requests | ||
| 19 | +import hashlib | ||
| 20 | + | ||
| 21 | +import torch | ||
| 22 | + | ||
| 23 | +from transformers import BertModel, BertConfig | ||
| 24 | +import gluonnlp as nlp | ||
| 25 | + | ||
| 26 | +from .utils import download as _download | ||
| 27 | +from .utils import tokenizer | ||
| 28 | + | ||
| 29 | +pytorch_kobert = { | ||
| 30 | + 'url': | ||
| 31 | + 'https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params', | ||
| 32 | + 'fname': 'pytorch_kobert_2439f391a6.params', | ||
| 33 | + 'chksum': '2439f391a6' | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +bert_config = { | ||
| 37 | + 'attention_probs_dropout_prob': 0.1, | ||
| 38 | + 'hidden_act': 'gelu', | ||
| 39 | + 'hidden_dropout_prob': 0.1, | ||
| 40 | + 'hidden_size': 768, | ||
| 41 | + 'initializer_range': 0.02, | ||
| 42 | + 'intermediate_size': 3072, | ||
| 43 | + 'max_position_embeddings': 512, | ||
| 44 | + 'num_attention_heads': 12, | ||
| 45 | + 'num_hidden_layers': 12, | ||
| 46 | + 'type_vocab_size': 2, | ||
| 47 | + 'vocab_size': 8002 | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def get_pytorch_kobert_model(ctx='cpu', cachedir='~/kobert/'): | ||
| 52 | + # download model | ||
| 53 | + model_info = pytorch_kobert | ||
| 54 | + model_path = _download(model_info['url'], | ||
| 55 | + model_info['fname'], | ||
| 56 | + model_info['chksum'], | ||
| 57 | + cachedir=cachedir) | ||
| 58 | + # download vocab | ||
| 59 | + vocab_info = tokenizer | ||
| 60 | + vocab_path = _download(vocab_info['url'], | ||
| 61 | + vocab_info['fname'], | ||
| 62 | + vocab_info['chksum'], | ||
| 63 | + cachedir=cachedir) | ||
| 64 | + return get_kobert_model(model_path, vocab_path, ctx) | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +def get_kobert_model(model_file, vocab_file, ctx="cpu"): | ||
| 68 | + bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) | ||
| 69 | + bertmodel.load_state_dict(torch.load(model_file)) | ||
| 70 | + #bertmodel = bertmodel.from_pretrained('https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params', output_hidden_states=True) | ||
| 71 | + device = torch.device(ctx) | ||
| 72 | + bertmodel.to(device) | ||
| 73 | + bertmodel.eval() | ||
| 74 | + vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, | ||
| 75 | + padding_token='[PAD]') | ||
| 76 | + return bertmodel, vocab_b_obj |
KoBERT/kobert/utils.py
0 → 100644
| 1 | +# coding=utf-8 | ||
| 2 | +# Copyright 2019 SK T-Brain Authors. | ||
| 3 | +# | ||
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | +# you may not use this file except in compliance with the License. | ||
| 6 | +# You may obtain a copy of the License at | ||
| 7 | +# | ||
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | +# | ||
| 10 | +# Unless required by applicable law or agreed to in writing, software | ||
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | +# See the License for the specific language governing permissions and | ||
| 14 | +# limitations under the License. | ||
| 15 | + | ||
| 16 | +import os | ||
| 17 | +import sys | ||
| 18 | +import requests | ||
| 19 | +import hashlib | ||
| 20 | + | ||
| 21 | +onnx_kobert = { | ||
| 22 | + 'url': | ||
| 23 | + 'https://kobert.blob.core.windows.net/models/kobert/onnx/onnx_kobert_44529811f0.onnx', | ||
| 24 | + 'fname': 'onnx_kobert_44529811f0.onnx', | ||
| 25 | + 'chksum': '44529811f0' | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +tokenizer = { | ||
| 29 | + 'url': | ||
| 30 | + 'https://kobert.blob.core.windows.net/models/kobert/tokenizer/kobert_news_wiki_ko_cased-ae5711deb3.spiece', | ||
| 31 | + 'fname': 'kobert_news_wiki_ko_cased-1087f8699e.spiece', | ||
| 32 | + 'chksum': 'ae5711deb3' | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def download(url, filename, chksum, cachedir='~/kobert/'): | ||
| 37 | + f_cachedir = os.path.expanduser(cachedir) | ||
| 38 | + os.makedirs(f_cachedir, exist_ok=True) | ||
| 39 | + file_path = os.path.join(f_cachedir, filename) | ||
| 40 | + if os.path.isfile(file_path): | ||
| 41 | + if hashlib.md5(open(file_path, | ||
| 42 | + 'rb').read()).hexdigest()[:10] == chksum: | ||
| 43 | + print('using cached model') | ||
| 44 | + return file_path | ||
| 45 | + with open(file_path, 'wb') as f: | ||
| 46 | + response = requests.get(url, stream=True) | ||
| 47 | + total = response.headers.get('content-length') | ||
| 48 | + | ||
| 49 | + if total is None: | ||
| 50 | + f.write(response.content) | ||
| 51 | + else: | ||
| 52 | + downloaded = 0 | ||
| 53 | + total = int(total) | ||
| 54 | + for data in response.iter_content( | ||
| 55 | + chunk_size=max(int(total / 1000), 1024 * 1024)): | ||
| 56 | + downloaded += len(data) | ||
| 57 | + f.write(data) | ||
| 58 | + done = int(50 * downloaded / total) | ||
| 59 | + sys.stdout.write('\r[{}{}]'.format('█' * done, | ||
| 60 | + '.' * (50 - done))) | ||
| 61 | + sys.stdout.flush() | ||
| 62 | + sys.stdout.write('\n') | ||
| 63 | + assert chksum == hashlib.md5(open( | ||
| 64 | + file_path, 'rb').read()).hexdigest()[:10], 'corrupted file!' | ||
| 65 | + return file_path | ||
| 66 | + | ||
| 67 | + | ||
| 68 | +def get_onnx(cachedir='~/kobert/'): | ||
| 69 | + """Get KoBERT ONNX file path after downloading | ||
| 70 | + """ | ||
| 71 | + model_info = onnx_kobert | ||
| 72 | + return download(model_info['url'], | ||
| 73 | + model_info['fname'], | ||
| 74 | + model_info['chksum'], | ||
| 75 | + cachedir=cachedir) | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def get_tokenizer(cachedir='~/kobert/'): | ||
| 79 | + """Get KoBERT Tokenizer file path after downloading | ||
| 80 | + """ | ||
| 81 | + model_info = tokenizer | ||
| 82 | + return download(model_info['url'], | ||
| 83 | + model_info['fname'], | ||
| 84 | + model_info['chksum'], | ||
| 85 | + cachedir=cachedir) |
-
Please register or login to post a comment