""" Multimodal Fraudulent Paper Detection - Core Model Architecture """ import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer, ViTModel, AutoConfig from typing import Dict, Optional class TextEncoder(nn.Module): """SciBERT-based text encoder with layer freezing.""" def __init__(self, model_name="allenai/scibert_scivocab_uncased", freeze_layers=6): super().__init__() self.config = AutoConfig.from_pretrained(model_name) self.encoder = AutoModel.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name) if freeze_layers > 0: for layer in self.encoder.encoder.layer[:freeze_layers]: for param in layer.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) return outputs.last_hidden_state[:, 0, :] def get_embedding_dim(self): return self.config.hidden_size class ImageEncoder(nn.Module): """ViT + forensic CNN for scientific figure analysis.""" def __init__(self, model_name="google/vit-base-patch16-224", forensic_dim=64): super().__init__() self.vit = ViTModel.from_pretrained(model_name) self.forensic_cnn = nn.Sequential( nn.Conv2d(3, 16, kernel_size=5, padding=2), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(), nn.Linear(32 * 8 * 8, forensic_dim) ) vit_dim = self.vit.config.hidden_size self.fusion_proj = nn.Linear(vit_dim + forensic_dim, vit_dim) def forward(self, pixel_values, forensic_features=None): vit_out = self.vit(pixel_values).last_hidden_state[:, 0, :] if forensic_features is not None: forensic_out = self.forensic_cnn(forensic_features) combined = torch.cat([vit_out, forensic_out], dim=-1) return self.fusion_proj(combined) return vit_out def get_embedding_dim(self): return self.vit.config.hidden_size class TabularEncoder(nn.Module): """FT-Transformer style encoder for tabular data.""" def __init__(self, num_features, hidden_dim=256, num_layers=4, num_heads=8): super().__init__() self.input_proj = nn.Linear(num_features, hidden_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim*4, dropout=0.1, batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.hidden_dim = hidden_dim def forward(self, tabular_features): x = self.input_proj(tabular_features).unsqueeze(1) x = self.transformer(x) return x.squeeze(1) def get_embedding_dim(self): return self.hidden_dim class MetadataEncoder(nn.Module): """MLP encoder for metadata (author, journal, citation patterns).""" def __init__(self, metadata_dim, hidden_dim=128): super().__init__() self.mlp = nn.Sequential( nn.Linear(metadata_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim) ) self.hidden_dim = hidden_dim def forward(self, metadata): return self.mlp(metadata) def get_embedding_dim(self): return self.hidden_dim class CrossModalFusion(nn.Module): """Cross-modal attention fusion layer.""" def __init__(self, embed_dims, fused_dim=512, num_heads=8, num_layers=2): super().__init__() self.modalities = list(embed_dims.keys()) self.projections = nn.ModuleDict({ mod: nn.Linear(dim, fused_dim) for mod, dim in embed_dims.items() }) encoder_layer = nn.TransformerEncoderLayer( d_model=fused_dim, nhead=num_heads, dim_feedforward=fused_dim*4, dropout=0.1, batch_first=True ) self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.modality_embeddings = nn.ParameterDict({ mod: nn.Parameter(torch.randn(1, 1, fused_dim) * 0.02) for mod in self.modalities }) def forward(self, embeddings, mask=None): batch_size = next(iter(embeddings.values())).size(0) projected = [] for mod in self.modalities: if mod in embeddings: proj = self.projections[mod](embeddings[mod]).unsqueeze(1) proj = proj + self.modality_embeddings[mod] projected.append(proj) stacked = torch.cat(projected, dim=1) if mask is not None: padding_mask = torch.ones(batch_size, len(self.modalities), dtype=torch.bool, device=stacked.device) for i, mod in enumerate(self.modalities): if mod in mask: padding_mask[:, i] = ~mask[mod] else: padding_mask = None fused = self.fusion_transformer(stacked, src_key_padding_mask=padding_mask) if padding_mask is not None: mask_expanded = (~padding_mask).unsqueeze(-1).float() fused = (fused * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1) else: fused = fused.mean(dim=1) return fused class FraudDetectionHead(nn.Module): """Classification head with explainability and anomaly scoring.""" def __init__(self, input_dim, num_classes=2, explanation_dim=256): super().__init__() self.classifier = nn.Sequential( nn.Linear(input_dim, input_dim//2), nn.ReLU(), nn.Dropout(0.3), nn.Linear(input_dim//2, input_dim//4), nn.ReLU(), nn.Dropout(0.2), nn.Linear(input_dim//4, num_classes) ) self.explanation_proj = nn.Sequential( nn.Linear(input_dim, explanation_dim), nn.ReLU(), nn.Linear(explanation_dim, 4) ) self.anomaly_proj = nn.Linear(input_dim, 1) def forward(self, fused_embedding): logits = self.classifier(fused_embedding) modality_scores = torch.sigmoid(self.explanation_proj(fused_embedding)) anomaly_score = torch.sigmoid(self.anomaly_proj(fused_embedding)) return logits, modality_scores, anomaly_score class MultimodalFraudDetector(nn.Module): """ Complete multimodal fraudulent paper detection system. Combines text, image, tabular, and metadata modalities. """ def __init__(self, text_model="allenai/scibert_scivocab_uncased", image_model="google/vit-base-patch16-224", tabular_features=20, metadata_features=15, fused_dim=512, freeze_text_layers=6): super().__init__() self.text_encoder = TextEncoder(text_model, freeze_text_layers) self.image_encoder = ImageEncoder(image_model) self.tabular_encoder = TabularEncoder(tabular_features) self.metadata_encoder = MetadataEncoder(metadata_features) embed_dims = { 'text': self.text_encoder.get_embedding_dim(), 'image': self.image_encoder.get_embedding_dim(), 'tabular': self.tabular_encoder.get_embedding_dim(), 'metadata': self.metadata_encoder.get_embedding_dim() } self.fusion = CrossModalFusion(embed_dims, fused_dim) self.head = FraudDetectionHead(fused_dim) self.fused_dim = fused_dim def forward(self, text_input_ids=None, text_attention_mask=None, image_pixels=None, image_forensic=None, tabular_features=None, metadata_features=None, modality_mask=None): embeddings = {} mask = modality_mask or {} if text_input_ids is not None: embeddings['text'] = self.text_encoder(text_input_ids, text_attention_mask) mask['text'] = torch.ones(text_input_ids.size(0), dtype=torch.bool, device=text_input_ids.device) if image_pixels is not None: embeddings['image'] = self.image_encoder(image_pixels, image_forensic) mask['image'] = torch.ones(image_pixels.size(0), dtype=torch.bool, device=image_pixels.device) if tabular_features is not None: embeddings['tabular'] = self.tabular_encoder(tabular_features) mask['tabular'] = torch.ones(tabular_features.size(0), dtype=torch.bool, device=tabular_features.device) if metadata_features is not None: embeddings['metadata'] = self.metadata_encoder(metadata_features) mask['metadata'] = torch.ones(metadata_features.size(0), dtype=torch.bool, device=metadata_features.device) fused = self.fusion(embeddings, mask) logits, modality_scores, anomaly_score = self.head(fused) return { 'logits': logits, 'fused_embedding': fused, 'modality_scores': modality_scores, 'anomaly_score': anomaly_score, 'embeddings': embeddings }