| """ |
| Inference module for deepfake audio detection. |
| |
| Wraps the Stage 2 Wav2Vec 2.0 classifier with a clean public API. |
| |
| Usage: |
| from src.inference.predict import DeepfakeDetector |
| detector = DeepfakeDetector(checkpoint_path="path/to/stage2_best.pt") |
| result = detector.predict("path/to/audio.wav") |
| print(result) |
| # {"spoof_probability": 0.84, "prediction": "spoof", "confidence": 0.84, |
| # "utterance_duration_sec": 3.42, "n_windows": 1, "model_version": "stage2"} |
| """ |
|
|
| import os |
| from typing import Dict, Optional, Union |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| from src.models.wav2vec_classifier import Wav2VecClassifier |
| from src.data.preprocessing import load_audio, segment_waveform, WINDOW_SAMPLES |
|
|
|
|
| |
| |
| |
| DEFAULT_THRESHOLD = 0.5 |
|
|
|
|
| class DeepfakeDetector: |
| """Anti-spoofing classifier wrapper for one-shot inference.""" |
|
|
| def __init__( |
| self, |
| checkpoint_path: str, |
| device: Optional[str] = None, |
| backbone_name: str = "facebook/wav2vec2-base", |
| threshold: float = DEFAULT_THRESHOLD, |
| use_mixed_precision: bool = True, |
| ): |
| """ |
| Args: |
| checkpoint_path: path to a Stage 2 .pt checkpoint |
| device: 'cuda', 'cpu', or None (auto-detect) |
| backbone_name: HuggingFace model name for Wav2Vec backbone |
| threshold: probability threshold above which we predict "spoof" |
| use_mixed_precision: use fp16 inference (faster on GPU) |
| """ |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.device = device |
| self.threshold = threshold |
| self.use_mixed_precision = use_mixed_precision and (device == "cuda") |
|
|
| |
| self.model = Wav2VecClassifier( |
| backbone_name=backbone_name, |
| num_classes=2, |
| freeze_backbone=True, |
| ) |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| self.model.load_state_dict(ckpt["model_state_dict"]) |
| self.model = self.model.to(device) |
| self.model.eval() |
|
|
| |
| self.checkpoint_metadata = { |
| "epoch": ckpt.get("epoch"), |
| "best_eer": ckpt.get("best_eer"), |
| "checkpoint_path": checkpoint_path, |
| } |
|
|
| @torch.no_grad() |
| def predict( |
| self, |
| audio_input: Union[str, torch.Tensor, np.ndarray], |
| return_per_window: bool = False, |
| ) -> Dict: |
| """Predict bonafide vs spoof for a single audio input. |
| |
| Args: |
| audio_input: either a file path (str), a 1-D Tensor at 16 kHz, or |
| a 1-D numpy array at 16 kHz. |
| return_per_window: if True, include per-window probabilities in |
| the result for debugging. |
| |
| Returns: |
| Dict with keys: |
| spoof_probability: float in [0, 1] |
| bonafide_probability: float in [0, 1] |
| prediction: "bonafide" or "spoof" |
| confidence: float in [0, 1] (probability of the predicted class) |
| utterance_duration_sec: total audio length |
| n_windows: number of 4-sec windows the audio was split into |
| window_scores: (only if return_per_window=True) list of per-window spoof probs |
| """ |
| |
| if isinstance(audio_input, str): |
| waveform = load_audio(audio_input) |
| elif isinstance(audio_input, np.ndarray): |
| waveform = torch.from_numpy(audio_input.astype(np.float32)) |
| elif isinstance(audio_input, torch.Tensor): |
| waveform = audio_input.float() |
| if waveform.dim() > 1: |
| waveform = waveform.squeeze() |
| else: |
| raise ValueError( |
| f"audio_input must be str, np.ndarray, or torch.Tensor; got {type(audio_input)}" |
| ) |
|
|
| duration_sec = float(waveform.shape[0] / 16000) |
|
|
| |
| windows = segment_waveform(waveform) |
| n_windows = len(windows) |
|
|
| |
| batch = torch.stack(windows, dim=0).to(self.device, non_blocking=True) |
|
|
| if self.use_mixed_precision: |
| with torch.amp.autocast(device_type="cuda", enabled=True): |
| logits = self.model(batch) |
| else: |
| logits = self.model(batch) |
|
|
| |
| probs = torch.softmax(logits.float(), dim=-1).cpu().numpy() |
| window_spoof_probs = probs[:, 1].tolist() |
| utt_spoof_prob = float(np.mean(window_spoof_probs)) |
| utt_bonafide_prob = 1.0 - utt_spoof_prob |
|
|
| |
| prediction = "spoof" if utt_spoof_prob > self.threshold else "bonafide" |
| confidence = utt_spoof_prob if prediction == "spoof" else utt_bonafide_prob |
|
|
| result = { |
| "spoof_probability": utt_spoof_prob, |
| "bonafide_probability": utt_bonafide_prob, |
| "prediction": prediction, |
| "confidence": confidence, |
| "utterance_duration_sec": duration_sec, |
| "n_windows": n_windows, |
| "threshold_used": self.threshold, |
| } |
| if return_per_window: |
| result["window_scores"] = window_spoof_probs |
| return result |
|
|
| def info(self) -> Dict: |
| """Return metadata about this model checkpoint.""" |
| return { |
| **self.checkpoint_metadata, |
| "device": self.device, |
| "threshold": self.threshold, |
| "mixed_precision": self.use_mixed_precision, |
| } |
|
|