alianassmaaa commited on
Commit
a1bd556
·
verified ·
1 Parent(s): 45d5acb

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +314 -0
inference.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Pipeline for Multimodal Deepfake Detection
3
+ =====================================================
4
+ Supports:
5
+ - Single image classification with confidence + GradCAM heatmap
6
+ - Video classification (frame-by-frame → aggregated score)
7
+ - Text classification (human vs AI-generated)
8
+ - Multimodal (image + text combined)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from PIL import Image
15
+ import json
16
+ import os
17
+
18
+
19
+ def load_model(checkpoint_path, device='cpu'):
20
+ """Load the trained multimodal ensemble model.
21
+
22
+ Args:
23
+ checkpoint_path: Path to multimodal_ensemble.pt
24
+ device: 'cpu' or 'cuda'
25
+
26
+ Returns:
27
+ model: MultimodalDeepfakeDetector
28
+ config: training configuration dict
29
+ """
30
+ from model import MultimodalDeepfakeDetector
31
+
32
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
33
+ config = checkpoint['config']
34
+
35
+ model = MultimodalDeepfakeDetector(
36
+ visual_pretrained=False,
37
+ text_model_name=config['text_model_name'],
38
+ dropout=0.0, # No dropout at inference
39
+ )
40
+ model.load_state_dict(checkpoint['model_state_dict'])
41
+ model = model.to(device)
42
+ model.eval()
43
+
44
+ return model, config
45
+
46
+
47
+ def classify_image(model, image_path_or_pil, device='cpu', return_gradcam=True):
48
+ """Classify a single image as real or AI-generated/deepfake.
49
+
50
+ Args:
51
+ model: MultimodalDeepfakeDetector
52
+ image_path_or_pil: Path to image or PIL Image
53
+ device: computation device
54
+ return_gradcam: whether to generate explainability map
55
+
56
+ Returns:
57
+ dict with:
58
+ - prediction: 'real' or 'fake'
59
+ - confidence: float [0, 1] (probability of being fake)
60
+ - gradcam: numpy array (H, W) heatmap if return_gradcam=True
61
+ """
62
+ from preprocessing import get_image_transforms
63
+ from model import GradCAM
64
+
65
+ if isinstance(image_path_or_pil, str):
66
+ image = Image.open(image_path_or_pil).convert('RGB')
67
+ else:
68
+ image = image_path_or_pil.convert('RGB')
69
+
70
+ transform = get_image_transforms('eval', 224)
71
+ image_tensor = transform(image).unsqueeze(0).to(device)
72
+
73
+ with torch.no_grad():
74
+ results = model(images=image_tensor, modality='visual')
75
+
76
+ confidence = results['confidence'].item()
77
+ prediction = 'fake' if confidence > 0.5 else 'real'
78
+
79
+ output = {
80
+ 'prediction': prediction,
81
+ 'confidence': confidence,
82
+ 'visual_score': results['modality_scores']['visual'].item(),
83
+ }
84
+
85
+ if return_gradcam:
86
+ # Enable gradients for GradCAM
87
+ model.visual_branch.eval()
88
+ gradcam = GradCAM(model.visual_branch, model.visual_branch.get_gradcam_target_layer())
89
+
90
+ image_tensor_grad = image_tensor.clone().requires_grad_(True)
91
+ cam = gradcam.generate(image_tensor_grad, class_idx=1) # Heatmap for "fake" class
92
+ output['gradcam'] = cam.squeeze().cpu().numpy()
93
+ gradcam.remove_hooks()
94
+
95
+ return output
96
+
97
+
98
+ def classify_video(model, video_path, device='cpu', num_frames=32, aggregation='mean'):
99
+ """Classify a video as real or deepfake.
100
+
101
+ Extracts frames uniformly, classifies each, and aggregates.
102
+
103
+ Args:
104
+ model: MultimodalDeepfakeDetector
105
+ video_path: Path to video file
106
+ device: computation device
107
+ num_frames: number of frames to sample
108
+ aggregation: 'mean', 'max', or 'voting'
109
+
110
+ Returns:
111
+ dict with video-level prediction and per-frame scores
112
+ """
113
+ from preprocessing import extract_video_frames, get_image_transforms
114
+ from model import aggregate_video_predictions
115
+
116
+ frames = extract_video_frames(video_path, num_frames=num_frames)
117
+ transform = get_image_transforms('eval', 224)
118
+
119
+ frame_scores = []
120
+ model.eval()
121
+
122
+ with torch.no_grad():
123
+ for frame in frames:
124
+ image_tensor = transform(frame.convert('RGB')).unsqueeze(0).to(device)
125
+ results = model(images=image_tensor, modality='visual')
126
+ frame_scores.append(results['confidence'].item())
127
+
128
+ video_confidence = aggregate_video_predictions(
129
+ torch.tensor(frame_scores), method=aggregation
130
+ )
131
+
132
+ return {
133
+ 'prediction': 'fake' if video_confidence > 0.5 else 'real',
134
+ 'confidence': video_confidence,
135
+ 'frame_scores': frame_scores,
136
+ 'num_frames_analyzed': len(frames),
137
+ 'aggregation_method': aggregation,
138
+ }
139
+
140
+
141
+ def classify_text(model, text, tokenizer=None, device='cpu', max_length=512):
142
+ """Classify text as human-written or AI-generated.
143
+
144
+ Args:
145
+ model: MultimodalDeepfakeDetector
146
+ text: input text string
147
+ tokenizer: optional pre-loaded tokenizer
148
+ device: computation device
149
+ max_length: max sequence length
150
+
151
+ Returns:
152
+ dict with prediction and confidence
153
+ """
154
+ from transformers import AutoTokenizer
155
+
156
+ if tokenizer is None:
157
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
158
+
159
+ encoding = tokenizer(
160
+ text,
161
+ max_length=max_length,
162
+ padding='max_length',
163
+ truncation=True,
164
+ return_tensors='pt'
165
+ )
166
+
167
+ input_ids = encoding['input_ids'].to(device)
168
+ attention_mask = encoding['attention_mask'].to(device)
169
+
170
+ model.eval()
171
+ with torch.no_grad():
172
+ results = model(input_ids=input_ids, attention_mask=attention_mask, modality='text')
173
+
174
+ confidence = results['confidence'].item()
175
+
176
+ return {
177
+ 'prediction': 'ai_generated' if confidence > 0.5 else 'human',
178
+ 'confidence': confidence,
179
+ 'text_score': results['modality_scores']['text'].item(),
180
+ }
181
+
182
+
183
+ def classify_multimodal(model, image_path_or_pil=None, text=None,
184
+ tokenizer=None, device='cpu'):
185
+ """Combined multimodal classification.
186
+
187
+ Uses both image and text when available, with learned fusion weights.
188
+
189
+ Returns:
190
+ dict with combined prediction, confidence, and per-modality scores
191
+ """
192
+ from preprocessing import get_image_transforms
193
+ from transformers import AutoTokenizer
194
+
195
+ images = None
196
+ input_ids = None
197
+ attention_mask = None
198
+
199
+ if image_path_or_pil is not None:
200
+ if isinstance(image_path_or_pil, str):
201
+ image = Image.open(image_path_or_pil).convert('RGB')
202
+ else:
203
+ image = image_path_or_pil.convert('RGB')
204
+ transform = get_image_transforms('eval', 224)
205
+ images = transform(image).unsqueeze(0).to(device)
206
+
207
+ if text is not None:
208
+ if tokenizer is None:
209
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
210
+ encoding = tokenizer(text, max_length=512, padding='max_length',
211
+ truncation=True, return_tensors='pt')
212
+ input_ids = encoding['input_ids'].to(device)
213
+ attention_mask = encoding['attention_mask'].to(device)
214
+
215
+ model.eval()
216
+ with torch.no_grad():
217
+ results = model(
218
+ images=images,
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ modality='auto'
222
+ )
223
+
224
+ confidence = results['confidence'].item()
225
+
226
+ output = {
227
+ 'prediction': 'fake/ai_generated' if confidence > 0.5 else 'real/human',
228
+ 'confidence': confidence,
229
+ 'modality_scores': {k: v.item() for k, v in results['modality_scores'].items()},
230
+ }
231
+
232
+ # Show fusion weights
233
+ with torch.no_grad():
234
+ fusion_weights = F.softmax(model.fusion_weights, dim=0)
235
+ output['fusion_weights'] = {
236
+ 'visual': fusion_weights[0].item(),
237
+ 'text': fusion_weights[1].item(),
238
+ }
239
+
240
+ return output
241
+
242
+
243
+ def visualize_gradcam(image_path, gradcam_heatmap, confidence, save_path=None):
244
+ """Visualize GradCAM overlay on the original image.
245
+
246
+ Args:
247
+ image_path: Path to original image
248
+ gradcam_heatmap: (H, W) numpy array from classify_image
249
+ confidence: fake confidence score
250
+ save_path: optional path to save visualization
251
+ """
252
+ import matplotlib
253
+ matplotlib.use('Agg')
254
+ import matplotlib.pyplot as plt
255
+
256
+ image = Image.open(image_path).convert('RGB')
257
+ image_np = np.array(image.resize((224, 224))) / 255.0
258
+
259
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
260
+
261
+ axes[0].imshow(image_np)
262
+ axes[0].set_title('Original')
263
+ axes[0].axis('off')
264
+
265
+ axes[1].imshow(gradcam_heatmap, cmap='jet')
266
+ axes[1].set_title('GradCAM Heatmap')
267
+ axes[1].axis('off')
268
+
269
+ axes[2].imshow(image_np)
270
+ axes[2].imshow(gradcam_heatmap, cmap='jet', alpha=0.4)
271
+ axes[2].set_title('Overlay (Explanation)')
272
+ axes[2].axis('off')
273
+
274
+ label = "FAKE" if confidence > 0.5 else "REAL"
275
+ color = 'red' if confidence > 0.5 else 'green'
276
+ fig.suptitle(f'{label} — Confidence: {confidence:.2%}', fontsize=16,
277
+ fontweight='bold', color=color)
278
+
279
+ plt.tight_layout()
280
+ if save_path:
281
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
282
+ plt.show()
283
+ plt.close()
284
+
285
+
286
+ # ============================================================
287
+ # Demo / Usage Example
288
+ # ============================================================
289
+ if __name__ == '__main__':
290
+ print("=" * 60)
291
+ print("Multimodal Deepfake Detection - Inference Demo")
292
+ print("=" * 60)
293
+ print()
294
+ print("Usage:")
295
+ print(" from inference import load_model, classify_image, classify_text, classify_multimodal")
296
+ print()
297
+ print(" # Load model")
298
+ print(" model, config = load_model('output/multimodal_ensemble.pt', device='cuda')")
299
+ print()
300
+ print(" # Image classification")
301
+ print(" result = classify_image(model, 'face.jpg', device='cuda')")
302
+ print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')")
303
+ print()
304
+ print(" # Text classification")
305
+ print(" result = classify_text(model, 'This text was generated by AI...')")
306
+ print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')")
307
+ print()
308
+ print(" # Video classification")
309
+ print(" result = classify_video(model, 'video.mp4', device='cuda')")
310
+ print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')")
311
+ print()
312
+ print(" # Multimodal (image + text)")
313
+ print(" result = classify_multimodal(model, image_path_or_pil='face.jpg', text='Caption...')")
314
+ print(" print(f'Combined: {result[\"prediction\"]} — Scores: {result[\"modality_scores\"]}')")