Spaces:
Runtime error
Runtime error
| """ | |
| Multimodal Deepfake Detection Model | |
| ==================================== | |
| Architecture: | |
| - Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification | |
| - Text Branch: RoBERTa-base for AI-generated text detection | |
| - Fusion Layer: Learnable weighted ensemble with late fusion | |
| - Explainability: GradCAM on EfficientNet convolutional layers | |
| - Output: Confidence scores [0,1] + explainability heatmaps | |
| Based on: | |
| - AWARE-NET Two-Tier Ensemble (arxiv:2505.00312) | |
| - CLIP-ViT LN-Tuning (arxiv:2503.19683) | |
| - DeTeCtive RoBERTa text detection (arxiv:2410.20964) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from transformers import AutoModel, AutoTokenizer | |
| import numpy as np | |
| # ============================================================ | |
| # GradCAM Explainability Module | |
| # ============================================================ | |
| class GradCAM: | |
| """Generate class activation maps for visual branch explainability.""" | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| self._hooks = [] | |
| # Register hooks | |
| self._hooks.append( | |
| target_layer.register_forward_hook(self._save_activations) | |
| ) | |
| self._hooks.append( | |
| target_layer.register_full_backward_hook(self._save_gradients) | |
| ) | |
| def _save_activations(self, module, input, output): | |
| self.activations = output.detach() | |
| def _save_gradients(self, module, grad_in, grad_out): | |
| self.gradients = grad_out[0].detach() | |
| def generate(self, input_tensor, class_idx=None): | |
| """Generate GradCAM heatmap. | |
| Args: | |
| input_tensor: (B, C, H, W) image tensor | |
| class_idx: Target class (None = predicted class) | |
| Returns: | |
| cam: (B, 1, H, W) heatmap normalized to [0, 1] | |
| """ | |
| self.model.eval() | |
| output = self.model(input_tensor) | |
| if class_idx is None: | |
| class_idx = output.argmax(dim=1) | |
| self.model.zero_grad() | |
| # Create one-hot target | |
| one_hot = torch.zeros_like(output) | |
| for i in range(output.size(0)): | |
| one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| # Weighted combination of activation maps | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) # (B, C, 1, 1) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
| cam = F.relu(cam) | |
| # Normalize per sample | |
| B = cam.size(0) | |
| cam_flat = cam.view(B, -1) | |
| cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) | |
| cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) | |
| cam = (cam - cam_min) / (cam_max - cam_min + 1e-8) | |
| # Upscale to input resolution | |
| cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False) | |
| return cam | |
| def remove_hooks(self): | |
| for h in self._hooks: | |
| h.remove() | |
| # ============================================================ | |
| # Visual Branch: EfficientNet-B0 Based Deepfake Detector | |
| # ============================================================ | |
| class VisualDeepfakeDetector(nn.Module): | |
| """EfficientNet-B0 based binary classifier for real/fake images. | |
| Features: | |
| - Pretrained EfficientNet-B0 backbone (timm) | |
| - L2-normalized features (inspired by CLIP deepfake detection) | |
| - GradCAM-compatible architecture | |
| """ | |
| def __init__(self, num_classes=2, pretrained=True, dropout=0.3): | |
| super().__init__() | |
| # EfficientNet-B0 backbone | |
| self.backbone = timm.create_model( | |
| 'efficientnet_b0', | |
| pretrained=pretrained, | |
| num_classes=0, # Remove classifier head | |
| global_pool='' # Remove global pooling | |
| ) | |
| self.feature_dim = 1280 # EfficientNet-B0 output channels | |
| # Custom head with L2 normalization | |
| self.global_pool = nn.AdaptiveAvgPool2d(1) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.classifier = nn.Linear(self.feature_dim, num_classes) | |
| def get_features(self, x): | |
| """Extract features before classification.""" | |
| features = self.backbone(x) # (B, 1280, H, W) | |
| return features | |
| def forward(self, x): | |
| features = self.get_features(x) # (B, 1280, H, W) | |
| pooled = self.global_pool(features).flatten(1) # (B, 1280) | |
| pooled = F.normalize(pooled, p=2, dim=-1) # L2 normalize | |
| pooled = self.dropout(pooled) | |
| logits = self.classifier(pooled) # (B, 2) | |
| return logits | |
| def get_gradcam_target_layer(self): | |
| """Return the target layer for GradCAM.""" | |
| # Last convolutional block of EfficientNet | |
| return self.backbone.blocks[-1] | |
| # ============================================================ | |
| # Text Branch: RoBERTa Based AI Text Detector | |
| # ============================================================ | |
| class TextDeepfakeDetector(nn.Module): | |
| """RoBERTa-based binary classifier for human vs AI-generated text. | |
| Features: | |
| - Pretrained RoBERTa-base backbone | |
| - Mean pooling over token embeddings (more robust than CLS) | |
| - Dropout regularization | |
| """ | |
| def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| self.hidden_dim = self.encoder.config.hidden_size # 768 | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.hidden_dim, 256), | |
| nn.ReLU(), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def mean_pooling(self, model_output, attention_mask): | |
| """Mean pooling over non-padded tokens.""" | |
| token_embeddings = model_output.last_hidden_state # (B, seq_len, hidden) | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = self.mean_pooling(outputs, attention_mask) # (B, 768) | |
| pooled = F.normalize(pooled, p=2, dim=-1) | |
| pooled = self.dropout(pooled) | |
| logits = self.classifier(pooled) # (B, 2) | |
| return logits | |
| # ============================================================ | |
| # Multimodal Fusion: Ensemble Classifier | |
| # ============================================================ | |
| class MultimodalDeepfakeDetector(nn.Module): | |
| """Multimodal ensemble for deepfake detection. | |
| Combines visual (image/video frame) and text modalities with | |
| learnable weighted late fusion. Supports single-modality inference. | |
| Architecture (inspired by AWARE-NET two-tier ensemble): | |
| - Visual: EfficientNet-B0 → logits | |
| - Text: RoBERTa-base → logits | |
| - Fusion: Learnable weighted average of probabilities | |
| Output: confidence score [0, 1] where 1 = AI-generated/fake | |
| """ | |
| def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3): | |
| super().__init__() | |
| self.visual_branch = VisualDeepfakeDetector( | |
| num_classes=2, pretrained=visual_pretrained, dropout=dropout | |
| ) | |
| self.text_branch = TextDeepfakeDetector( | |
| model_name=text_model_name, num_classes=2, dropout=dropout | |
| ) | |
| # Learnable fusion weights (AWARE-NET style) | |
| self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4])) # [visual, text] | |
| # Cross-modal attention for richer fusion (optional, used when both modalities present) | |
| self.cross_attention = nn.MultiheadAttention( | |
| embed_dim=128, num_heads=4, batch_first=True | |
| ) | |
| self.visual_proj = nn.Linear(1280, 128) | |
| self.text_proj = nn.Linear(768, 128) | |
| self.fusion_classifier = nn.Sequential( | |
| nn.Linear(256, 64), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 2) | |
| ) | |
| def forward(self, images=None, input_ids=None, attention_mask=None, | |
| modality='auto'): | |
| """ | |
| Forward pass supporting single or multi-modal input. | |
| Args: | |
| images: (B, C, H, W) image tensor, optional | |
| input_ids: (B, seq_len) text token IDs, optional | |
| attention_mask: (B, seq_len) attention mask, optional | |
| modality: 'visual', 'text', 'multimodal', or 'auto' | |
| Returns: | |
| dict with: | |
| - logits: (B, 2) raw logits | |
| - confidence: (B,) probability of being fake/AI-generated | |
| - modality_scores: dict of per-modality confidence scores | |
| """ | |
| results = {'modality_scores': {}} | |
| has_visual = images is not None | |
| has_text = input_ids is not None | |
| if modality == 'auto': | |
| if has_visual and has_text: | |
| modality = 'multimodal' | |
| elif has_visual: | |
| modality = 'visual' | |
| elif has_text: | |
| modality = 'text' | |
| else: | |
| raise ValueError("At least one modality input required") | |
| visual_logits = None | |
| text_logits = None | |
| if modality in ('visual', 'multimodal') and has_visual: | |
| visual_logits = self.visual_branch(images) | |
| visual_probs = F.softmax(visual_logits, dim=-1) | |
| results['modality_scores']['visual'] = visual_probs[:, 1] # P(fake) ← FIXED | |
| if modality in ('text', 'multimodal') and has_text: | |
| text_logits = self.text_branch(input_ids, attention_mask) | |
| text_probs = F.softmax(text_logits, dim=-1) | |
| results['modality_scores']['text'] = text_probs[:, 1] # P(fake) ← FIXED | |
| # Fusion logic | |
| if modality == 'multimodal' and visual_logits is not None and text_logits is not None: | |
| # Late fusion: learnable weighted average | |
| weights = F.softmax(self.fusion_weights, dim=0) | |
| visual_probs = F.softmax(visual_logits, dim=-1) | |
| text_probs = F.softmax(text_logits, dim=-1) | |
| fused_probs = weights[0] * visual_probs + weights[1] * text_probs | |
| results['logits'] = torch.log(fused_probs + 1e-8) | |
| results['confidence'] = fused_probs[:, 1] # P(fake) | |
| elif visual_logits is not None: | |
| results['logits'] = visual_logits | |
| results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 1] # P(fake) | |
| elif text_logits is not None: | |
| results['logits'] = text_logits | |
| results['confidence'] = F.softmax(text_logits, dim=-1)[:, 1] # P(fake) | |
| return results | |
| def get_visual_gradcam(self): | |
| """Get GradCAM instance for visual branch.""" | |
| target_layer = self.visual_branch.get_gradcam_target_layer() | |
| return GradCAM(self.visual_branch, target_layer) | |
| # ============================================================ | |
| # Helper: Video Frame Aggregation | |
| # ============================================================ | |
| def aggregate_video_predictions(frame_confidences, method='mean'): | |
| """Aggregate per-frame predictions to video-level score. | |
| Args: | |
| frame_confidences: list/tensor of per-frame P(fake) scores | |
| method: 'mean', 'max', 'voting' (majority vote at 0.5 threshold) | |
| Returns: | |
| video_confidence: scalar P(fake) for the whole video | |
| """ | |
| if isinstance(frame_confidences, list): | |
| frame_confidences = torch.tensor(frame_confidences) | |
| if method == 'mean': | |
| return frame_confidences.mean().item() | |
| elif method == 'max': | |
| return frame_confidences.max().item() | |
| elif method == 'voting': | |
| votes = (frame_confidences > 0.5).float() | |
| return votes.mean().item() | |
| else: | |
| raise ValueError(f"Unknown aggregation method: {method}") | |