|
|
| import torch, torch.nn as nn, torch.nn.functional as F |
| from transformers import DistilBertModel |
|
|
| class CNNBranch(nn.Module): |
| def __init__(self, vocab=30522, edim=128, nf=128, kernels=(2,3,4)): |
| super().__init__() |
| self.emb=nn.Embedding(vocab,edim,padding_idx=0) |
| self.convs=nn.ModuleList([nn.Conv1d(edim,nf,k,padding=k//2) for k in kernels]) |
| self.drop=nn.Dropout(0.3); self.out_dim=nf*len(kernels) |
| def forward(self,ids): |
| x=self.emb(ids).permute(0,2,1) |
| return self.drop(torch.cat([F.adaptive_max_pool1d(F.gelu(c(x)),1).squeeze(2) for c in self.convs],1)) |
|
|
| class HybridClassifier(nn.Module): |
| def __init__(self, n_labels=6, vocab=30522, edim=128, nf=128, kernels=(2,3,4), drop=0.3): |
| super().__init__() |
| self.bert=DistilBertModel.from_pretrained("distilbert-base-uncased") |
| self.cnn=CNNBranch(vocab,edim,nf,kernels) |
| fused=768+self.cnn.out_dim |
| self.head=nn.Sequential(nn.Dropout(drop),nn.Linear(fused,512),nn.GELU(),nn.Dropout(drop*0.7),nn.Linear(512,n_labels)) |
| self.n_labels=n_labels |
| def forward(self, input_ids, attention_mask): |
| cls=self.bert(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state[:,0,:] |
| cnn=self.cnn(input_ids) |
| return self.head(torch.cat([cls,cnn],1)) |
|
|