Sarjinkhan2003's picture
Upload folder using huggingface_hub
d8c6bd7 verified
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import DistilBertModel
class CNNBranch(nn.Module):
def __init__(self, vocab=30522, edim=128, nf=128, kernels=(2,3,4)):
super().__init__()
self.emb=nn.Embedding(vocab,edim,padding_idx=0)
self.convs=nn.ModuleList([nn.Conv1d(edim,nf,k,padding=k//2) for k in kernels])
self.drop=nn.Dropout(0.3); self.out_dim=nf*len(kernels)
def forward(self,ids):
x=self.emb(ids).permute(0,2,1)
return self.drop(torch.cat([F.adaptive_max_pool1d(F.gelu(c(x)),1).squeeze(2) for c in self.convs],1))
class HybridClassifier(nn.Module):
def __init__(self, n_labels=6, vocab=30522, edim=128, nf=128, kernels=(2,3,4), drop=0.3):
super().__init__()
self.bert=DistilBertModel.from_pretrained("distilbert-base-uncased")
self.cnn=CNNBranch(vocab,edim,nf,kernels)
fused=768+self.cnn.out_dim
self.head=nn.Sequential(nn.Dropout(drop),nn.Linear(fused,512),nn.GELU(),nn.Dropout(drop*0.7),nn.Linear(512,n_labels))
self.n_labels=n_labels
def forward(self, input_ids, attention_mask):
cls=self.bert(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state[:,0,:]
cnn=self.cnn(input_ids)
return self.head(torch.cat([cls,cnn],1))