File size: 9,204 Bytes
e384945 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
Multimodal Fraudulent Paper Detection - Core Model Architecture
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, ViTModel, AutoConfig
from typing import Dict, Optional
class TextEncoder(nn.Module):
"""SciBERT-based text encoder with layer freezing."""
def __init__(self, model_name="allenai/scibert_scivocab_uncased", freeze_layers=6):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name)
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if freeze_layers > 0:
for layer in self.encoder.encoder.layer[:freeze_layers]:
for param in layer.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state[:, 0, :]
def get_embedding_dim(self):
return self.config.hidden_size
class ImageEncoder(nn.Module):
"""ViT + forensic CNN for scientific figure analysis."""
def __init__(self, model_name="google/vit-base-patch16-224", forensic_dim=64):
super().__init__()
self.vit = ViTModel.from_pretrained(model_name)
self.forensic_cnn = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(),
nn.Linear(32 * 8 * 8, forensic_dim)
)
vit_dim = self.vit.config.hidden_size
self.fusion_proj = nn.Linear(vit_dim + forensic_dim, vit_dim)
def forward(self, pixel_values, forensic_features=None):
vit_out = self.vit(pixel_values).last_hidden_state[:, 0, :]
if forensic_features is not None:
forensic_out = self.forensic_cnn(forensic_features)
combined = torch.cat([vit_out, forensic_out], dim=-1)
return self.fusion_proj(combined)
return vit_out
def get_embedding_dim(self):
return self.vit.config.hidden_size
class TabularEncoder(nn.Module):
"""FT-Transformer style encoder for tabular data."""
def __init__(self, num_features, hidden_dim=256, num_layers=4, num_heads=8):
super().__init__()
self.input_proj = nn.Linear(num_features, hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=num_heads,
dim_feedforward=hidden_dim*4, dropout=0.1, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.hidden_dim = hidden_dim
def forward(self, tabular_features):
x = self.input_proj(tabular_features).unsqueeze(1)
x = self.transformer(x)
return x.squeeze(1)
def get_embedding_dim(self):
return self.hidden_dim
class MetadataEncoder(nn.Module):
"""MLP encoder for metadata (author, journal, citation patterns)."""
def __init__(self, metadata_dim, hidden_dim=128):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(metadata_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim)
)
self.hidden_dim = hidden_dim
def forward(self, metadata):
return self.mlp(metadata)
def get_embedding_dim(self):
return self.hidden_dim
class CrossModalFusion(nn.Module):
"""Cross-modal attention fusion layer."""
def __init__(self, embed_dims, fused_dim=512, num_heads=8, num_layers=2):
super().__init__()
self.modalities = list(embed_dims.keys())
self.projections = nn.ModuleDict({
mod: nn.Linear(dim, fused_dim) for mod, dim in embed_dims.items()
})
encoder_layer = nn.TransformerEncoderLayer(
d_model=fused_dim, nhead=num_heads,
dim_feedforward=fused_dim*4, dropout=0.1, batch_first=True
)
self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.modality_embeddings = nn.ParameterDict({
mod: nn.Parameter(torch.randn(1, 1, fused_dim) * 0.02)
for mod in self.modalities
})
def forward(self, embeddings, mask=None):
batch_size = next(iter(embeddings.values())).size(0)
projected = []
for mod in self.modalities:
if mod in embeddings:
proj = self.projections[mod](embeddings[mod]).unsqueeze(1)
proj = proj + self.modality_embeddings[mod]
projected.append(proj)
stacked = torch.cat(projected, dim=1)
if mask is not None:
padding_mask = torch.ones(batch_size, len(self.modalities), dtype=torch.bool, device=stacked.device)
for i, mod in enumerate(self.modalities):
if mod in mask:
padding_mask[:, i] = ~mask[mod]
else:
padding_mask = None
fused = self.fusion_transformer(stacked, src_key_padding_mask=padding_mask)
if padding_mask is not None:
mask_expanded = (~padding_mask).unsqueeze(-1).float()
fused = (fused * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
else:
fused = fused.mean(dim=1)
return fused
class FraudDetectionHead(nn.Module):
"""Classification head with explainability and anomaly scoring."""
def __init__(self, input_dim, num_classes=2, explanation_dim=256):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, input_dim//2), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(input_dim//2, input_dim//4), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(input_dim//4, num_classes)
)
self.explanation_proj = nn.Sequential(
nn.Linear(input_dim, explanation_dim), nn.ReLU(),
nn.Linear(explanation_dim, 4)
)
self.anomaly_proj = nn.Linear(input_dim, 1)
def forward(self, fused_embedding):
logits = self.classifier(fused_embedding)
modality_scores = torch.sigmoid(self.explanation_proj(fused_embedding))
anomaly_score = torch.sigmoid(self.anomaly_proj(fused_embedding))
return logits, modality_scores, anomaly_score
class MultimodalFraudDetector(nn.Module):
"""
Complete multimodal fraudulent paper detection system.
Combines text, image, tabular, and metadata modalities.
"""
def __init__(self, text_model="allenai/scibert_scivocab_uncased",
image_model="google/vit-base-patch16-224", tabular_features=20,
metadata_features=15, fused_dim=512, freeze_text_layers=6):
super().__init__()
self.text_encoder = TextEncoder(text_model, freeze_text_layers)
self.image_encoder = ImageEncoder(image_model)
self.tabular_encoder = TabularEncoder(tabular_features)
self.metadata_encoder = MetadataEncoder(metadata_features)
embed_dims = {
'text': self.text_encoder.get_embedding_dim(),
'image': self.image_encoder.get_embedding_dim(),
'tabular': self.tabular_encoder.get_embedding_dim(),
'metadata': self.metadata_encoder.get_embedding_dim()
}
self.fusion = CrossModalFusion(embed_dims, fused_dim)
self.head = FraudDetectionHead(fused_dim)
self.fused_dim = fused_dim
def forward(self, text_input_ids=None, text_attention_mask=None,
image_pixels=None, image_forensic=None, tabular_features=None,
metadata_features=None, modality_mask=None):
embeddings = {}
mask = modality_mask or {}
if text_input_ids is not None:
embeddings['text'] = self.text_encoder(text_input_ids, text_attention_mask)
mask['text'] = torch.ones(text_input_ids.size(0), dtype=torch.bool, device=text_input_ids.device)
if image_pixels is not None:
embeddings['image'] = self.image_encoder(image_pixels, image_forensic)
mask['image'] = torch.ones(image_pixels.size(0), dtype=torch.bool, device=image_pixels.device)
if tabular_features is not None:
embeddings['tabular'] = self.tabular_encoder(tabular_features)
mask['tabular'] = torch.ones(tabular_features.size(0), dtype=torch.bool, device=tabular_features.device)
if metadata_features is not None:
embeddings['metadata'] = self.metadata_encoder(metadata_features)
mask['metadata'] = torch.ones(metadata_features.size(0), dtype=torch.bool, device=metadata_features.device)
fused = self.fusion(embeddings, mask)
logits, modality_scores, anomaly_score = self.head(fused)
return {
'logits': logits, 'fused_embedding': fused,
'modality_scores': modality_scores, 'anomaly_score': anomaly_score,
'embeddings': embeddings
}
|