File size: 6,397 Bytes
cd95a81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """
Inference Pipeline for Multimodal Deepfake Detection
=====================================================
Supports:
- Single image classification with confidence + GradCAM heatmap
- Video classification (frame-by-frame → aggregated score)
- Text classification (human vs AI-generated)
- Multimodal (image + text combined)
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import os
def load_model(checkpoint_path, device='cpu'):
from model import MultimodalDeepfakeDetector
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = checkpoint['config']
model = MultimodalDeepfakeDetector(visual_pretrained=False, text_model_name=config['text_model_name'], dropout=0.0)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model, config
def classify_image(model, image_path_or_pil, device='cpu', return_gradcam=True):
from preprocessing import get_image_transforms
from model import GradCAM
if isinstance(image_path_or_pil, str):
image = Image.open(image_path_or_pil).convert('RGB')
else:
image = image_path_or_pil.convert('RGB')
transform = get_image_transforms('eval', 224)
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
results = model(images=image_tensor, modality='visual')
confidence = results['confidence'].item()
prediction = 'fake' if confidence > 0.5 else 'real'
output = {'prediction': prediction, 'confidence': confidence, 'visual_score': results['modality_scores']['visual'].item()}
if return_gradcam:
model.visual_branch.eval()
gradcam = GradCAM(model.visual_branch, model.visual_branch.get_gradcam_target_layer())
image_tensor_grad = image_tensor.clone().requires_grad_(True)
cam = gradcam.generate(image_tensor_grad, class_idx=1)
output['gradcam'] = cam.squeeze().cpu().numpy()
gradcam.remove_hooks()
return output
def classify_video(model, video_path, device='cpu', num_frames=32, aggregation='mean'):
from preprocessing import extract_video_frames, get_image_transforms
from model import aggregate_video_predictions
frames = extract_video_frames(video_path, num_frames=num_frames)
transform = get_image_transforms('eval', 224)
frame_scores = []
model.eval()
with torch.no_grad():
for frame in frames:
image_tensor = transform(frame.convert('RGB')).unsqueeze(0).to(device)
results = model(images=image_tensor, modality='visual')
frame_scores.append(results['confidence'].item())
video_confidence = aggregate_video_predictions(torch.tensor(frame_scores), method=aggregation)
return {
'prediction': 'fake' if video_confidence > 0.5 else 'real',
'confidence': video_confidence,
'frame_scores': frame_scores,
'num_frames_analyzed': len(frames),
'aggregation_method': aggregation,
}
def classify_text(model, text, tokenizer=None, device='cpu', max_length=512):
from transformers import AutoTokenizer
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
encoding = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
model.eval()
with torch.no_grad():
results = model(input_ids=input_ids, attention_mask=attention_mask, modality='text')
confidence = results['confidence'].item()
return {
'prediction': 'ai_generated' if confidence > 0.5 else 'human',
'confidence': confidence,
'text_score': results['modality_scores']['text'].item(),
}
def classify_multimodal(model, image_path_or_pil=None, text=None, tokenizer=None, device='cpu'):
from preprocessing import get_image_transforms
from transformers import AutoTokenizer
images = input_ids = attention_mask = None
if image_path_or_pil is not None:
if isinstance(image_path_or_pil, str):
image = Image.open(image_path_or_pil).convert('RGB')
else:
image = image_path_or_pil.convert('RGB')
transform = get_image_transforms('eval', 224)
images = transform(image).unsqueeze(0).to(device)
if text is not None:
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
encoding = tokenizer(text, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
model.eval()
with torch.no_grad():
results = model(images=images, input_ids=input_ids, attention_mask=attention_mask, modality='auto')
confidence = results['confidence'].item()
output = {
'prediction': 'fake/ai_generated' if confidence > 0.5 else 'real/human',
'confidence': confidence,
'modality_scores': {k: v.item() for k, v in results['modality_scores'].items()},
}
with torch.no_grad():
fusion_weights = F.softmax(model.fusion_weights, dim=0)
output['fusion_weights'] = {'visual': fusion_weights[0].item(), 'text': fusion_weights[1].item()}
return output
def visualize_gradcam(image_path, gradcam_heatmap, confidence, save_path=None):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
image = Image.open(image_path).convert('RGB')
image_np = np.array(image.resize((224, 224))) / 255.0
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image_np); axes[0].set_title('Original'); axes[0].axis('off')
axes[1].imshow(gradcam_heatmap, cmap='jet'); axes[1].set_title('GradCAM Heatmap'); axes[1].axis('off')
axes[2].imshow(image_np); axes[2].imshow(gradcam_heatmap, cmap='jet', alpha=0.4)
axes[2].set_title('Overlay (Explanation)'); axes[2].axis('off')
label = "FAKE" if confidence > 0.5 else "REAL"
color = 'red' if confidence > 0.5 else 'green'
fig.suptitle(f'{label} — Confidence: {confidence:.2%}', fontsize=16, fontweight='bold', color=color)
plt.tight_layout()
if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
|