|
|
| import transformers |
| import torch.nn as nn |
|
|
| class BertClassificationModel(nn.Module): |
| def __init__(self): |
| super(BertClassificationModel, self).__init__() |
| pretrained_weights="bert-base-chinese" |
| self.bert = transformers.BertModel.from_pretrained(pretrained_weights) |
| for param in self.bert.parameters(): |
| param.requires_grad = True |
| self.dense = nn.Linear(768, 3) |
| |
| def forward(self, input_ids,token_type_ids,attention_mask): |
| bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask) |
| bert_cls_hidden_state = bert_output[1] |
| linear_output = self.dense(bert_cls_hidden_state) |
| return linear_output |