File size: 6,047 Bytes
2f9a835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Inference module for deepfake audio detection.

Wraps the Stage 2 Wav2Vec 2.0 classifier with a clean public API.

Usage:
    from src.inference.predict import DeepfakeDetector
    detector = DeepfakeDetector(checkpoint_path="path/to/stage2_best.pt")
    result = detector.predict("path/to/audio.wav")
    print(result)
    # {"spoof_probability": 0.84, "prediction": "spoof", "confidence": 0.84,
    #  "utterance_duration_sec": 3.42, "n_windows": 1, "model_version": "stage2"}
"""

import os
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
import numpy as np

from src.models.wav2vec_classifier import Wav2VecClassifier
from src.data.preprocessing import load_audio, segment_waveform, WINDOW_SAMPLES


# Default classifier threshold. 0.5 is naive; we tuned it during eval.
# Values closer to 0.5 = balanced; lower = more sensitive (more false alarms);
# higher = more conservative (more misses).
DEFAULT_THRESHOLD = 0.5


class DeepfakeDetector:
    """Anti-spoofing classifier wrapper for one-shot inference."""

    def __init__(
        self,
        checkpoint_path: str,
        device: Optional[str] = None,
        backbone_name: str = "facebook/wav2vec2-base",
        threshold: float = DEFAULT_THRESHOLD,
        use_mixed_precision: bool = True,
    ):
        """
        Args:
            checkpoint_path: path to a Stage 2 .pt checkpoint
            device: 'cuda', 'cpu', or None (auto-detect)
            backbone_name: HuggingFace model name for Wav2Vec backbone
            threshold: probability threshold above which we predict "spoof"
            use_mixed_precision: use fp16 inference (faster on GPU)
        """
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.threshold = threshold
        self.use_mixed_precision = use_mixed_precision and (device == "cuda")

        # Build model and load weights
        self.model = Wav2VecClassifier(
            backbone_name=backbone_name,
            num_classes=2,
            freeze_backbone=True,
        )
        ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
        self.model.load_state_dict(ckpt["model_state_dict"])
        self.model = self.model.to(device)
        self.model.eval()

        # Store metadata for transparency
        self.checkpoint_metadata = {
            "epoch": ckpt.get("epoch"),
            "best_eer": ckpt.get("best_eer"),
            "checkpoint_path": checkpoint_path,
        }

    @torch.no_grad()
    def predict(
        self,
        audio_input: Union[str, torch.Tensor, np.ndarray],
        return_per_window: bool = False,
    ) -> Dict:
        """Predict bonafide vs spoof for a single audio input.

        Args:
            audio_input: either a file path (str), a 1-D Tensor at 16 kHz, or
                         a 1-D numpy array at 16 kHz.
            return_per_window: if True, include per-window probabilities in
                               the result for debugging.

        Returns:
            Dict with keys:
                spoof_probability: float in [0, 1]
                bonafide_probability: float in [0, 1]
                prediction: "bonafide" or "spoof"
                confidence: float in [0, 1] (probability of the predicted class)
                utterance_duration_sec: total audio length
                n_windows: number of 4-sec windows the audio was split into
                window_scores: (only if return_per_window=True) list of per-window spoof probs
        """
        # Step 1: Load and resample audio if needed
        if isinstance(audio_input, str):
            waveform = load_audio(audio_input)  # returns 1-D tensor at 16 kHz
        elif isinstance(audio_input, np.ndarray):
            waveform = torch.from_numpy(audio_input.astype(np.float32))
        elif isinstance(audio_input, torch.Tensor):
            waveform = audio_input.float()
            if waveform.dim() > 1:
                waveform = waveform.squeeze()
        else:
            raise ValueError(
                f"audio_input must be str, np.ndarray, or torch.Tensor; got {type(audio_input)}"
            )

        duration_sec = float(waveform.shape[0] / 16000)

        # Step 2: Segment into 4-sec windows
        windows = segment_waveform(waveform)  # list of 1-D tensors of length 64000
        n_windows = len(windows)

        # Step 3: Stack into a batch and run inference
        batch = torch.stack(windows, dim=0).to(self.device, non_blocking=True)

        if self.use_mixed_precision:
            with torch.amp.autocast(device_type="cuda", enabled=True):
                logits = self.model(batch)
        else:
            logits = self.model(batch)

        # Step 4: Compute per-window probabilities, then aggregate (mean)
        probs = torch.softmax(logits.float(), dim=-1).cpu().numpy()  # (n_windows, 2)
        window_spoof_probs = probs[:, 1].tolist()
        utt_spoof_prob = float(np.mean(window_spoof_probs))
        utt_bonafide_prob = 1.0 - utt_spoof_prob

        # Step 5: Apply threshold for hard prediction
        prediction = "spoof" if utt_spoof_prob > self.threshold else "bonafide"
        confidence = utt_spoof_prob if prediction == "spoof" else utt_bonafide_prob

        result = {
            "spoof_probability": utt_spoof_prob,
            "bonafide_probability": utt_bonafide_prob,
            "prediction": prediction,
            "confidence": confidence,
            "utterance_duration_sec": duration_sec,
            "n_windows": n_windows,
            "threshold_used": self.threshold,
        }
        if return_per_window:
            result["window_scores"] = window_spoof_probs
        return result

    def info(self) -> Dict:
        """Return metadata about this model checkpoint."""
        return {
            **self.checkpoint_metadata,
            "device": self.device,
            "threshold": self.threshold,
            "mixed_precision": self.use_mixed_precision,
        }