|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModelForTokenClassification |
|
|
| class CNNForNER(nn.Module): |
| def __init__(self, pretrained_model_name, num_classes, max_length=128): |
| super(CNNForNER, self).__init__() |
| self.transformer = AutoModelForTokenClassification.from_pretrained(pretrained_model_name) |
| self.max_length = max_length |
|
|
| |
| pretrained_num_labels = self.transformer.num_labels |
|
|
| self.conv1 = nn.Conv1d(in_channels=pretrained_num_labels, out_channels=256, kernel_size=3, padding=1) |
| self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1) |
| self.dropout = nn.Dropout(0.3) |
| self.fc = nn.Linear(in_features=128, out_features=num_classes) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs.logits |
| |
| |
| logits = logits.permute(0, 2, 1) |
| conv1_out = F.relu(self.conv1(logits)) |
| conv2_out = F.relu(self.conv2(conv1_out)) |
| conv2_out = self.dropout(conv2_out) |
| conv2_out = conv2_out.permute(0, 2, 1) |
| final_logits = self.fc(conv2_out) |
| return final_logits |
|
|