File size: 3,422 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """LibriSpeech audio data preparation with AudioVQEncoder targets.
Dataset: https://huggingface.co/datasets/openslr/librispeech_asr
Storage: ~6 GB (FLAC compressed)
License: CC-BY-4.0
"""
import torch, torchaudio
from dataclasses import dataclass
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB, AUDIO_SR
@dataclass
class LibriSpeechConfig:
subset: str = "clean"
batch_size: int = 4
shuffle_buffer: int = 5000
split: str = "train"
max_sec: int = 5
audio_sr: int = AUDIO_SR
class LibriSpeechStream:
def __init__(self, cfg: LibriSpeechConfig):
self.cfg = cfg
self._ds = None
self._vq_enc = None
self._bos = SPECIAL_VOCAB['BOS']
self._eos = SPECIAL_VOCAB['EOS']
self._pad = SPECIAL_VOCAB['PAD']
def _lazy_init(self):
if self._ds is not None:
return
from datasets import load_dataset
conf = "clean" if self.cfg.subset == "clean" else "other"
subset_map = {"train": f"train.{conf}.100",
"dev": f"dev.{conf}",
"test": f"test.{conf}"}
ds = load_dataset("openslr/librispeech_asr",
subset_map.get(self.cfg.split, "train.clean.100"),
split="train", streaming=True, trust_remote_code=True)
self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)
def _get_vq_encoder(self):
if self._vq_enc is None:
from arbitor.encoders.audio import AudioVQEncoder
self._vq_enc = AudioVQEncoder().cuda()
self._vq_enc.eval()
return self._vq_enc
def batches(self):
self._lazy_init()
buf = []
for example in self._ds:
audio = example["audio"]
wav = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0)
sr = audio["sampling_rate"]
if sr != self.cfg.audio_sr:
resample = torchaudio.transforms.Resample(sr, self.cfg.audio_sr)
wav = resample(wav)
max_samples = self.cfg.max_sec * self.cfg.audio_sr
if wav.shape[1] < self.cfg.audio_sr:
continue
wav = wav[:, :max_samples]
text = example.get("text", "")
raw = text.encode("utf-8")
text_tokens = [self._bos] + list(raw) + [self._eos]
text_t = torch.tensor(text_tokens, dtype=torch.long)
vq_enc = self._get_vq_encoder()
with torch.no_grad():
_, vq_tokens = vq_enc(wav.unsqueeze(0).cuda())
vq_tokens = vq_tokens.squeeze(0).cpu()
vq_t = torch.cat([text_t.new_tensor([self._bos]), vq_tokens,
text_t.new_tensor([self._eos])])
buf.append((wav, vq_t))
if len(buf) >= self.cfg.batch_size:
batch = buf[:self.cfg.batch_size]
buf = buf[self.cfg.batch_size:]
waves = torch.stack([b[0] for b in batch])
max_len = max(b[1].numel() for b in batch)
targets = torch.stack([
torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)])
for b in batch
])
yield waves, targets
def num_samples(self) -> int:
return 290_000
|