| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Batch inference CLI for OmniVoice. |
| |
| Distributes TTS generation across multiple GPUs for large-scale tasks. |
| Reads a JSONL test list, generates audio in parallel, and saves results. |
| |
| Usage: |
| omnivoice-infer-batch --model k2-fsa/OmniVoice \ |
| --test_list test.jsonl --res_dir results/ |
| |
| Test list format (JSONL, one JSON object per line): |
| Required fields: "id", "text" |
| Voice cloning: "ref_audio", "ref_text" |
| Voice design: "instruct" |
| Optional: "language_id", "language_name", "duration", "speed" |
| """ |
|
|
| import argparse |
| import logging |
| import multiprocessing as mp |
| import os |
| import signal |
| import time |
| import traceback |
| from concurrent.futures import ProcessPoolExecutor, as_completed |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from omnivoice.models.omnivoice import OmniVoice |
| import soundfile as sf |
|
|
| from omnivoice.utils.audio import load_audio |
| from omnivoice.utils.common import str2bool |
| from omnivoice.utils.data_utils import read_test_list |
| from omnivoice.utils.duration import RuleDurationEstimator |
|
|
|
|
| def get_best_device(): |
| """Auto-detect the best available device: CUDA > MPS > CPU.""" |
| if torch.cuda.is_available(): |
| return "cuda", torch.cuda.device_count() |
| if torch.backends.mps.is_available(): |
| return "mps", 1 |
| return "cpu", 1 |
|
|
|
|
| worker_model = None |
| SAMPLING_RATE = 24000 |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser(description="Infer OmniVoice Model") |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="k2-fsa/OmniVoice", |
| help="Path to the model checkpoint (local dir or HF repo id). " |
| "Audio tokenizer is expected at <checkpoint>/audio_tokenizer/.", |
| ) |
| parser.add_argument( |
| "--test_list", |
| type=str, |
| required=True, |
| help="Path to the JSONL file containing test samples. " |
| "Each line is a JSON object with the following fields: " |
| '"id" (str, required): unique name for the output file; ' |
| '"text" (str, required): text to synthesize; ' |
| '"ref_audio" (str): path to reference audio for voice cloning; ' |
| '"ref_text" (str): transcript of the reference audio; ' |
| '"instruct" (str): instruction for voice design (used when ref_audio is absent); ' |
| '"language_id" (str): language code, e.g. "en"; ' |
| '"language_name" (str): language name, e.g. "English"; ' |
| '"duration" (float): target duration in seconds; ' |
| '"speed" (float): speaking speed multiplier. ' |
| "Only id and text are required; all other fields are optional.", |
| ) |
| parser.add_argument( |
| "--res_dir", |
| type=str, |
| required=True, |
| help="Directory to save the generated audio files.", |
| ) |
| parser.add_argument( |
| "--num_step", |
| type=int, |
| default=32, |
| help="Number of steps for iterative decoding.", |
| ) |
| parser.add_argument( |
| "--guidance_scale", |
| type=float, |
| default=2.0, |
| help="Scale for Classifier-Free Guidance.", |
| ) |
| parser.add_argument( |
| "--t_shift", |
| type=float, |
| default=0.1, |
| help="Shift t to smaller ones if t_shift < 1.0", |
| ) |
| parser.add_argument( |
| "--nj_per_gpu", |
| type=int, |
| default=1, |
| help="Number of worker processes to spawn per GPU.", |
| ) |
| parser.add_argument( |
| "--audio_chunk_duration", |
| type=float, |
| default=15.0, |
| help="Maximum duration of audio chunk (in seconds) for splitting. " |
| '"Not split" if <= 0.', |
| ) |
| parser.add_argument( |
| "--audio_chunk_threshold", |
| type=float, |
| default=30.0, |
| help=( |
| "The duration threshold (in seconds) to decide" |
| " whether to split audio into chunks." |
| ), |
| ) |
| parser.add_argument( |
| "--batch_duration", |
| type=float, |
| default=1000.0, |
| help="Maximum total duration (reference + generated) per batch (seconds).", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=0, |
| help="Fixed batch size (number of samples per batch). " |
| "If > 0, use fixed-size batching instead of duration-based batching.", |
| ) |
| parser.add_argument( |
| "--warmup", |
| type=int, |
| default=0, |
| help="Number of dummy inference runs per worker before real inference " |
| "starts, to warm up CUDA kernels and caches.", |
| ) |
| parser.add_argument( |
| "--preprocess_prompt", |
| type=str2bool, |
| default=True, |
| help="Whether to preprocess reference audio (silence removal, trimming). " |
| "Set to False to keep raw audio.", |
| ) |
| parser.add_argument( |
| "--postprocess_output", |
| type=str2bool, |
| default=True, |
| help="Whether to post-process generated audio (remove silence).", |
| ) |
| parser.add_argument( |
| "--layer_penalty_factor", |
| type=float, |
| default=5.0, |
| help="The penalty factor for layer-wise sampling.", |
| ) |
| parser.add_argument( |
| "--position_temperature", |
| type=float, |
| default=5.0, |
| help="The temperature for position selection.", |
| ) |
| parser.add_argument( |
| "--class_temperature", |
| type=float, |
| default=0.0, |
| help="The temperature for class token sampling.", |
| ) |
| parser.add_argument( |
| "--denoise", |
| type=str2bool, |
| default=True, |
| help="Whether to add <|denoise|> token in the reference.", |
| ) |
| parser.add_argument( |
| "--lang_id", |
| type=str, |
| default=None, |
| help="Language id to use when test_list JSONL entries do not contain " |
| "language_id/language_name fields. If provided, both language_id and " |
| "language_name will be set to this value.", |
| ) |
| return parser |
|
|
|
|
| def process_init(rank_queue, model_checkpoint, warmup=0): |
| """Initializer for each worker process. |
| |
| Loads model (with tokenizers and duration estimator) onto a specific GPU |
| via ``OmniVoice.from_pretrained()``. |
| """ |
| global worker_model |
|
|
| torch.set_num_threads(2) |
| torch.set_num_interop_threads(2) |
|
|
| formatter = ( |
| "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " |
| "[Worker %(process)d] %(message)s" |
| ) |
| logging.basicConfig(format=formatter, level=logging.INFO, force=True) |
|
|
| rank = rank_queue.get() |
| device_type, device_id = rank |
| if device_type == "cpu": |
| worker_device = "cpu" |
| elif device_type == "mps": |
| worker_device = "mps" |
| else: |
| worker_device = f"cuda:{device_id}" |
|
|
| logging.info(f"Initializing worker on device: {worker_device}") |
|
|
| worker_model = OmniVoice.from_pretrained( |
| model_checkpoint, |
| device_map=worker_device, |
| dtype=torch.float16, |
| ) |
|
|
| if warmup > 0: |
| logging.info(f"Running {warmup} warmup iterations on {worker_device}") |
| dummy_ref_audio = ( |
| torch.randn(1, SAMPLING_RATE), |
| SAMPLING_RATE, |
| ) |
| for i in range(warmup): |
| worker_model.generate( |
| text=["hello"], |
| language=["en"], |
| ref_audio=[dummy_ref_audio], |
| ref_text=["hello"], |
| ) |
| logging.info(f"Warmup complete on {worker_device}") |
|
|
| logging.info(f"Worker on {worker_device} initialized successfully.") |
|
|
|
|
| def estimate_sample_total_duration( |
| duration_estimator: RuleDurationEstimator, |
| text: str, |
| ref_text: Optional[str], |
| ref_audio_path: Optional[str], |
| gen_duration: Optional[float] = None, |
| ) -> float: |
| """Estimate total duration (ref + generated) for a single sample. |
| |
| When ``ref_audio_path`` is ``None`` (instruct / voice-design mode), |
| the reference duration is treated as 0 and only the estimated generated |
| duration contributes to the total. |
| """ |
| if ref_audio_path is not None: |
| ref_wav = load_audio(ref_audio_path, SAMPLING_RATE) |
| ref_duration = ref_wav.shape[-1] / SAMPLING_RATE |
| else: |
| ref_duration = 0 |
|
|
| if gen_duration is None: |
| if ref_audio_path is not None: |
| gen_duration = duration_estimator.estimate_duration( |
| text, ref_text or "", ref_duration, low_threshold=2.0 |
| ) |
| else: |
| gen_duration = duration_estimator.estimate_duration( |
| text, "Nice to meet you.", 0.5, low_threshold=2.0 |
| ) |
|
|
| total_duration = ref_duration + gen_duration |
| return total_duration |
|
|
|
|
| def _sort_samples_by_duration( |
| samples: List[Tuple], |
| duration_estimator: RuleDurationEstimator, |
| ) -> List[Tuple[Tuple, float]]: |
| """Return (sample, total_duration) pairs sorted by duration descending.""" |
| sample_with_duration = [] |
| for sample in samples: |
| _, ref_text, ref_audio_path, text, _, _, dur, _, _ = sample |
| total_duration = estimate_sample_total_duration( |
| duration_estimator, text, ref_text, ref_audio_path, gen_duration=dur |
| ) |
| sample_with_duration.append((sample, total_duration)) |
| sample_with_duration.sort(key=lambda x: x[1], reverse=True) |
| return sample_with_duration |
|
|
|
|
| def cluster_samples_by_duration( |
| samples: List[Tuple], |
| duration_estimator: RuleDurationEstimator, |
| batch_duration: float, |
| ) -> List[List[Tuple]]: |
| sample_with_duration = _sort_samples_by_duration(samples, duration_estimator) |
| batches = [] |
| current_batch = [] |
| current_total_duration = 0.0 |
|
|
| for sample, duration in sample_with_duration: |
| if duration > batch_duration: |
| batches.append([sample]) |
| continue |
|
|
| if current_total_duration + duration <= batch_duration: |
| current_batch.append(sample) |
| current_total_duration += duration |
| else: |
| batches.append(current_batch) |
| current_batch = [sample] |
| current_total_duration = duration |
|
|
| if current_batch: |
| batches.append(current_batch) |
|
|
| logging.info(f"Clustered {len(samples)} samples into {len(batches)} batches") |
| return batches |
|
|
|
|
| def cluster_samples_by_batch_size( |
| samples: List[Tuple], |
| duration_estimator: RuleDurationEstimator, |
| batch_size: int, |
| ) -> List[List[Tuple]]: |
| """Split samples into fixed-size batches, sorted by duration to minimize padding.""" |
| sample_with_duration = _sort_samples_by_duration(samples, duration_estimator) |
| sorted_samples = [s for s, _ in sample_with_duration] |
|
|
| batches = [ |
| sorted_samples[i : i + batch_size] |
| for i in range(0, len(sorted_samples), batch_size) |
| ] |
| logging.info( |
| f"Split {len(samples)} samples into {len(batches)} batches " |
| f"(fixed batch_size={batch_size}, sorted by duration)" |
| ) |
| return batches |
|
|
|
|
| def run_inference_batch( |
| batch_samples: List[Tuple], |
| res_dir: str, |
| **gen_kwargs, |
| ) -> List[Tuple]: |
| global worker_model |
|
|
| save_names = [] |
| ref_texts = [] |
| ref_audio_paths = [] |
| texts = [] |
| langs = [] |
| durations = [] |
| speeds = [] |
| instructs = [] |
|
|
| for sample in batch_samples: |
| save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd, instruct = sample |
| save_names.append(save_name) |
| ref_texts.append(ref_text) |
| ref_audio_paths.append(ref_audio_path) |
| texts.append(text) |
| langs.append(lang_id) |
| durations.append(dur) |
| speeds.append(spd) |
| instructs.append(instruct) |
|
|
| start_time = time.time() |
| audios = worker_model.generate( |
| text=texts, |
| language=langs, |
| ref_audio=ref_audio_paths if any(p is not None for p in ref_audio_paths) else None, |
| ref_text=ref_texts if any(t is not None for t in ref_texts) else None, |
| duration=durations if any(d is not None for d in durations) else None, |
| speed=speeds if any(s is not None for s in speeds) else None, |
| instruct=instructs if any(i is not None for i in instructs) else None, |
| **gen_kwargs, |
| ) |
| batch_synth_time = time.time() - start_time |
|
|
| results = [] |
| for save_name, audio in zip(save_names, audios): |
| save_path = os.path.join(res_dir, save_name + ".wav") |
| sf.write(save_path, audio, worker_model.sampling_rate) |
| audio_duration = audio.shape[-1] / worker_model.sampling_rate |
| results.append( |
| ( |
| save_name, |
| batch_synth_time / len(batch_samples), |
| audio_duration, |
| "success", |
| ) |
| ) |
|
|
| return results |
|
|
|
|
| def main(): |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| logging.basicConfig(format=formatter, level=logging.INFO, force=True) |
| mp.set_start_method("spawn", force=True) |
|
|
| args = get_parser().parse_args() |
| os.makedirs(args.res_dir, exist_ok=True) |
|
|
| device_type, num_devices = get_best_device() |
| if device_type == "cpu": |
| logging.warning( |
| "No GPU found. Falling back to CPU inference. This might be slow." |
| ) |
|
|
| num_processes = num_devices * args.nj_per_gpu |
| logging.info( |
| f"Using {device_type} ({num_devices} device(s))." |
| f" Spawning {num_processes} worker processes." |
| ) |
|
|
| manager = mp.Manager() |
| rank_queue = manager.Queue() |
| for rank in list(range(num_devices)) * args.nj_per_gpu: |
| rank_queue.put((device_type, rank)) |
|
|
| samples_raw = read_test_list(args.test_list) |
| samples = [] |
| for s in samples_raw: |
| if args.lang_id is not None: |
| lang_id = args.lang_id |
| lang_name = args.lang_id |
| else: |
| lang_id = s.get("language_id") |
| lang_name = s.get("language_name") |
| samples.append( |
| ( |
| s["id"], |
| s.get("ref_text"), |
| s.get("ref_audio"), |
| s["text"], |
| lang_id, |
| lang_name, |
| s.get("duration"), |
| s.get("speed"), |
| s.get("instruct"), |
| ) |
| ) |
|
|
| total_synthesis_time = [] |
| total_audio_duration = [] |
|
|
| try: |
| with ProcessPoolExecutor( |
| max_workers=num_processes, |
| initializer=process_init, |
| initargs=(rank_queue, args.model, args.warmup), |
| ) as executor: |
| futures = [] |
|
|
| logging.info("Running batch inference") |
|
|
| |
| |
| |
| |
| clone_samples = [s for s in samples if s[2] is not None] |
| other_samples = [s for s in samples if s[2] is None] |
|
|
| duration_estimator = RuleDurationEstimator() |
| batches = [] |
| for subset in (clone_samples, other_samples): |
| if not subset: |
| continue |
| if args.batch_size > 0: |
| batches.extend( |
| cluster_samples_by_batch_size( |
| subset, duration_estimator, args.batch_size |
| ) |
| ) |
| else: |
| batches.extend( |
| cluster_samples_by_duration( |
| subset, duration_estimator, args.batch_duration |
| ) |
| ) |
|
|
| args_dict = vars(args) |
|
|
| for batch in batches: |
| futures.append( |
| executor.submit( |
| run_inference_batch, batch_samples=batch, **args_dict |
| ) |
| ) |
|
|
| for future in tqdm( |
| as_completed(futures), total=len(futures), desc="Processing samples" |
| ): |
| try: |
| result = future.result() |
| for s_name, synth_time, audio_dur, status in result: |
| total_synthesis_time.append(synth_time) |
| total_audio_duration.append(audio_dur) |
| rtf = synth_time / audio_dur if audio_dur > 0 else float("inf") |
| logging.debug( |
| f"Processed {s_name}: Audio Duration={audio_dur:.2f}s, " |
| f"Synthesis Time={synth_time:.2f}s, RTF={rtf:.4f}" |
| ) |
| except Exception as e: |
| logging.error(f"Failed to process sample: {e}") |
| detailed_error = traceback.format_exc() |
| logging.error(f"Detailed error: {detailed_error}") |
|
|
| except (Exception, KeyboardInterrupt) as e: |
| logging.critical( |
| f"An unrecoverable error occurred: {e}. Terminating all processes." |
| ) |
| detailed_error_info = traceback.format_exc() |
| logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}") |
| os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) |
|
|
| total_synthesis_time = sum(total_synthesis_time) |
| total_audio_duration = sum(total_audio_duration) |
| logging.info("--- Summary ---") |
| logging.info(f"Total audio duration: {total_audio_duration:.2f}s") |
| logging.info(f"Total synthesis time: {total_synthesis_time:.2f}s") |
| if total_audio_duration > 0: |
| average_rtf = total_synthesis_time / total_audio_duration |
| logging.info(f"Average RTF: {average_rtf:.4f}") |
| else: |
| logging.warning("No speech was generated. RTF cannot be computed.") |
|
|
| logging.info("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|