File size: 3,422 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""LibriSpeech audio data preparation with AudioVQEncoder targets.

Dataset: https://huggingface.co/datasets/openslr/librispeech_asr
Storage: ~6 GB (FLAC compressed)
License: CC-BY-4.0
"""
import torch, torchaudio
from dataclasses import dataclass
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import SPECIAL_VOCAB, AUDIO_SR


@dataclass
class LibriSpeechConfig:
    subset: str = "clean"
    batch_size: int = 4
    shuffle_buffer: int = 5000
    split: str = "train"
    max_sec: int = 5
    audio_sr: int = AUDIO_SR


class LibriSpeechStream:
    def __init__(self, cfg: LibriSpeechConfig):
        self.cfg = cfg
        self._ds = None
        self._vq_enc = None
        self._bos = SPECIAL_VOCAB['BOS']
        self._eos = SPECIAL_VOCAB['EOS']
        self._pad = SPECIAL_VOCAB['PAD']

    def _lazy_init(self):
        if self._ds is not None:
            return
        from datasets import load_dataset
        conf = "clean" if self.cfg.subset == "clean" else "other"
        subset_map = {"train": f"train.{conf}.100",
                      "dev": f"dev.{conf}",
                      "test": f"test.{conf}"}
        ds = load_dataset("openslr/librispeech_asr",
                          subset_map.get(self.cfg.split, "train.clean.100"),
                          split="train", streaming=True, trust_remote_code=True)
        self._ds = ds.shuffle(buffer_size=self.cfg.shuffle_buffer, seed=42)

    def _get_vq_encoder(self):
        if self._vq_enc is None:
            from arbitor.encoders.audio import AudioVQEncoder
            self._vq_enc = AudioVQEncoder().cuda()
            self._vq_enc.eval()
        return self._vq_enc

    def batches(self):
        self._lazy_init()
        buf = []
        for example in self._ds:
            audio = example["audio"]
            wav = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0)
            sr = audio["sampling_rate"]
            if sr != self.cfg.audio_sr:
                resample = torchaudio.transforms.Resample(sr, self.cfg.audio_sr)
                wav = resample(wav)
            max_samples = self.cfg.max_sec * self.cfg.audio_sr
            if wav.shape[1] < self.cfg.audio_sr:
                continue
            wav = wav[:, :max_samples]
            text = example.get("text", "")
            raw = text.encode("utf-8")
            text_tokens = [self._bos] + list(raw) + [self._eos]
            text_t = torch.tensor(text_tokens, dtype=torch.long)

            vq_enc = self._get_vq_encoder()
            with torch.no_grad():
                _, vq_tokens = vq_enc(wav.unsqueeze(0).cuda())
                vq_tokens = vq_tokens.squeeze(0).cpu()
            vq_t = torch.cat([text_t.new_tensor([self._bos]), vq_tokens,
                              text_t.new_tensor([self._eos])])
            buf.append((wav, vq_t))
            if len(buf) >= self.cfg.batch_size:
                batch = buf[:self.cfg.batch_size]
                buf = buf[self.cfg.batch_size:]
                waves = torch.stack([b[0] for b in batch])
                max_len = max(b[1].numel() for b in batch)
                targets = torch.stack([
                    torch.cat([b[1], b[1].new_full((max_len - b[1].numel(),), self._pad)])
                    for b in batch
                ])
                yield waves, targets

    def num_samples(self) -> int:
        return 290_000