import torch, torch.nn as nn, torch.nn.functional as F from transformers import DistilBertModel class CNNBranch(nn.Module): def __init__(self, vocab=30522, edim=128, nf=128, kernels=(2,3,4)): super().__init__() self.emb=nn.Embedding(vocab,edim,padding_idx=0) self.convs=nn.ModuleList([nn.Conv1d(edim,nf,k,padding=k//2) for k in kernels]) self.drop=nn.Dropout(0.3); self.out_dim=nf*len(kernels) def forward(self,ids): x=self.emb(ids).permute(0,2,1) return self.drop(torch.cat([F.adaptive_max_pool1d(F.gelu(c(x)),1).squeeze(2) for c in self.convs],1)) class HybridClassifier(nn.Module): def __init__(self, n_labels=6, vocab=30522, edim=128, nf=128, kernels=(2,3,4), drop=0.3): super().__init__() self.bert=DistilBertModel.from_pretrained("distilbert-base-uncased") self.cnn=CNNBranch(vocab,edim,nf,kernels) fused=768+self.cnn.out_dim self.head=nn.Sequential(nn.Dropout(drop),nn.Linear(fused,512),nn.GELU(),nn.Dropout(drop*0.7),nn.Linear(512,n_labels)) self.n_labels=n_labels def forward(self, input_ids, attention_mask): cls=self.bert(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state[:,0,:] cnn=self.cnn(input_ids) return self.head(torch.cat([cls,cnn],1))