Saracasm
Phase 6: add production inference module (DeepfakeDetector wrapper)
2f9a835
"""
Inference module for deepfake audio detection.
Wraps the Stage 2 Wav2Vec 2.0 classifier with a clean public API.
Usage:
from src.inference.predict import DeepfakeDetector
detector = DeepfakeDetector(checkpoint_path="path/to/stage2_best.pt")
result = detector.predict("path/to/audio.wav")
print(result)
# {"spoof_probability": 0.84, "prediction": "spoof", "confidence": 0.84,
# "utterance_duration_sec": 3.42, "n_windows": 1, "model_version": "stage2"}
"""
import os
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
import numpy as np
from src.models.wav2vec_classifier import Wav2VecClassifier
from src.data.preprocessing import load_audio, segment_waveform, WINDOW_SAMPLES
# Default classifier threshold. 0.5 is naive; we tuned it during eval.
# Values closer to 0.5 = balanced; lower = more sensitive (more false alarms);
# higher = more conservative (more misses).
DEFAULT_THRESHOLD = 0.5
class DeepfakeDetector:
"""Anti-spoofing classifier wrapper for one-shot inference."""
def __init__(
self,
checkpoint_path: str,
device: Optional[str] = None,
backbone_name: str = "facebook/wav2vec2-base",
threshold: float = DEFAULT_THRESHOLD,
use_mixed_precision: bool = True,
):
"""
Args:
checkpoint_path: path to a Stage 2 .pt checkpoint
device: 'cuda', 'cpu', or None (auto-detect)
backbone_name: HuggingFace model name for Wav2Vec backbone
threshold: probability threshold above which we predict "spoof"
use_mixed_precision: use fp16 inference (faster on GPU)
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.threshold = threshold
self.use_mixed_precision = use_mixed_precision and (device == "cuda")
# Build model and load weights
self.model = Wav2VecClassifier(
backbone_name=backbone_name,
num_classes=2,
freeze_backbone=True,
)
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
self.model.load_state_dict(ckpt["model_state_dict"])
self.model = self.model.to(device)
self.model.eval()
# Store metadata for transparency
self.checkpoint_metadata = {
"epoch": ckpt.get("epoch"),
"best_eer": ckpt.get("best_eer"),
"checkpoint_path": checkpoint_path,
}
@torch.no_grad()
def predict(
self,
audio_input: Union[str, torch.Tensor, np.ndarray],
return_per_window: bool = False,
) -> Dict:
"""Predict bonafide vs spoof for a single audio input.
Args:
audio_input: either a file path (str), a 1-D Tensor at 16 kHz, or
a 1-D numpy array at 16 kHz.
return_per_window: if True, include per-window probabilities in
the result for debugging.
Returns:
Dict with keys:
spoof_probability: float in [0, 1]
bonafide_probability: float in [0, 1]
prediction: "bonafide" or "spoof"
confidence: float in [0, 1] (probability of the predicted class)
utterance_duration_sec: total audio length
n_windows: number of 4-sec windows the audio was split into
window_scores: (only if return_per_window=True) list of per-window spoof probs
"""
# Step 1: Load and resample audio if needed
if isinstance(audio_input, str):
waveform = load_audio(audio_input) # returns 1-D tensor at 16 kHz
elif isinstance(audio_input, np.ndarray):
waveform = torch.from_numpy(audio_input.astype(np.float32))
elif isinstance(audio_input, torch.Tensor):
waveform = audio_input.float()
if waveform.dim() > 1:
waveform = waveform.squeeze()
else:
raise ValueError(
f"audio_input must be str, np.ndarray, or torch.Tensor; got {type(audio_input)}"
)
duration_sec = float(waveform.shape[0] / 16000)
# Step 2: Segment into 4-sec windows
windows = segment_waveform(waveform) # list of 1-D tensors of length 64000
n_windows = len(windows)
# Step 3: Stack into a batch and run inference
batch = torch.stack(windows, dim=0).to(self.device, non_blocking=True)
if self.use_mixed_precision:
with torch.amp.autocast(device_type="cuda", enabled=True):
logits = self.model(batch)
else:
logits = self.model(batch)
# Step 4: Compute per-window probabilities, then aggregate (mean)
probs = torch.softmax(logits.float(), dim=-1).cpu().numpy() # (n_windows, 2)
window_spoof_probs = probs[:, 1].tolist()
utt_spoof_prob = float(np.mean(window_spoof_probs))
utt_bonafide_prob = 1.0 - utt_spoof_prob
# Step 5: Apply threshold for hard prediction
prediction = "spoof" if utt_spoof_prob > self.threshold else "bonafide"
confidence = utt_spoof_prob if prediction == "spoof" else utt_bonafide_prob
result = {
"spoof_probability": utt_spoof_prob,
"bonafide_probability": utt_bonafide_prob,
"prediction": prediction,
"confidence": confidence,
"utterance_duration_sec": duration_sec,
"n_windows": n_windows,
"threshold_used": self.threshold,
}
if return_per_window:
result["window_scores"] = window_spoof_probs
return result
def info(self) -> Dict:
"""Return metadata about this model checkpoint."""
return {
**self.checkpoint_metadata,
"device": self.device,
"threshold": self.threshold,
"mixed_precision": self.use_mixed_precision,
}