pangweijlu's picture
Upload model.py
e384945 verified
"""
Multimodal Fraudulent Paper Detection - Core Model Architecture
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, ViTModel, AutoConfig
from typing import Dict, Optional
class TextEncoder(nn.Module):
"""SciBERT-based text encoder with layer freezing."""
def __init__(self, model_name="allenai/scibert_scivocab_uncased", freeze_layers=6):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name)
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if freeze_layers > 0:
for layer in self.encoder.encoder.layer[:freeze_layers]:
for param in layer.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state[:, 0, :]
def get_embedding_dim(self):
return self.config.hidden_size
class ImageEncoder(nn.Module):
"""ViT + forensic CNN for scientific figure analysis."""
def __init__(self, model_name="google/vit-base-patch16-224", forensic_dim=64):
super().__init__()
self.vit = ViTModel.from_pretrained(model_name)
self.forensic_cnn = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(),
nn.Linear(32 * 8 * 8, forensic_dim)
)
vit_dim = self.vit.config.hidden_size
self.fusion_proj = nn.Linear(vit_dim + forensic_dim, vit_dim)
def forward(self, pixel_values, forensic_features=None):
vit_out = self.vit(pixel_values).last_hidden_state[:, 0, :]
if forensic_features is not None:
forensic_out = self.forensic_cnn(forensic_features)
combined = torch.cat([vit_out, forensic_out], dim=-1)
return self.fusion_proj(combined)
return vit_out
def get_embedding_dim(self):
return self.vit.config.hidden_size
class TabularEncoder(nn.Module):
"""FT-Transformer style encoder for tabular data."""
def __init__(self, num_features, hidden_dim=256, num_layers=4, num_heads=8):
super().__init__()
self.input_proj = nn.Linear(num_features, hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=num_heads,
dim_feedforward=hidden_dim*4, dropout=0.1, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.hidden_dim = hidden_dim
def forward(self, tabular_features):
x = self.input_proj(tabular_features).unsqueeze(1)
x = self.transformer(x)
return x.squeeze(1)
def get_embedding_dim(self):
return self.hidden_dim
class MetadataEncoder(nn.Module):
"""MLP encoder for metadata (author, journal, citation patterns)."""
def __init__(self, metadata_dim, hidden_dim=128):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(metadata_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim)
)
self.hidden_dim = hidden_dim
def forward(self, metadata):
return self.mlp(metadata)
def get_embedding_dim(self):
return self.hidden_dim
class CrossModalFusion(nn.Module):
"""Cross-modal attention fusion layer."""
def __init__(self, embed_dims, fused_dim=512, num_heads=8, num_layers=2):
super().__init__()
self.modalities = list(embed_dims.keys())
self.projections = nn.ModuleDict({
mod: nn.Linear(dim, fused_dim) for mod, dim in embed_dims.items()
})
encoder_layer = nn.TransformerEncoderLayer(
d_model=fused_dim, nhead=num_heads,
dim_feedforward=fused_dim*4, dropout=0.1, batch_first=True
)
self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.modality_embeddings = nn.ParameterDict({
mod: nn.Parameter(torch.randn(1, 1, fused_dim) * 0.02)
for mod in self.modalities
})
def forward(self, embeddings, mask=None):
batch_size = next(iter(embeddings.values())).size(0)
projected = []
for mod in self.modalities:
if mod in embeddings:
proj = self.projections[mod](embeddings[mod]).unsqueeze(1)
proj = proj + self.modality_embeddings[mod]
projected.append(proj)
stacked = torch.cat(projected, dim=1)
if mask is not None:
padding_mask = torch.ones(batch_size, len(self.modalities), dtype=torch.bool, device=stacked.device)
for i, mod in enumerate(self.modalities):
if mod in mask:
padding_mask[:, i] = ~mask[mod]
else:
padding_mask = None
fused = self.fusion_transformer(stacked, src_key_padding_mask=padding_mask)
if padding_mask is not None:
mask_expanded = (~padding_mask).unsqueeze(-1).float()
fused = (fused * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
else:
fused = fused.mean(dim=1)
return fused
class FraudDetectionHead(nn.Module):
"""Classification head with explainability and anomaly scoring."""
def __init__(self, input_dim, num_classes=2, explanation_dim=256):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, input_dim//2), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(input_dim//2, input_dim//4), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(input_dim//4, num_classes)
)
self.explanation_proj = nn.Sequential(
nn.Linear(input_dim, explanation_dim), nn.ReLU(),
nn.Linear(explanation_dim, 4)
)
self.anomaly_proj = nn.Linear(input_dim, 1)
def forward(self, fused_embedding):
logits = self.classifier(fused_embedding)
modality_scores = torch.sigmoid(self.explanation_proj(fused_embedding))
anomaly_score = torch.sigmoid(self.anomaly_proj(fused_embedding))
return logits, modality_scores, anomaly_score
class MultimodalFraudDetector(nn.Module):
"""
Complete multimodal fraudulent paper detection system.
Combines text, image, tabular, and metadata modalities.
"""
def __init__(self, text_model="allenai/scibert_scivocab_uncased",
image_model="google/vit-base-patch16-224", tabular_features=20,
metadata_features=15, fused_dim=512, freeze_text_layers=6):
super().__init__()
self.text_encoder = TextEncoder(text_model, freeze_text_layers)
self.image_encoder = ImageEncoder(image_model)
self.tabular_encoder = TabularEncoder(tabular_features)
self.metadata_encoder = MetadataEncoder(metadata_features)
embed_dims = {
'text': self.text_encoder.get_embedding_dim(),
'image': self.image_encoder.get_embedding_dim(),
'tabular': self.tabular_encoder.get_embedding_dim(),
'metadata': self.metadata_encoder.get_embedding_dim()
}
self.fusion = CrossModalFusion(embed_dims, fused_dim)
self.head = FraudDetectionHead(fused_dim)
self.fused_dim = fused_dim
def forward(self, text_input_ids=None, text_attention_mask=None,
image_pixels=None, image_forensic=None, tabular_features=None,
metadata_features=None, modality_mask=None):
embeddings = {}
mask = modality_mask or {}
if text_input_ids is not None:
embeddings['text'] = self.text_encoder(text_input_ids, text_attention_mask)
mask['text'] = torch.ones(text_input_ids.size(0), dtype=torch.bool, device=text_input_ids.device)
if image_pixels is not None:
embeddings['image'] = self.image_encoder(image_pixels, image_forensic)
mask['image'] = torch.ones(image_pixels.size(0), dtype=torch.bool, device=image_pixels.device)
if tabular_features is not None:
embeddings['tabular'] = self.tabular_encoder(tabular_features)
mask['tabular'] = torch.ones(tabular_features.size(0), dtype=torch.bool, device=tabular_features.device)
if metadata_features is not None:
embeddings['metadata'] = self.metadata_encoder(metadata_features)
mask['metadata'] = torch.ones(metadata_features.size(0), dtype=torch.bool, device=metadata_features.device)
fused = self.fusion(embeddings, mask)
logits, modality_scores, anomaly_score = self.head(fused)
return {
'logits': logits, 'fused_embedding': fused,
'modality_scores': modality_scores, 'anomaly_score': anomaly_score,
'embeddings': embeddings
}