bongminkim

KoBERT_model

1 +import torch
2 +from torch import nn
3 +
4 +class BERTClassifier(nn.Module):
5 + def __init__(self,
6 + bert,
7 + hidden_size=768,
8 + num_classes=2,
9 + dr_rate=None):
10 + super(BERTClassifier, self).__init__()
11 + self.bert = bert
12 + self.dr_rate = dr_rate
13 +
14 + self.classifier = nn.Linear(hidden_size, num_classes)
15 + if dr_rate:
16 + self.dropout = nn.Dropout(p=dr_rate)
17 +
18 + def gen_attention_mask(self, token_ids, valid_length):
19 + attention_mask = torch.zeros_like(token_ids)
20 + for i, v in enumerate(valid_length):
21 + attention_mask[i][:v] = 1
22 + return attention_mask.float()
23 +
24 + def forward(self, token_ids, valid_length, segment_ids):
25 +
26 + attention_mask = self.gen_attention_mask(token_ids, valid_length)
27 +
28 + _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(),
29 + attention_mask=attention_mask.float().to(token_ids.device))
30 +
31 + if self.dr_rate:
32 + out = self.dropout(pooler)
33 + return self.classifier(out)