| import torch |
| import torch.nn as nn |
| from transformers import BertModel |
|
|
| class MultimodalClassifier(nn.Module): |
| def __init__(self, text_hidden_size=768, image_feat_size=2048, num_classes=2): |
| super(MultimodalClassifier, self).__init__() |
| self.bert = BertModel.from_pretrained("bert-base-uncased") |
|
|
| self.text_fc = nn.Sequential( |
| nn.Linear(text_hidden_size, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.2) |
| ) |
|
|
| self.image_fc = nn.Sequential( |
| nn.Linear(image_feat_size, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.2) |
| ) |
|
|
| self.fusion_fc = nn.Sequential( |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, 64), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(64, num_classes) |
| ) |
|
|
| def forward(self, input_ids, attention_mask, image_vector): |
| text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| text_feat = self.text_fc[0](text_output.pooler_output) |
| if text_feat.size(0) > 1: |
| text_feat = self.text_fc[1:](text_feat) |
| else: |
| text_feat = self.text_fc[2:](text_feat) |
|
|
| image_feat = self.image_fc[0](image_vector) |
| if image_feat.size(0) > 1: |
| image_feat = self.image_fc[1:](image_feat) |
| else: |
| image_feat = self.image_fc[2:](image_feat) |
|
|
| fused = torch.cat((text_feat, image_feat), dim=1) |
| logits = self.fusion_fc(fused) |
| return logits |
|
|