"""LibriSpeech audio data preparation with AudioVQEncoder targets. Dataset: https://huggingface.co/datasets/openslr/librispeech_asr Storage: ~6 GB (FLAC compressed) License: CC-BY-4.0 """ import torch, torchaudio from dataclasses import dataclass import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.config import SPECIAL_VOCAB, AUDIO_SR @dataclass class LibriSpeechConfig: subset: str = "clean" batch_size: int = 4 shuffle_buffer: int = 5000 split: str = "train" max_sec: int = 5 audio_sr: int = AUDIO_SR class LibriSpeechStream: def __init__(self, cfg: LibriSpeechConfig): self.cfg = cfg self._ds = None self._vq_enc = None self._bos = SPECIAL_VOCAB['BOS'] self._eos = SPECIAL_VOCAB['EOS'] self._pad = SPECIAL_VOCAB['PAD'] def _lazy_init(self): if self._ds is not None: return from datasets import load_dataset conf = "clean" if self.cfg.subset == "clean" else "other" subset_map = {"train": f"train.{conf}.100", "dev": f"dev.{conf}", "test": f"test.{conf}"} ds = load_dataset("openslr/librispeech_asr", subset_map.get(self.cfg.split, "train.clean.100"), split="train", streaming=True, trust_remote_code=True) self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) def _get_vq_encoder(self): if self._vq_enc is None: from arbitor.encoders.audio import AudioVQEncoder self._vq_enc = AudioVQEncoder().cuda() self._vq_enc.eval() return self._vq_enc def batches(self): self._lazy_init() buf = [] for example in self._ds: audio = example["audio"] wav = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0) sr = audio["sampling_rate"] if sr != self.cfg.audio_sr: resample = torchaudio.transforms.Resample(sr, self.cfg.audio_sr) wav = resample(wav) max_samples = self.cfg.max_sec * self.cfg.audio_sr if wav.shape[1] < self.cfg.audio_sr: continue wav = wav[:, :max_samples] text = example.get("text", "") raw = text.encode("utf-8") text_tokens = [self._bos] + list(raw) + [self._eos] text_t = torch.tensor(text_tokens, dtype=torch.long) vq_enc = self._get_vq_encoder() with torch.no_grad(): _, vq_tokens = vq_enc(wav.unsqueeze(0).cuda()) vq_tokens = vq_tokens.squeeze(0).cpu() vq_t = torch.cat([text_t.new_tensor([self._bos]), vq_tokens, text_t.new_tensor([self._eos])]) buf.append((wav, vq_t)) if len(buf) >= self.cfg.batch_size: batch = buf[:self.cfg.batch_size] buf = buf[self.cfg.batch_size:] waves = torch.stack([b[0] for b in batch]) max_len = max(b[1].numel() for b in batch) targets = torch.stack([ torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)]) for b in batch ]) yield waves, targets def num_samples(self) -> int: return 290_000