File size: 5,604 Bytes
aa79b9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | #!/usr/bin/env python3
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching strategies for streaming/iterable datasets.
Provides length-based grouping and packing for efficient training with
variable-length audio.
Key classes:
- ``PackingIterableDataset``: Packs multiple samples into fixed-length sequences
for training. Used by ``omnivoice.training.builder``.
- ``StreamLengthGroupDataset``: Groups samples by length into buckets. Used by
data processing scripts (e.g. ``omnivoice/scripts/``).
"""
import bisect
import logging
from typing import Any, Dict, Iterator, List, Optional
import numpy as np
from omnivoice.data.dataset import IterableDataReader, WrappedIterableDataset
class StreamLengthGroupDataset(WrappedIterableDataset):
"""A streaming dataset that groups samples by their lengths into buckets.
Only support audio data for now."""
def __init__(
self,
dataset: IterableDataReader,
batch_duration: float,
min_length: float = 0.5,
max_length: float = 30.0,
num_buckets: int = 20,
audio_key: str = "audio",
drop_last: bool = False,
max_sample: Optional[int] = None,
):
self.dataset = dataset
self.batch_duration = batch_duration
self.min_length = min_length
self.max_length = max_length
self.num_buckets = num_buckets
self.audio_key = audio_key
self.drop_last = drop_last
self.max_sample = max_sample if max_sample is not None else float("inf")
self.boundaries = np.linspace(min_length, max_length, num_buckets + 1)[1:]
def set_epoch(self, epoch: int):
"""
Set the epoch for shuffling.
"""
self.dataset.set_epoch(epoch)
def _get_bucket_id(self, length: float) -> int:
return bisect.bisect_left(self.boundaries, length)
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
buckets = [[] for _ in range(self.num_buckets)]
bucket_max_len = [0.0] * self.num_buckets
for sample in self.dataset:
audio = sample[self.audio_key]
duration = audio.size(-1) / self.dataset.sample_rate
if duration < self.min_length or duration > self.max_length:
# logging.warning(f"Skipping sample with duration {duration:.2f}s")
continue
b_id = self._get_bucket_id(duration)
buckets[b_id].append(sample)
if duration > bucket_max_len[b_id]:
bucket_max_len[b_id] = duration
if (
bucket_max_len[b_id] * (len(buckets[b_id]) + 1) >= self.batch_duration
or len(buckets[b_id]) >= self.max_sample
):
yield buckets[b_id]
buckets[b_id] = []
bucket_max_len[b_id] = 0.0
if not self.drop_last:
for b_idx, bucket in enumerate(buckets):
if bucket:
yield bucket
buckets[b_idx] = []
class PackingIterableDataset(WrappedIterableDataset):
"""
An IterableDataset that dynamically processes samples using a processor
and packs them into batches based on the real token count.
Args:
dataset (Iterable): The raw dataset to process.
processor (Callable): A processor to process each sample.
batch_tokens (int): Maximum number of tokens per batch.
"""
def __init__(
self,
dataset: IterableDataReader,
processor: Any,
batch_tokens: int,
):
self.dataset = dataset
self.processor = processor
self.batch_tokens = batch_tokens
self.skip_batches = 0
def set_epoch(self, epoch: int):
"""
Set the epoch for shuffling.
"""
self.dataset.set_epoch(epoch)
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
current_batch = []
current_token_count = 0
for raw_sample in self.dataset:
# Process the sample using the processor
try:
processed_sample = self.processor(raw_sample)
except Exception as e:
logging.warning(f"Error processing sample {raw_sample}: {e}")
continue
sample_length = processed_sample["length"]
if sample_length > self.batch_tokens:
continue
# Check if adding this sample exceeds the batch token limit
if current_token_count + sample_length > self.batch_tokens:
# Yield the current batch and start a new one
yield current_batch
current_batch = []
current_token_count = 0
# Add the processed sample to the current batch
current_batch.append(processed_sample)
current_token_count += sample_length
# Yield the last batch if it's not empty
if current_batch:
yield current_batch
|