""" Multimodal Deepfake Detection Model ==================================== Architecture: - Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification - Text Branch: RoBERTa-base for AI-generated text detection - Fusion Layer: Learnable weighted ensemble with late fusion - Explainability: GradCAM on EfficientNet convolutional layers - Output: Confidence scores [0,1] + explainability heatmaps Based on: - AWARE-NET Two-Tier Ensemble (arxiv:2505.00312) - CLIP-ViT LN-Tuning (arxiv:2503.19683) - DeTeCtive RoBERTa text detection (arxiv:2410.20964) """ import torch import torch.nn as nn import torch.nn.functional as F import timm from transformers import AutoModel, AutoTokenizer import numpy as np # ============================================================ # GradCAM Explainability Module # ============================================================ class GradCAM: """Generate class activation maps for visual branch explainability.""" def __init__(self, model, target_layer): self.model = model self.gradients = None self.activations = None self._hooks = [] # Register hooks self._hooks.append( target_layer.register_forward_hook(self._save_activations) ) self._hooks.append( target_layer.register_full_backward_hook(self._save_gradients) ) def _save_activations(self, module, input, output): self.activations = output.detach() def _save_gradients(self, module, grad_in, grad_out): self.gradients = grad_out[0].detach() def generate(self, input_tensor, class_idx=None): """Generate GradCAM heatmap. Args: input_tensor: (B, C, H, W) image tensor class_idx: Target class (None = predicted class) Returns: cam: (B, 1, H, W) heatmap normalized to [0, 1] """ self.model.eval() output = self.model(input_tensor) if class_idx is None: class_idx = output.argmax(dim=1) self.model.zero_grad() # Create one-hot target one_hot = torch.zeros_like(output) for i in range(output.size(0)): one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0 output.backward(gradient=one_hot, retain_graph=True) # Weighted combination of activation maps weights = self.gradients.mean(dim=(2, 3), keepdim=True) # (B, C, 1, 1) cam = (weights * self.activations).sum(dim=1, keepdim=True) # (B, 1, H, W) cam = F.relu(cam) # Normalize per sample B = cam.size(0) cam_flat = cam.view(B, -1) cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) cam = (cam - cam_min) / (cam_max - cam_min + 1e-8) # Upscale to input resolution cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False) return cam def remove_hooks(self): for h in self._hooks: h.remove() # ============================================================ # Visual Branch: EfficientNet-B0 Based Deepfake Detector # ============================================================ class VisualDeepfakeDetector(nn.Module): """EfficientNet-B0 based binary classifier for real/fake images. Features: - Pretrained EfficientNet-B0 backbone (timm) - L2-normalized features (inspired by CLIP deepfake detection) - GradCAM-compatible architecture """ def __init__(self, num_classes=2, pretrained=True, dropout=0.3): super().__init__() # EfficientNet-B0 backbone self.backbone = timm.create_model( 'efficientnet_b0', pretrained=pretrained, num_classes=0, # Remove classifier head global_pool='' # Remove global pooling ) self.feature_dim = 1280 # EfficientNet-B0 output channels # Custom head with L2 normalization self.global_pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(p=dropout) self.classifier = nn.Linear(self.feature_dim, num_classes) def get_features(self, x): """Extract features before classification.""" features = self.backbone(x) # (B, 1280, H, W) return features def forward(self, x): features = self.get_features(x) # (B, 1280, H, W) pooled = self.global_pool(features).flatten(1) # (B, 1280) pooled = F.normalize(pooled, p=2, dim=-1) # L2 normalize pooled = self.dropout(pooled) logits = self.classifier(pooled) # (B, 2) return logits def get_gradcam_target_layer(self): """Return the target layer for GradCAM.""" # Last convolutional block of EfficientNet return self.backbone.blocks[-1] # ============================================================ # Text Branch: RoBERTa Based AI Text Detector # ============================================================ class TextDeepfakeDetector(nn.Module): """RoBERTa-based binary classifier for human vs AI-generated text. Features: - Pretrained RoBERTa-base backbone - Mean pooling over token embeddings (more robust than CLS) - Dropout regularization """ def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.hidden_dim = self.encoder.config.hidden_size # 768 self.dropout = nn.Dropout(p=dropout) self.classifier = nn.Sequential( nn.Linear(self.hidden_dim, 256), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(256, num_classes) ) def mean_pooling(self, model_output, attention_mask): """Mean pooling over non-padded tokens.""" token_embeddings = model_output.last_hidden_state # (B, seq_len, hidden) input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = self.mean_pooling(outputs, attention_mask) # (B, 768) pooled = F.normalize(pooled, p=2, dim=-1) pooled = self.dropout(pooled) logits = self.classifier(pooled) # (B, 2) return logits # ============================================================ # Multimodal Fusion: Ensemble Classifier # ============================================================ class MultimodalDeepfakeDetector(nn.Module): """Multimodal ensemble for deepfake detection. Combines visual (image/video frame) and text modalities with learnable weighted late fusion. Supports single-modality inference. Architecture (inspired by AWARE-NET two-tier ensemble): - Visual: EfficientNet-B0 → logits - Text: RoBERTa-base → logits - Fusion: Learnable weighted average of probabilities Output: confidence score [0, 1] where 1 = AI-generated/fake """ def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3): super().__init__() self.visual_branch = VisualDeepfakeDetector( num_classes=2, pretrained=visual_pretrained, dropout=dropout ) self.text_branch = TextDeepfakeDetector( model_name=text_model_name, num_classes=2, dropout=dropout ) # Learnable fusion weights (AWARE-NET style) self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4])) # [visual, text] # Cross-modal attention for richer fusion (optional, used when both modalities present) self.cross_attention = nn.MultiheadAttention( embed_dim=128, num_heads=4, batch_first=True ) self.visual_proj = nn.Linear(1280, 128) self.text_proj = nn.Linear(768, 128) self.fusion_classifier = nn.Sequential( nn.Linear(256, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2) ) def forward(self, images=None, input_ids=None, attention_mask=None, modality='auto'): """ Forward pass supporting single or multi-modal input. Args: images: (B, C, H, W) image tensor, optional input_ids: (B, seq_len) text token IDs, optional attention_mask: (B, seq_len) attention mask, optional modality: 'visual', 'text', 'multimodal', or 'auto' Returns: dict with: - logits: (B, 2) raw logits - confidence: (B,) probability of being fake/AI-generated - modality_scores: dict of per-modality confidence scores """ results = {'modality_scores': {}} has_visual = images is not None has_text = input_ids is not None if modality == 'auto': if has_visual and has_text: modality = 'multimodal' elif has_visual: modality = 'visual' elif has_text: modality = 'text' else: raise ValueError("At least one modality input required") visual_logits = None text_logits = None if modality in ('visual', 'multimodal') and has_visual: visual_logits = self.visual_branch(images) visual_probs = F.softmax(visual_logits, dim=-1) results['modality_scores']['visual'] = visual_probs[:, 1] # P(fake) ← FIXED if modality in ('text', 'multimodal') and has_text: text_logits = self.text_branch(input_ids, attention_mask) text_probs = F.softmax(text_logits, dim=-1) results['modality_scores']['text'] = text_probs[:, 1] # P(fake) ← FIXED # Fusion logic if modality == 'multimodal' and visual_logits is not None and text_logits is not None: # Late fusion: learnable weighted average weights = F.softmax(self.fusion_weights, dim=0) visual_probs = F.softmax(visual_logits, dim=-1) text_probs = F.softmax(text_logits, dim=-1) fused_probs = weights[0] * visual_probs + weights[1] * text_probs results['logits'] = torch.log(fused_probs + 1e-8) results['confidence'] = fused_probs[:, 1] # P(fake) elif visual_logits is not None: results['logits'] = visual_logits results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 1] # P(fake) elif text_logits is not None: results['logits'] = text_logits results['confidence'] = F.softmax(text_logits, dim=-1)[:, 1] # P(fake) return results def get_visual_gradcam(self): """Get GradCAM instance for visual branch.""" target_layer = self.visual_branch.get_gradcam_target_layer() return GradCAM(self.visual_branch, target_layer) # ============================================================ # Helper: Video Frame Aggregation # ============================================================ def aggregate_video_predictions(frame_confidences, method='mean'): """Aggregate per-frame predictions to video-level score. Args: frame_confidences: list/tensor of per-frame P(fake) scores method: 'mean', 'max', 'voting' (majority vote at 0.5 threshold) Returns: video_confidence: scalar P(fake) for the whole video """ if isinstance(frame_confidences, list): frame_confidences = torch.tensor(frame_confidences) if method == 'mean': return frame_confidences.mean().item() elif method == 'max': return frame_confidences.max().item() elif method == 'voting': votes = (frame_confidences > 0.5).float() return votes.mean().item() else: raise ValueError(f"Unknown aggregation method: {method}")