alianassmaaa commited on
Commit
cd95a81
·
verified ·
1 Parent(s): e08d744

Add inference pipeline

Browse files
Files changed (1) hide show
  1. inference.py +142 -0
inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Pipeline for Multimodal Deepfake Detection
3
+ =====================================================
4
+ Supports:
5
+ - Single image classification with confidence + GradCAM heatmap
6
+ - Video classification (frame-by-frame → aggregated score)
7
+ - Text classification (human vs AI-generated)
8
+ - Multimodal (image + text combined)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from PIL import Image
15
+ import json
16
+ import os
17
+
18
+
19
+ def load_model(checkpoint_path, device='cpu'):
20
+ from model import MultimodalDeepfakeDetector
21
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
22
+ config = checkpoint['config']
23
+ model = MultimodalDeepfakeDetector(visual_pretrained=False, text_model_name=config['text_model_name'], dropout=0.0)
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ model = model.to(device)
26
+ model.eval()
27
+ return model, config
28
+
29
+
30
+ def classify_image(model, image_path_or_pil, device='cpu', return_gradcam=True):
31
+ from preprocessing import get_image_transforms
32
+ from model import GradCAM
33
+ if isinstance(image_path_or_pil, str):
34
+ image = Image.open(image_path_or_pil).convert('RGB')
35
+ else:
36
+ image = image_path_or_pil.convert('RGB')
37
+ transform = get_image_transforms('eval', 224)
38
+ image_tensor = transform(image).unsqueeze(0).to(device)
39
+ with torch.no_grad():
40
+ results = model(images=image_tensor, modality='visual')
41
+ confidence = results['confidence'].item()
42
+ prediction = 'fake' if confidence > 0.5 else 'real'
43
+ output = {'prediction': prediction, 'confidence': confidence, 'visual_score': results['modality_scores']['visual'].item()}
44
+ if return_gradcam:
45
+ model.visual_branch.eval()
46
+ gradcam = GradCAM(model.visual_branch, model.visual_branch.get_gradcam_target_layer())
47
+ image_tensor_grad = image_tensor.clone().requires_grad_(True)
48
+ cam = gradcam.generate(image_tensor_grad, class_idx=1)
49
+ output['gradcam'] = cam.squeeze().cpu().numpy()
50
+ gradcam.remove_hooks()
51
+ return output
52
+
53
+
54
+ def classify_video(model, video_path, device='cpu', num_frames=32, aggregation='mean'):
55
+ from preprocessing import extract_video_frames, get_image_transforms
56
+ from model import aggregate_video_predictions
57
+ frames = extract_video_frames(video_path, num_frames=num_frames)
58
+ transform = get_image_transforms('eval', 224)
59
+ frame_scores = []
60
+ model.eval()
61
+ with torch.no_grad():
62
+ for frame in frames:
63
+ image_tensor = transform(frame.convert('RGB')).unsqueeze(0).to(device)
64
+ results = model(images=image_tensor, modality='visual')
65
+ frame_scores.append(results['confidence'].item())
66
+ video_confidence = aggregate_video_predictions(torch.tensor(frame_scores), method=aggregation)
67
+ return {
68
+ 'prediction': 'fake' if video_confidence > 0.5 else 'real',
69
+ 'confidence': video_confidence,
70
+ 'frame_scores': frame_scores,
71
+ 'num_frames_analyzed': len(frames),
72
+ 'aggregation_method': aggregation,
73
+ }
74
+
75
+
76
+ def classify_text(model, text, tokenizer=None, device='cpu', max_length=512):
77
+ from transformers import AutoTokenizer
78
+ if tokenizer is None:
79
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
80
+ encoding = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
81
+ input_ids = encoding['input_ids'].to(device)
82
+ attention_mask = encoding['attention_mask'].to(device)
83
+ model.eval()
84
+ with torch.no_grad():
85
+ results = model(input_ids=input_ids, attention_mask=attention_mask, modality='text')
86
+ confidence = results['confidence'].item()
87
+ return {
88
+ 'prediction': 'ai_generated' if confidence > 0.5 else 'human',
89
+ 'confidence': confidence,
90
+ 'text_score': results['modality_scores']['text'].item(),
91
+ }
92
+
93
+
94
+ def classify_multimodal(model, image_path_or_pil=None, text=None, tokenizer=None, device='cpu'):
95
+ from preprocessing import get_image_transforms
96
+ from transformers import AutoTokenizer
97
+ images = input_ids = attention_mask = None
98
+ if image_path_or_pil is not None:
99
+ if isinstance(image_path_or_pil, str):
100
+ image = Image.open(image_path_or_pil).convert('RGB')
101
+ else:
102
+ image = image_path_or_pil.convert('RGB')
103
+ transform = get_image_transforms('eval', 224)
104
+ images = transform(image).unsqueeze(0).to(device)
105
+ if text is not None:
106
+ if tokenizer is None:
107
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
108
+ encoding = tokenizer(text, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
109
+ input_ids = encoding['input_ids'].to(device)
110
+ attention_mask = encoding['attention_mask'].to(device)
111
+ model.eval()
112
+ with torch.no_grad():
113
+ results = model(images=images, input_ids=input_ids, attention_mask=attention_mask, modality='auto')
114
+ confidence = results['confidence'].item()
115
+ output = {
116
+ 'prediction': 'fake/ai_generated' if confidence > 0.5 else 'real/human',
117
+ 'confidence': confidence,
118
+ 'modality_scores': {k: v.item() for k, v in results['modality_scores'].items()},
119
+ }
120
+ with torch.no_grad():
121
+ fusion_weights = F.softmax(model.fusion_weights, dim=0)
122
+ output['fusion_weights'] = {'visual': fusion_weights[0].item(), 'text': fusion_weights[1].item()}
123
+ return output
124
+
125
+
126
+ def visualize_gradcam(image_path, gradcam_heatmap, confidence, save_path=None):
127
+ import matplotlib
128
+ matplotlib.use('Agg')
129
+ import matplotlib.pyplot as plt
130
+ image = Image.open(image_path).convert('RGB')
131
+ image_np = np.array(image.resize((224, 224))) / 255.0
132
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
133
+ axes[0].imshow(image_np); axes[0].set_title('Original'); axes[0].axis('off')
134
+ axes[1].imshow(gradcam_heatmap, cmap='jet'); axes[1].set_title('GradCAM Heatmap'); axes[1].axis('off')
135
+ axes[2].imshow(image_np); axes[2].imshow(gradcam_heatmap, cmap='jet', alpha=0.4)
136
+ axes[2].set_title('Overlay (Explanation)'); axes[2].axis('off')
137
+ label = "FAKE" if confidence > 0.5 else "REAL"
138
+ color = 'red' if confidence > 0.5 else 'green'
139
+ fig.suptitle(f'{label} — Confidence: {confidence:.2%}', fontsize=16, fontweight='bold', color=color)
140
+ plt.tight_layout()
141
+ if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight')
142
+ plt.close()