Saracasm
Phase 2: add preprocessing and dataset modules
b95ff5a
"""
PyTorch Dataset for ASVspoof 2019 LA.
One Dataset entry = one 4-second audio window + its binary label.
Long utterances expand into multiple entries (one per window).
"""
from typing import List, Tuple
import torch
from torch.utils.data import Dataset
from src.data.protocols import Utterance
from src.data.preprocessing import (
load_audio,
segment_waveform,
WINDOW_SAMPLES,
HOP_SAMPLES,
)
def _count_windows(duration_samples: int) -> int:
if duration_samples <= WINDOW_SAMPLES:
return 1
n = (duration_samples - WINDOW_SAMPLES + HOP_SAMPLES - 1) // HOP_SAMPLES + 1
return max(1, n)
class ASVspoofDataset(Dataset):
"""One sample = one 4-sec window + its label."""
def __init__(
self,
utterances: List[Utterance],
durations_samples: List[int] = None,
):
self.utterances = utterances
self.index: List[Tuple[int, int]] = []
if durations_samples is None:
for i in range(len(utterances)):
self.index.append((i, 0))
else:
assert len(durations_samples) == len(utterances)
for i, d in enumerate(durations_samples):
n_windows = _count_windows(d)
for w in range(n_windows):
self.index.append((i, w))
def __len__(self) -> int:
return len(self.index)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, str]:
utt_idx, window_idx = self.index[idx]
utt = self.utterances[utt_idx]
waveform = load_audio(utt.flac_path)
windows = segment_waveform(waveform)
window_idx = min(window_idx, len(windows) - 1)
return windows[window_idx], utt.label_int, utt.utterance_id