| """ |
| PyTorch Dataset for ASVspoof 2019 LA. |
| |
| One Dataset entry = one 4-second audio window + its binary label. |
| Long utterances expand into multiple entries (one per window). |
| """ |
|
|
| from typing import List, Tuple |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from src.data.protocols import Utterance |
| from src.data.preprocessing import ( |
| load_audio, |
| segment_waveform, |
| WINDOW_SAMPLES, |
| HOP_SAMPLES, |
| ) |
|
|
|
|
| def _count_windows(duration_samples: int) -> int: |
| if duration_samples <= WINDOW_SAMPLES: |
| return 1 |
| n = (duration_samples - WINDOW_SAMPLES + HOP_SAMPLES - 1) // HOP_SAMPLES + 1 |
| return max(1, n) |
|
|
|
|
| class ASVspoofDataset(Dataset): |
| """One sample = one 4-sec window + its label.""" |
|
|
| def __init__( |
| self, |
| utterances: List[Utterance], |
| durations_samples: List[int] = None, |
| ): |
| self.utterances = utterances |
| self.index: List[Tuple[int, int]] = [] |
| if durations_samples is None: |
| for i in range(len(utterances)): |
| self.index.append((i, 0)) |
| else: |
| assert len(durations_samples) == len(utterances) |
| for i, d in enumerate(durations_samples): |
| n_windows = _count_windows(d) |
| for w in range(n_windows): |
| self.index.append((i, w)) |
|
|
| def __len__(self) -> int: |
| return len(self.index) |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, str]: |
| utt_idx, window_idx = self.index[idx] |
| utt = self.utterances[utt_idx] |
| waveform = load_audio(utt.flac_path) |
| windows = segment_waveform(waveform) |
| window_idx = min(window_idx, len(windows) - 1) |
| return windows[window_idx], utt.label_int, utt.utterance_id |
|
|