Spaces:
Runtime error
Runtime error
| """ | |
| Inference Pipeline for Multimodal Deepfake Detection | |
| ===================================================== | |
| Supports: | |
| - Single image classification with confidence + GradCAM heatmap | |
| - Video classification (frame-by-frame → aggregated score) | |
| - Text classification (human vs AI-generated) | |
| - Multimodal (image + text combined) | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import os | |
| def load_model(checkpoint_path, device='cpu'): | |
| """Load the trained multimodal ensemble model. | |
| Args: | |
| checkpoint_path: Path to multimodal_ensemble.pt | |
| device: 'cpu' or 'cuda' | |
| Returns: | |
| model: MultimodalDeepfakeDetector | |
| config: training configuration dict | |
| """ | |
| from model import MultimodalDeepfakeDetector | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| config = checkpoint['config'] | |
| model = MultimodalDeepfakeDetector( | |
| visual_pretrained=False, | |
| text_model_name=config['text_model_name'], | |
| dropout=0.0, # No dropout at inference | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| return model, config | |
| def classify_image(model, image_path_or_pil, device='cpu', return_gradcam=True): | |
| """Classify a single image as real or AI-generated/deepfake. | |
| Args: | |
| model: MultimodalDeepfakeDetector | |
| image_path_or_pil: Path to image or PIL Image | |
| device: computation device | |
| return_gradcam: whether to generate explainability map | |
| Returns: | |
| dict with: | |
| - prediction: 'real' or 'fake' | |
| - confidence: float [0, 1] (probability of being fake) | |
| - gradcam: numpy array (H, W) heatmap if return_gradcam=True | |
| """ | |
| from preprocessing import get_image_transforms | |
| from model import GradCAM | |
| if isinstance(image_path_or_pil, str): | |
| image = Image.open(image_path_or_pil).convert('RGB') | |
| else: | |
| image = image_path_or_pil.convert('RGB') | |
| transform = get_image_transforms('eval', 224) | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| results = model(images=image_tensor, modality='visual') | |
| confidence = results['confidence'].item() | |
| prediction = 'fake' if confidence > 0.5 else 'real' | |
| output = { | |
| 'prediction': prediction, | |
| 'confidence': confidence, | |
| 'visual_score': results['modality_scores']['visual'].item(), | |
| } | |
| if return_gradcam: | |
| # Enable gradients for GradCAM | |
| model.visual_branch.eval() | |
| gradcam = GradCAM(model.visual_branch, model.visual_branch.get_gradcam_target_layer()) | |
| image_tensor_grad = image_tensor.clone().requires_grad_(True) | |
| cam = gradcam.generate(image_tensor_grad, class_idx=1) # Heatmap for "fake" class | |
| output['gradcam'] = cam.squeeze().cpu().numpy() | |
| gradcam.remove_hooks() | |
| return output | |
| def classify_video(model, video_path, device='cpu', num_frames=32, aggregation='mean'): | |
| """Classify a video as real or deepfake. | |
| Extracts frames uniformly, classifies each, and aggregates. | |
| Args: | |
| model: MultimodalDeepfakeDetector | |
| video_path: Path to video file | |
| device: computation device | |
| num_frames: number of frames to sample | |
| aggregation: 'mean', 'max', or 'voting' | |
| Returns: | |
| dict with video-level prediction and per-frame scores | |
| """ | |
| from preprocessing import extract_video_frames, get_image_transforms | |
| from model import aggregate_video_predictions | |
| frames = extract_video_frames(video_path, num_frames=num_frames) | |
| transform = get_image_transforms('eval', 224) | |
| frame_scores = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for frame in frames: | |
| image_tensor = transform(frame.convert('RGB')).unsqueeze(0).to(device) | |
| results = model(images=image_tensor, modality='visual') | |
| frame_scores.append(results['confidence'].item()) | |
| video_confidence = aggregate_video_predictions( | |
| torch.tensor(frame_scores), method=aggregation | |
| ) | |
| return { | |
| 'prediction': 'fake' if video_confidence > 0.5 else 'real', | |
| 'confidence': video_confidence, | |
| 'frame_scores': frame_scores, | |
| 'num_frames_analyzed': len(frames), | |
| 'aggregation_method': aggregation, | |
| } | |
| def classify_text(model, text, tokenizer=None, device='cpu', max_length=512): | |
| """Classify text as human-written or AI-generated. | |
| Args: | |
| model: MultimodalDeepfakeDetector | |
| text: input text string | |
| tokenizer: optional pre-loaded tokenizer | |
| device: computation device | |
| max_length: max sequence length | |
| Returns: | |
| dict with prediction and confidence | |
| """ | |
| from transformers import AutoTokenizer | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained('roberta-base') | |
| encoding = tokenizer( | |
| text, | |
| max_length=max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| results = model(input_ids=input_ids, attention_mask=attention_mask, modality='text') | |
| confidence = results['confidence'].item() | |
| return { | |
| 'prediction': 'ai_generated' if confidence > 0.5 else 'human', | |
| 'confidence': confidence, | |
| 'text_score': results['modality_scores']['text'].item(), | |
| } | |
| def classify_multimodal(model, image_path_or_pil=None, text=None, | |
| tokenizer=None, device='cpu'): | |
| """Combined multimodal classification. | |
| Uses both image and text when available, with learned fusion weights. | |
| Returns: | |
| dict with combined prediction, confidence, and per-modality scores | |
| """ | |
| from preprocessing import get_image_transforms | |
| from transformers import AutoTokenizer | |
| images = None | |
| input_ids = None | |
| attention_mask = None | |
| if image_path_or_pil is not None: | |
| if isinstance(image_path_or_pil, str): | |
| image = Image.open(image_path_or_pil).convert('RGB') | |
| else: | |
| image = image_path_or_pil.convert('RGB') | |
| transform = get_image_transforms('eval', 224) | |
| images = transform(image).unsqueeze(0).to(device) | |
| if text is not None: | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained('roberta-base') | |
| encoding = tokenizer(text, max_length=512, padding='max_length', | |
| truncation=True, return_tensors='pt') | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| results = model( | |
| images=images, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| modality='auto' | |
| ) | |
| confidence = results['confidence'].item() | |
| output = { | |
| 'prediction': 'fake/ai_generated' if confidence > 0.5 else 'real/human', | |
| 'confidence': confidence, | |
| 'modality_scores': {k: v.item() for k, v in results['modality_scores'].items()}, | |
| } | |
| # Show fusion weights | |
| with torch.no_grad(): | |
| fusion_weights = F.softmax(model.fusion_weights, dim=0) | |
| output['fusion_weights'] = { | |
| 'visual': fusion_weights[0].item(), | |
| 'text': fusion_weights[1].item(), | |
| } | |
| return output | |
| def visualize_gradcam(image_path, gradcam_heatmap, confidence, save_path=None): | |
| """Visualize GradCAM overlay on the original image. | |
| Args: | |
| image_path: Path to original image | |
| gradcam_heatmap: (H, W) numpy array from classify_image | |
| confidence: fake confidence score | |
| save_path: optional path to save visualization | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| image = Image.open(image_path).convert('RGB') | |
| image_np = np.array(image.resize((224, 224))) / 255.0 | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| axes[0].imshow(image_np) | |
| axes[0].set_title('Original') | |
| axes[0].axis('off') | |
| axes[1].imshow(gradcam_heatmap, cmap='jet') | |
| axes[1].set_title('GradCAM Heatmap') | |
| axes[1].axis('off') | |
| axes[2].imshow(image_np) | |
| axes[2].imshow(gradcam_heatmap, cmap='jet', alpha=0.4) | |
| axes[2].set_title('Overlay (Explanation)') | |
| axes[2].axis('off') | |
| label = "FAKE" if confidence > 0.5 else "REAL" | |
| color = 'red' if confidence > 0.5 else 'green' | |
| fig.suptitle(f'{label} — Confidence: {confidence:.2%}', fontsize=16, | |
| fontweight='bold', color=color) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.show() | |
| plt.close() | |
| # ============================================================ | |
| # Demo / Usage Example | |
| # ============================================================ | |
| if __name__ == '__main__': | |
| print("=" * 60) | |
| print("Multimodal Deepfake Detection - Inference Demo") | |
| print("=" * 60) | |
| print() | |
| print("Usage:") | |
| print(" from inference import load_model, classify_image, classify_text, classify_multimodal") | |
| print() | |
| print(" # Load model") | |
| print(" model, config = load_model('output/multimodal_ensemble.pt', device='cuda')") | |
| print() | |
| print(" # Image classification") | |
| print(" result = classify_image(model, 'face.jpg', device='cuda')") | |
| print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')") | |
| print() | |
| print(" # Text classification") | |
| print(" result = classify_text(model, 'This text was generated by AI...')") | |
| print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')") | |
| print() | |
| print(" # Video classification") | |
| print(" result = classify_video(model, 'video.mp4', device='cuda')") | |
| print(" print(f'Prediction: {result[\"prediction\"]} (confidence: {result[\"confidence\"]:.2%})')") | |
| print() | |
| print(" # Multimodal (image + text)") | |
| print(" result = classify_multimodal(model, image_path_or_pil='face.jpg', text='Caption...')") | |
| print(" print(f'Combined: {result[\"prediction\"]} — Scores: {result[\"modality_scores\"]}')") | |