| """LibriSpeech audio data preparation with AudioVQEncoder targets. |
| |
| Dataset: https://huggingface.co/datasets/openslr/librispeech_asr |
| Storage: ~6 GB (FLAC compressed) |
| License: CC-BY-4.0 |
| """ |
| import torch, torchaudio |
| from dataclasses import dataclass |
| import sys, os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| from arbitor.config import SPECIAL_VOCAB, AUDIO_SR |
|
|
|
|
| @dataclass |
| class LibriSpeechConfig: |
| subset: str = "clean" |
| batch_size: int = 4 |
| shuffle_buffer: int = 5000 |
| split: str = "train" |
| max_sec: int = 5 |
| audio_sr: int = AUDIO_SR |
|
|
|
|
| class LibriSpeechStream: |
| def __init__(self, cfg: LibriSpeechConfig): |
| self.cfg = cfg |
| self._ds = None |
| self._vq_enc = None |
| self._bos = SPECIAL_VOCAB['BOS'] |
| self._eos = SPECIAL_VOCAB['EOS'] |
| self._pad = SPECIAL_VOCAB['PAD'] |
|
|
| def _lazy_init(self): |
| if self._ds is not None: |
| return |
| from datasets import load_dataset |
| conf = "clean" if self.cfg.subset == "clean" else "other" |
| subset_map = {"train": f"train.{conf}.100", |
| "dev": f"dev.{conf}", |
| "test": f"test.{conf}"} |
| ds = load_dataset("openslr/librispeech_asr", |
| subset_map.get(self.cfg.split, "train.clean.100"), |
| split="train", streaming=True, trust_remote_code=True) |
| self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42) |
|
|
| def _get_vq_encoder(self): |
| if self._vq_enc is None: |
| from arbitor.encoders.audio import AudioVQEncoder |
| self._vq_enc = AudioVQEncoder().cuda() |
| self._vq_enc.eval() |
| return self._vq_enc |
|
|
| def batches(self): |
| self._lazy_init() |
| buf = [] |
| for example in self._ds: |
| audio = example["audio"] |
| wav = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0) |
| sr = audio["sampling_rate"] |
| if sr != self.cfg.audio_sr: |
| resample = torchaudio.transforms.Resample(sr, self.cfg.audio_sr) |
| wav = resample(wav) |
| max_samples = self.cfg.max_sec * self.cfg.audio_sr |
| if wav.shape[1] < self.cfg.audio_sr: |
| continue |
| wav = wav[:, :max_samples] |
| text = example.get("text", "") |
| raw = text.encode("utf-8") |
| text_tokens = [self._bos] + list(raw) + [self._eos] |
| text_t = torch.tensor(text_tokens, dtype=torch.long) |
|
|
| vq_enc = self._get_vq_encoder() |
| with torch.no_grad(): |
| _, vq_tokens = vq_enc(wav.unsqueeze(0).cuda()) |
| vq_tokens = vq_tokens.squeeze(0).cpu() |
| vq_t = torch.cat([text_t.new_tensor([self._bos]), vq_tokens, |
| text_t.new_tensor([self._eos])]) |
| buf.append((wav, vq_t)) |
| if len(buf) >= self.cfg.batch_size: |
| batch = buf[:self.cfg.batch_size] |
| buf = buf[self.cfg.batch_size:] |
| waves = torch.stack([b[0] for b in batch]) |
| max_len = max(b[1].numel() for b in batch) |
| targets = torch.stack([ |
| torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)]) |
| for b in batch |
| ]) |
| yield waves, targets |
|
|
| def num_samples(self) -> int: |
| return 290_000 |
|
|