alianassmaaa commited on
Commit
8189b22
·
verified ·
1 Parent(s): f8e7507

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +320 -0
model.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Deepfake Detection Model
3
+ ====================================
4
+ Architecture:
5
+ - Visual Branch: EfficientNet-B0 (pretrained) for image/video frame classification
6
+ - Text Branch: RoBERTa-base for AI-generated text detection
7
+ - Fusion Layer: Learnable weighted ensemble with late fusion
8
+ - Explainability: GradCAM on EfficientNet convolutional layers
9
+ - Output: Confidence scores [0,1] + explainability heatmaps
10
+
11
+ Based on:
12
+ - AWARE-NET Two-Tier Ensemble (arxiv:2505.00312)
13
+ - CLIP-ViT LN-Tuning (arxiv:2503.19683)
14
+ - DeTeCtive RoBERTa text detection (arxiv:2410.20964)
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import timm
21
+ from transformers import AutoModel, AutoTokenizer
22
+ import numpy as np
23
+
24
+
25
+ # ============================================================
26
+ # GradCAM Explainability Module
27
+ # ============================================================
28
+ class GradCAM:
29
+ """Generate class activation maps for visual branch explainability."""
30
+
31
+ def __init__(self, model, target_layer):
32
+ self.model = model
33
+ self.gradients = None
34
+ self.activations = None
35
+ self._hooks = []
36
+
37
+ # Register hooks
38
+ self._hooks.append(
39
+ target_layer.register_forward_hook(self._save_activations)
40
+ )
41
+ self._hooks.append(
42
+ target_layer.register_full_backward_hook(self._save_gradients)
43
+ )
44
+
45
+ def _save_activations(self, module, input, output):
46
+ self.activations = output.detach()
47
+
48
+ def _save_gradients(self, module, grad_in, grad_out):
49
+ self.gradients = grad_out[0].detach()
50
+
51
+ def generate(self, input_tensor, class_idx=None):
52
+ """Generate GradCAM heatmap.
53
+
54
+ Args:
55
+ input_tensor: (B, C, H, W) image tensor
56
+ class_idx: Target class (None = predicted class)
57
+
58
+ Returns:
59
+ cam: (B, 1, H, W) heatmap normalized to [0, 1]
60
+ """
61
+ self.model.eval()
62
+ output = self.model(input_tensor)
63
+
64
+ if class_idx is None:
65
+ class_idx = output.argmax(dim=1)
66
+
67
+ self.model.zero_grad()
68
+ # Create one-hot target
69
+ one_hot = torch.zeros_like(output)
70
+ for i in range(output.size(0)):
71
+ one_hot[i, class_idx[i] if isinstance(class_idx, torch.Tensor) else class_idx] = 1.0
72
+
73
+ output.backward(gradient=one_hot, retain_graph=True)
74
+
75
+ # Weighted combination of activation maps
76
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True) # (B, C, 1, 1)
77
+ cam = (weights * self.activations).sum(dim=1, keepdim=True) # (B, 1, H, W)
78
+ cam = F.relu(cam)
79
+
80
+ # Normalize per sample
81
+ B = cam.size(0)
82
+ cam_flat = cam.view(B, -1)
83
+ cam_min = cam_flat.min(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
84
+ cam_max = cam_flat.max(dim=1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
85
+ cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
86
+
87
+ # Upscale to input resolution
88
+ cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
89
+ return cam
90
+
91
+ def remove_hooks(self):
92
+ for h in self._hooks:
93
+ h.remove()
94
+
95
+
96
+ # ============================================================
97
+ # Visual Branch: EfficientNet-B0 Based Deepfake Detector
98
+ # ============================================================
99
+ class VisualDeepfakeDetector(nn.Module):
100
+ """EfficientNet-B0 based binary classifier for real/fake images.
101
+
102
+ Features:
103
+ - Pretrained EfficientNet-B0 backbone (timm)
104
+ - L2-normalized features (inspired by CLIP deepfake detection)
105
+ - GradCAM-compatible architecture
106
+ """
107
+
108
+ def __init__(self, num_classes=2, pretrained=True, dropout=0.3):
109
+ super().__init__()
110
+ # EfficientNet-B0 backbone
111
+ self.backbone = timm.create_model(
112
+ 'efficientnet_b0',
113
+ pretrained=pretrained,
114
+ num_classes=0, # Remove classifier head
115
+ global_pool='' # Remove global pooling
116
+ )
117
+ self.feature_dim = 1280 # EfficientNet-B0 output channels
118
+
119
+ # Custom head with L2 normalization
120
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
121
+ self.dropout = nn.Dropout(p=dropout)
122
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
123
+
124
+ def get_features(self, x):
125
+ """Extract features before classification."""
126
+ features = self.backbone(x) # (B, 1280, H, W)
127
+ return features
128
+
129
+ def forward(self, x):
130
+ features = self.get_features(x) # (B, 1280, H, W)
131
+ pooled = self.global_pool(features).flatten(1) # (B, 1280)
132
+ pooled = F.normalize(pooled, p=2, dim=-1) # L2 normalize
133
+ pooled = self.dropout(pooled)
134
+ logits = self.classifier(pooled) # (B, 2)
135
+ return logits
136
+
137
+ def get_gradcam_target_layer(self):
138
+ """Return the target layer for GradCAM."""
139
+ # Last convolutional block of EfficientNet
140
+ return self.backbone.blocks[-1]
141
+
142
+
143
+ # ============================================================
144
+ # Text Branch: RoBERTa Based AI Text Detector
145
+ # ============================================================
146
+ class TextDeepfakeDetector(nn.Module):
147
+ """RoBERTa-based binary classifier for human vs AI-generated text.
148
+
149
+ Features:
150
+ - Pretrained RoBERTa-base backbone
151
+ - Mean pooling over token embeddings (more robust than CLS)
152
+ - Dropout regularization
153
+ """
154
+
155
+ def __init__(self, model_name='roberta-base', num_classes=2, dropout=0.3):
156
+ super().__init__()
157
+ self.encoder = AutoModel.from_pretrained(model_name)
158
+ self.hidden_dim = self.encoder.config.hidden_size # 768
159
+
160
+ self.dropout = nn.Dropout(p=dropout)
161
+ self.classifier = nn.Sequential(
162
+ nn.Linear(self.hidden_dim, 256),
163
+ nn.ReLU(),
164
+ nn.Dropout(p=dropout),
165
+ nn.Linear(256, num_classes)
166
+ )
167
+
168
+ def mean_pooling(self, model_output, attention_mask):
169
+ """Mean pooling over non-padded tokens."""
170
+ token_embeddings = model_output.last_hidden_state # (B, seq_len, hidden)
171
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
172
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
173
+ input_mask_expanded.sum(1), min=1e-9
174
+ )
175
+
176
+ def forward(self, input_ids, attention_mask):
177
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
178
+ pooled = self.mean_pooling(outputs, attention_mask) # (B, 768)
179
+ pooled = F.normalize(pooled, p=2, dim=-1)
180
+ pooled = self.dropout(pooled)
181
+ logits = self.classifier(pooled) # (B, 2)
182
+ return logits
183
+
184
+
185
+ # ============================================================
186
+ # Multimodal Fusion: Ensemble Classifier
187
+ # ============================================================
188
+ class MultimodalDeepfakeDetector(nn.Module):
189
+ """Multimodal ensemble for deepfake detection.
190
+
191
+ Combines visual (image/video frame) and text modalities with
192
+ learnable weighted late fusion. Supports single-modality inference.
193
+
194
+ Architecture (inspired by AWARE-NET two-tier ensemble):
195
+ - Visual: EfficientNet-B0 → logits
196
+ - Text: RoBERTa-base → logits
197
+ - Fusion: Learnable weighted average of probabilities
198
+
199
+ Output: confidence score [0, 1] where 1 = AI-generated/fake
200
+ """
201
+
202
+ def __init__(self, visual_pretrained=True, text_model_name='roberta-base', dropout=0.3):
203
+ super().__init__()
204
+ self.visual_branch = VisualDeepfakeDetector(
205
+ num_classes=2, pretrained=visual_pretrained, dropout=dropout
206
+ )
207
+ self.text_branch = TextDeepfakeDetector(
208
+ model_name=text_model_name, num_classes=2, dropout=dropout
209
+ )
210
+
211
+ # Learnable fusion weights (AWARE-NET style)
212
+ self.fusion_weights = nn.Parameter(torch.tensor([0.6, 0.4])) # [visual, text]
213
+
214
+ # Cross-modal attention for richer fusion (optional, used when both modalities present)
215
+ self.cross_attention = nn.MultiheadAttention(
216
+ embed_dim=128, num_heads=4, batch_first=True
217
+ )
218
+ self.visual_proj = nn.Linear(1280, 128)
219
+ self.text_proj = nn.Linear(768, 128)
220
+ self.fusion_classifier = nn.Sequential(
221
+ nn.Linear(256, 64),
222
+ nn.ReLU(),
223
+ nn.Dropout(dropout),
224
+ nn.Linear(64, 2)
225
+ )
226
+
227
+ def forward(self, images=None, input_ids=None, attention_mask=None,
228
+ modality='auto'):
229
+ """
230
+ Forward pass supporting single or multi-modal input.
231
+
232
+ Args:
233
+ images: (B, C, H, W) image tensor, optional
234
+ input_ids: (B, seq_len) text token IDs, optional
235
+ attention_mask: (B, seq_len) attention mask, optional
236
+ modality: 'visual', 'text', 'multimodal', or 'auto'
237
+
238
+ Returns:
239
+ dict with:
240
+ - logits: (B, 2) raw logits
241
+ - confidence: (B,) probability of being fake/AI-generated
242
+ - modality_scores: dict of per-modality confidence scores
243
+ """
244
+ results = {'modality_scores': {}}
245
+
246
+ has_visual = images is not None
247
+ has_text = input_ids is not None
248
+
249
+ if modality == 'auto':
250
+ if has_visual and has_text:
251
+ modality = 'multimodal'
252
+ elif has_visual:
253
+ modality = 'visual'
254
+ elif has_text:
255
+ modality = 'text'
256
+ else:
257
+ raise ValueError("At least one modality input required")
258
+
259
+ visual_logits = None
260
+ text_logits = None
261
+
262
+ if modality in ('visual', 'multimodal') and has_visual:
263
+ visual_logits = self.visual_branch(images)
264
+ visual_probs = F.softmax(visual_logits, dim=-1)
265
+ results['modality_scores']['visual'] = visual_probs[:, 1] # P(fake) ← FIXED
266
+
267
+ if modality in ('text', 'multimodal') and has_text:
268
+ text_logits = self.text_branch(input_ids, attention_mask)
269
+ text_probs = F.softmax(text_logits, dim=-1)
270
+ results['modality_scores']['text'] = text_probs[:, 1] # P(fake) ← FIXED
271
+
272
+ # Fusion logic
273
+ if modality == 'multimodal' and visual_logits is not None and text_logits is not None:
274
+ # Late fusion: learnable weighted average
275
+ weights = F.softmax(self.fusion_weights, dim=0)
276
+ visual_probs = F.softmax(visual_logits, dim=-1)
277
+ text_probs = F.softmax(text_logits, dim=-1)
278
+ fused_probs = weights[0] * visual_probs + weights[1] * text_probs
279
+ results['logits'] = torch.log(fused_probs + 1e-8)
280
+ results['confidence'] = fused_probs[:, 1] # P(fake)
281
+ elif visual_logits is not None:
282
+ results['logits'] = visual_logits
283
+ results['confidence'] = F.softmax(visual_logits, dim=-1)[:, 1] # P(fake)
284
+ elif text_logits is not None:
285
+ results['logits'] = text_logits
286
+ results['confidence'] = F.softmax(text_logits, dim=-1)[:, 1] # P(fake)
287
+
288
+ return results
289
+
290
+ def get_visual_gradcam(self):
291
+ """Get GradCAM instance for visual branch."""
292
+ target_layer = self.visual_branch.get_gradcam_target_layer()
293
+ return GradCAM(self.visual_branch, target_layer)
294
+
295
+
296
+ # ============================================================
297
+ # Helper: Video Frame Aggregation
298
+ # ============================================================
299
+ def aggregate_video_predictions(frame_confidences, method='mean'):
300
+ """Aggregate per-frame predictions to video-level score.
301
+
302
+ Args:
303
+ frame_confidences: list/tensor of per-frame P(fake) scores
304
+ method: 'mean', 'max', 'voting' (majority vote at 0.5 threshold)
305
+
306
+ Returns:
307
+ video_confidence: scalar P(fake) for the whole video
308
+ """
309
+ if isinstance(frame_confidences, list):
310
+ frame_confidences = torch.tensor(frame_confidences)
311
+
312
+ if method == 'mean':
313
+ return frame_confidences.mean().item()
314
+ elif method == 'max':
315
+ return frame_confidences.max().item()
316
+ elif method == 'voting':
317
+ votes = (frame_confidences > 0.5).float()
318
+ return votes.mean().item()
319
+ else:
320
+ raise ValueError(f"Unknown aggregation method: {method}")