pytorch_kobert.py
2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# coding=utf-8
# Copyright 2019 SK T-Brain Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import requests
import hashlib
import torch
from transformers import BertModel, BertConfig
import gluonnlp as nlp
from .utils import download as _download
from .utils import tokenizer
pytorch_kobert = {
'url':
'https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params',
'fname': 'pytorch_kobert_2439f391a6.params',
'chksum': '2439f391a6'
}
bert_config = {
'attention_probs_dropout_prob': 0.1,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 768,
'initializer_range': 0.02,
'intermediate_size': 3072,
'max_position_embeddings': 512,
'num_attention_heads': 12,
'num_hidden_layers': 12,
'type_vocab_size': 2,
'vocab_size': 8002
}
def get_pytorch_kobert_model(ctx='cpu', cachedir='~/kobert/'):
# download model
model_info = pytorch_kobert
model_path = _download(model_info['url'],
model_info['fname'],
model_info['chksum'],
cachedir=cachedir)
# download vocab
vocab_info = tokenizer
vocab_path = _download(vocab_info['url'],
vocab_info['fname'],
vocab_info['chksum'],
cachedir=cachedir)
return get_kobert_model(model_path, vocab_path, ctx)
def get_kobert_model(model_file, vocab_file, ctx="cpu"):
bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
bertmodel.load_state_dict(torch.load(model_file))
#bertmodel = bertmodel.from_pretrained('https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params', output_hidden_states=True)
device = torch.device(ctx)
bertmodel.to(device)
bertmodel.eval()
vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file,
padding_token='[PAD]')
return bertmodel, vocab_b_obj