OmniVoice / omnivoice /data /batching.py
zhu-han's picture
Upload 48 files
aa79b9c verified
#!/usr/bin/env python3
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching strategies for streaming/iterable datasets.
Provides length-based grouping and packing for efficient training with
variable-length audio.
Key classes:
- ``PackingIterableDataset``: Packs multiple samples into fixed-length sequences
for training. Used by ``omnivoice.training.builder``.
- ``StreamLengthGroupDataset``: Groups samples by length into buckets. Used by
data processing scripts (e.g. ``omnivoice/scripts/``).
"""
import bisect
import logging
from typing import Any, Dict, Iterator, List, Optional
import numpy as np
from omnivoice.data.dataset import IterableDataReader, WrappedIterableDataset
class StreamLengthGroupDataset(WrappedIterableDataset):
"""A streaming dataset that groups samples by their lengths into buckets.
Only support audio data for now."""
def __init__(
self,
dataset: IterableDataReader,
batch_duration: float,
min_length: float = 0.5,
max_length: float = 30.0,
num_buckets: int = 20,
audio_key: str = "audio",
drop_last: bool = False,
max_sample: Optional[int] = None,
):
self.dataset = dataset
self.batch_duration = batch_duration
self.min_length = min_length
self.max_length = max_length
self.num_buckets = num_buckets
self.audio_key = audio_key
self.drop_last = drop_last
self.max_sample = max_sample if max_sample is not None else float("inf")
self.boundaries = np.linspace(min_length, max_length, num_buckets + 1)[1:]
def set_epoch(self, epoch: int):
"""
Set the epoch for shuffling.
"""
self.dataset.set_epoch(epoch)
def _get_bucket_id(self, length: float) -> int:
return bisect.bisect_left(self.boundaries, length)
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
buckets = [[] for _ in range(self.num_buckets)]
bucket_max_len = [0.0] * self.num_buckets
for sample in self.dataset:
audio = sample[self.audio_key]
duration = audio.size(-1) / self.dataset.sample_rate
if duration < self.min_length or duration > self.max_length:
# logging.warning(f"Skipping sample with duration {duration:.2f}s")
continue
b_id = self._get_bucket_id(duration)
buckets[b_id].append(sample)
if duration > bucket_max_len[b_id]:
bucket_max_len[b_id] = duration
if (
bucket_max_len[b_id] * (len(buckets[b_id]) + 1) >= self.batch_duration
or len(buckets[b_id]) >= self.max_sample
):
yield buckets[b_id]
buckets[b_id] = []
bucket_max_len[b_id] = 0.0
if not self.drop_last:
for b_idx, bucket in enumerate(buckets):
if bucket:
yield bucket
buckets[b_idx] = []
class PackingIterableDataset(WrappedIterableDataset):
"""
An IterableDataset that dynamically processes samples using a processor
and packs them into batches based on the real token count.
Args:
dataset (Iterable): The raw dataset to process.
processor (Callable): A processor to process each sample.
batch_tokens (int): Maximum number of tokens per batch.
"""
def __init__(
self,
dataset: IterableDataReader,
processor: Any,
batch_tokens: int,
):
self.dataset = dataset
self.processor = processor
self.batch_tokens = batch_tokens
self.skip_batches = 0
def set_epoch(self, epoch: int):
"""
Set the epoch for shuffling.
"""
self.dataset.set_epoch(epoch)
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
current_batch = []
current_token_count = 0
for raw_sample in self.dataset:
# Process the sample using the processor
try:
processed_sample = self.processor(raw_sample)
except Exception as e:
logging.warning(f"Error processing sample {raw_sample}: {e}")
continue
sample_length = processed_sample["length"]
if sample_length > self.batch_tokens:
continue
# Check if adding this sample exceeds the batch token limit
if current_token_count + sample_length > self.batch_tokens:
# Yield the current batch and start a new one
yield current_batch
current_batch = []
current_token_count = 0
# Add the processed sample to the current batch
current_batch.append(processed_sample)
current_token_count += sample_length
# Yield the last batch if it's not empty
if current_batch:
yield current_batch