| """ |
| Multimodal Deepfake Detection Model |
| ==================================== |
| Architecture: |
| - Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification |
| - Text Branch: RoBERTa-base for AI-generated text detection |
| - Fusion Layer: Learnable weighted ensemble with late fusion |
| - Explainability: GradCAM on EfficientNet convolutional layers |
| - Output: Confidence scores [0,1] + explainability heatmaps |
| |
| Based on: |
| - AWARE-NET Two-Tier Ensemble (arxiv:2505.00312) |
| - CLIP-ViT LN-Tuning (arxiv:2503.19683) |
| - DeTeCtive RoBERTa text detection (arxiv:2410.20964) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import timm |
| from transformers import AutoModel, AutoTokenizer |
| import numpy as np |
|
|
|
|
| class GradCAM: |
| """Generate class activation maps for visual branch explainability.""" |
|
|
| def __init__(self, model, target_layer): |
| self.model = model |
| self.gradients = None |
| self.activations = None |
| self._hooks = [] |
| self._hooks.append(target_layer.register_forward_hook(self._save_activations)) |
| self._hooks.append(target_layer.register_full_backward_hook(self._save_gradients)) |
|
|
| def _save_activations(self, module, input, output): |
| self.activations = output.detach() |
|
|
| def _save_gradients(self, module, grad_in, grad_out): |
| self.gradients = grad_out[0].detach() |
|
|
| def generate(self, input_tensor, class_idx=None): |
| self.model.eval() |
| output = self.model(input_tensor) |
| if class_idx is None: |
| class_idx = output.argmax(dim=1) |
| self.model.zero_grad() |
| one_hot = torch.zeros_like(output) |
| for i in range(output.size(0)): |
| one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0 |
| output.backward(gradient=one_hot, retain_graph=True) |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) |
| cam = F.relu(cam) |
| B = cam.size(0) |
| cam_flat = cam.view(B, -1) |
| cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) |
| cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) |
| cam = (cam - cam_min) / (cam_max - cam_min + 1e-8) |
| cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False) |
| return cam |
|
|
| def remove_hooks(self): |
| for h in self._hooks: |
| h.remove() |
|
|
|
|
| class VisualDeepfakeDetector(nn.Module): |
| def __init__(self, num_classes=2, pretrained=True, dropout=0.3): |
| super().__init__() |
| self.backbone = timm.create_model('efficientnet_b0', pretrained=pretrained, num_classes=0, global_pool='') |
| self.feature_dim = 1280 |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| self.dropout = nn.Dropout(p=dropout) |
| self.classifier = nn.Linear(self.feature_dim, num_classes) |
|
|
| def get_features(self, x): |
| return self.backbone(x) |
|
|
| def forward(self, x): |
| features = self.get_features(x) |
| pooled = self.global_pool(features).flatten(1) |
| pooled = F.normalize(pooled, p=2, dim=-1) |
| pooled = self.dropout(pooled) |
| return self.classifier(pooled) |
|
|
| def get_gradcam_target_layer(self): |
| return self.backbone.blocks[-1] |
|
|
|
|
| class TextDeepfakeDetector(nn.Module): |
| def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(model_name) |
| self.hidden_dim = self.encoder.config.hidden_size |
| self.dropout = nn.Dropout(p=dropout) |
| self.classifier = nn.Sequential( |
| nn.Linear(self.hidden_dim, 256), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(256, num_classes) |
| ) |
|
|
| def mean_pooling(self, model_output, attention_mask): |
| token_embeddings = model_output.last_hidden_state |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| pooled = self.mean_pooling(outputs, attention_mask) |
| pooled = F.normalize(pooled, p=2, dim=-1) |
| pooled = self.dropout(pooled) |
| return self.classifier(pooled) |
|
|
|
|
| class MultimodalDeepfakeDetector(nn.Module): |
| def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3): |
| super().__init__() |
| self.visual_branch = VisualDeepfakeDetector(num_classes=2, pretrained=visual_pretrained, dropout=dropout) |
| self.text_branch = TextDeepfakeDetector(model_name=text_model_name, num_classes=2, dropout=dropout) |
| self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4])) |
| self.cross_attention = nn.MultiheadAttention(embed_dim=128, num_heads=4, batch_first=True) |
| self.visual_proj = nn.Linear(1280, 128) |
| self.text_proj = nn.Linear(768, 128) |
| self.fusion_classifier = nn.Sequential(nn.Linear(256, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2)) |
|
|
| def forward(self, images=None, input_ids=None, attention_mask=None, modality='auto'): |
| results = {'modality_scores': {}} |
| has_visual = images is not None |
| has_text = input_ids is not None |
| if modality == 'auto': |
| if has_visual and has_text: modality = 'multimodal' |
| elif has_visual: modality = 'visual' |
| elif has_text: modality = 'text' |
| else: raise ValueError("At least one modality input required") |
| visual_logits = text_logits = None |
| if modality in ('visual', 'multimodal') and has_visual: |
| visual_logits = self.visual_branch(images) |
| results['modality_scores']['visual'] = F.softmax(visual_logits, dim=-1)[:, 0] |
| if modality in ('text', 'multimodal') and has_text: |
| text_logits = self.text_branch(input_ids, attention_mask) |
| results['modality_scores']['text'] = F.softmax(text_logits, dim=-1)[:, 0] |
| if modality == 'multimodal' and visual_logits is not None and text_logits is not None: |
| weights = F.softmax(self.fusion_weights, dim=0) |
| fused = weights[0] * F.softmax(visual_logits, -1) + weights[1] * F.softmax(text_logits, -1) |
| results['logits'] = torch.log(fused + 1e-8) |
| results['confidence'] = fused[:, 0] |
| elif visual_logits is not None: |
| results['logits'] = visual_logits |
| results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 0] |
| elif text_logits is not None: |
| results['logits'] = text_logits |
| results['confidence'] = F.softmax(text_logits, dim=-1)[:, 0] |
| return results |
|
|
| def get_visual_gradcam(self): |
| return GradCAM(self.visual_branch, self.visual_branch.get_gradcam_target_layer()) |
|
|
|
|
| def aggregate_video_predictions(frame_confidences, method='mean'): |
| if isinstance(frame_confidences, list): |
| frame_confidences = torch.tensor(frame_confidences) |
| if method == 'mean': return frame_confidences.mean().item() |
| elif method == 'max': return frame_confidences.max().item() |
| elif method == 'voting': return (frame_confidences > 0.5).float().mean().item() |
| else: raise ValueError(f"Unknown method: {method}") |
|
|