""" Multimodal Deepfake Detection Model ==================================== Architecture: - Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification - Text Branch: RoBERTa-base for AI-generated text detection - Fusion Layer: Learnable weighted ensemble with late fusion - Explainability: GradCAM on EfficientNet convolutional layers - Output: Confidence scores [0,1] + explainability heatmaps Based on: - AWARE-NET Two-Tier Ensemble (arxiv:2505.00312) - CLIP-ViT LN-Tuning (arxiv:2503.19683) - DeTeCtive RoBERTa text detection (arxiv:2410.20964) """ import torch import torch.nn as nn import torch.nn.functional as F import timm from transformers import AutoModel, AutoTokenizer import numpy as np class GradCAM: """Generate class activation maps for visual branch explainability.""" def __init__(self, model, target_layer): self.model = model self.gradients = None self.activations = None self._hooks = [] self._hooks.append(target_layer.register_forward_hook(self._save_activations)) self._hooks.append(target_layer.register_full_backward_hook(self._save_gradients)) def _save_activations(self, module, input, output): self.activations = output.detach() def _save_gradients(self, module, grad_in, grad_out): self.gradients = grad_out[0].detach() def generate(self, input_tensor, class_idx=None): self.model.eval() output = self.model(input_tensor) if class_idx is None: class_idx = output.argmax(dim=1) self.model.zero_grad() one_hot = torch.zeros_like(output) for i in range(output.size(0)): one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0 output.backward(gradient=one_hot, retain_graph=True) weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = F.relu(cam) B = cam.size(0) cam_flat = cam.view(B, -1) cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) cam = (cam - cam_min) / (cam_max - cam_min + 1e-8) cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False) return cam def remove_hooks(self): for h in self._hooks: h.remove() class VisualDeepfakeDetector(nn.Module): def __init__(self, num_classes=2, pretrained=True, dropout=0.3): super().__init__() self.backbone = timm.create_model('efficientnet_b0', pretrained=pretrained, num_classes=0, global_pool='') self.feature_dim = 1280 self.global_pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(p=dropout) self.classifier = nn.Linear(self.feature_dim, num_classes) def get_features(self, x): return self.backbone(x) def forward(self, x): features = self.get_features(x) pooled = self.global_pool(features).flatten(1) pooled = F.normalize(pooled, p=2, dim=-1) pooled = self.dropout(pooled) return self.classifier(pooled) def get_gradcam_target_layer(self): return self.backbone.blocks[-1] class TextDeepfakeDetector(nn.Module): def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.hidden_dim = self.encoder.config.hidden_size self.dropout = nn.Dropout(p=dropout) self.classifier = nn.Sequential( nn.Linear(self.hidden_dim, 256), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(256, num_classes) ) def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = self.mean_pooling(outputs, attention_mask) pooled = F.normalize(pooled, p=2, dim=-1) pooled = self.dropout(pooled) return self.classifier(pooled) class MultimodalDeepfakeDetector(nn.Module): def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3): super().__init__() self.visual_branch = VisualDeepfakeDetector(num_classes=2, pretrained=visual_pretrained, dropout=dropout) self.text_branch = TextDeepfakeDetector(model_name=text_model_name, num_classes=2, dropout=dropout) self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4])) self.cross_attention = nn.MultiheadAttention(embed_dim=128, num_heads=4, batch_first=True) self.visual_proj = nn.Linear(1280, 128) self.text_proj = nn.Linear(768, 128) self.fusion_classifier = nn.Sequential(nn.Linear(256, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2)) def forward(self, images=None, input_ids=None, attention_mask=None, modality='auto'): results = {'modality_scores': {}} has_visual = images is not None has_text = input_ids is not None if modality == 'auto': if has_visual and has_text: modality = 'multimodal' elif has_visual: modality = 'visual' elif has_text: modality = 'text' else: raise ValueError("At least one modality input required") visual_logits = text_logits = None if modality in ('visual', 'multimodal') and has_visual: visual_logits = self.visual_branch(images) results['modality_scores']['visual'] = F.softmax(visual_logits, dim=-1)[:, 0] if modality in ('text', 'multimodal') and has_text: text_logits = self.text_branch(input_ids, attention_mask) results['modality_scores']['text'] = F.softmax(text_logits, dim=-1)[:, 0] if modality == 'multimodal' and visual_logits is not None and text_logits is not None: weights = F.softmax(self.fusion_weights, dim=0) fused = weights[0] * F.softmax(visual_logits, -1) + weights[1] * F.softmax(text_logits, -1) results['logits'] = torch.log(fused + 1e-8) results['confidence'] = fused[:, 0] elif visual_logits is not None: results['logits'] = visual_logits results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 0] elif text_logits is not None: results['logits'] = text_logits results['confidence'] = F.softmax(text_logits, dim=-1)[:, 0] return results def get_visual_gradcam(self): return GradCAM(self.visual_branch, self.visual_branch.get_gradcam_target_layer()) def aggregate_video_predictions(frame_confidences, method='mean'): if isinstance(frame_confidences, list): frame_confidences = torch.tensor(frame_confidences) if method == 'mean': return frame_confidences.mean().item() elif method == 'max': return frame_confidences.max().item() elif method == 'voting': return (frame_confidences > 0.5).float().mean().item() else: raise ValueError(f"Unknown method: {method}")