Showing
2 changed files
with
33 additions
and
0 deletions
KoBERT/Bert_model.py
0 → 100644
1 | +import torch | ||
2 | +from torch import nn | ||
3 | + | ||
4 | +class BERTClassifier(nn.Module): | ||
5 | + def __init__(self, | ||
6 | + bert, | ||
7 | + hidden_size=768, | ||
8 | + num_classes=2, | ||
9 | + dr_rate=None): | ||
10 | + super(BERTClassifier, self).__init__() | ||
11 | + self.bert = bert | ||
12 | + self.dr_rate = dr_rate | ||
13 | + | ||
14 | + self.classifier = nn.Linear(hidden_size, num_classes) | ||
15 | + if dr_rate: | ||
16 | + self.dropout = nn.Dropout(p=dr_rate) | ||
17 | + | ||
18 | + def gen_attention_mask(self, token_ids, valid_length): | ||
19 | + attention_mask = torch.zeros_like(token_ids) | ||
20 | + for i, v in enumerate(valid_length): | ||
21 | + attention_mask[i][:v] = 1 | ||
22 | + return attention_mask.float() | ||
23 | + | ||
24 | + def forward(self, token_ids, valid_length, segment_ids): | ||
25 | + | ||
26 | + attention_mask = self.gen_attention_mask(token_ids, valid_length) | ||
27 | + | ||
28 | + _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), | ||
29 | + attention_mask=attention_mask.float().to(token_ids.device)) | ||
30 | + | ||
31 | + if self.dr_rate: | ||
32 | + out = self.dropout(pooler) | ||
33 | + return self.classifier(out) |
면담확인서/캡디면담확인서03.jpg
0 → 100644
1.13 MB
-
Please register or login to post a comment