File size: 7,423 Bytes
24f0d7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
Multimodal Deepfake Detection Model
====================================
Architecture:
- Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification
- Text Branch: RoBERTa-base for AI-generated text detection
- Fusion Layer: Learnable weighted ensemble with late fusion
- Explainability: GradCAM on EfficientNet convolutional layers
- Output: Confidence scores [0,1] + explainability heatmaps
Based on:
- AWARE-NET Two-Tier Ensemble (arxiv:2505.00312)
- CLIP-ViT LN-Tuning (arxiv:2503.19683)
- DeTeCtive RoBERTa text detection (arxiv:2410.20964)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from transformers import AutoModel, AutoTokenizer
import numpy as np
class GradCAM:
"""Generate class activation maps for visual branch explainability."""
def __init__(self, model, target_layer):
self.model = model
self.gradients = None
self.activations = None
self._hooks = []
self._hooks.append(target_layer.register_forward_hook(self._save_activations))
self._hooks.append(target_layer.register_full_backward_hook(self._save_gradients))
def _save_activations(self, module, input, output):
self.activations = output.detach()
def _save_gradients(self, module, grad_in, grad_out):
self.gradients = grad_out[0].detach()
def generate(self, input_tensor, class_idx=None):
self.model.eval()
output = self.model(input_tensor)
if class_idx is None:
class_idx = output.argmax(dim=1)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
for i in range(output.size(0)):
one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0
output.backward(gradient=one_hot, retain_graph=True)
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
B = cam.size(0)
cam_flat = cam.view(B, -1)
cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
return cam
def remove_hooks(self):
for h in self._hooks:
h.remove()
class VisualDeepfakeDetector(nn.Module):
def __init__(self, num_classes=2, pretrained=True, dropout=0.3):
super().__init__()
self.backbone = timm.create_model('efficientnet_b0', pretrained=pretrained, num_classes=0, global_pool='')
self.feature_dim = 1280
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Linear(self.feature_dim, num_classes)
def get_features(self, x):
return self.backbone(x)
def forward(self, x):
features = self.get_features(x)
pooled = self.global_pool(features).flatten(1)
pooled = F.normalize(pooled, p=2, dim=-1)
pooled = self.dropout(pooled)
return self.classifier(pooled)
def get_gradcam_target_layer(self):
return self.backbone.blocks[-1]
class TextDeepfakeDetector(nn.Module):
def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.hidden_dim = self.encoder.config.hidden_size
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Sequential(
nn.Linear(self.hidden_dim, 256), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(256, num_classes)
)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.mean_pooling(outputs, attention_mask)
pooled = F.normalize(pooled, p=2, dim=-1)
pooled = self.dropout(pooled)
return self.classifier(pooled)
class MultimodalDeepfakeDetector(nn.Module):
def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3):
super().__init__()
self.visual_branch = VisualDeepfakeDetector(num_classes=2, pretrained=visual_pretrained, dropout=dropout)
self.text_branch = TextDeepfakeDetector(model_name=text_model_name, num_classes=2, dropout=dropout)
self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4]))
self.cross_attention = nn.MultiheadAttention(embed_dim=128, num_heads=4, batch_first=True)
self.visual_proj = nn.Linear(1280, 128)
self.text_proj = nn.Linear(768, 128)
self.fusion_classifier = nn.Sequential(nn.Linear(256, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2))
def forward(self, images=None, input_ids=None, attention_mask=None, modality='auto'):
results = {'modality_scores': {}}
has_visual = images is not None
has_text = input_ids is not None
if modality == 'auto':
if has_visual and has_text: modality = 'multimodal'
elif has_visual: modality = 'visual'
elif has_text: modality = 'text'
else: raise ValueError("At least one modality input required")
visual_logits = text_logits = None
if modality in ('visual', 'multimodal') and has_visual:
visual_logits = self.visual_branch(images)
results['modality_scores']['visual'] = F.softmax(visual_logits, dim=-1)[:, 0]
if modality in ('text', 'multimodal') and has_text:
text_logits = self.text_branch(input_ids, attention_mask)
results['modality_scores']['text'] = F.softmax(text_logits, dim=-1)[:, 0]
if modality == 'multimodal' and visual_logits is not None and text_logits is not None:
weights = F.softmax(self.fusion_weights, dim=0)
fused = weights[0] * F.softmax(visual_logits, -1) + weights[1] * F.softmax(text_logits, -1)
results['logits'] = torch.log(fused + 1e-8)
results['confidence'] = fused[:, 0]
elif visual_logits is not None:
results['logits'] = visual_logits
results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 0]
elif text_logits is not None:
results['logits'] = text_logits
results['confidence'] = F.softmax(text_logits, dim=-1)[:, 0]
return results
def get_visual_gradcam(self):
return GradCAM(self.visual_branch, self.visual_branch.get_gradcam_target_layer())
def aggregate_video_predictions(frame_confidences, method='mean'):
if isinstance(frame_confidences, list):
frame_confidences = torch.tensor(frame_confidences)
if method == 'mean': return frame_confidences.mean().item()
elif method == 'max': return frame_confidences.max().item()
elif method == 'voting': return (frame_confidences > 0.5).float().mean().item()
else: raise ValueError(f"Unknown method: {method}")
|