Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Hugging Face Space entry point for the OmniVoice MLX demo.""" | |
| import logging | |
| import os | |
| from typing import Any, Dict | |
| import numpy as np | |
| from omnivoice.cli.demo import build_demo | |
| from omnivoice.mlx import OmniVoiceMLX | |
| from omnivoice.models.omnivoice import OmniVoiceGenerationConfig | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(name)s %(levelname)s: %(message)s", | |
| ) | |
| logging.getLogger("omnivoice").setLevel(logging.INFO) | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "mlx-community/OmniVoice-4bit") | |
| DTYPE = os.environ.get("OMNIVOICE_DTYPE", "float16") | |
| AUDIO_TOKENIZER_DEVICE = os.environ.get("OMNIVOICE_AUDIO_TOKENIZER_DEVICE", "cpu") | |
| print(f"Loading OmniVoice MLX model from {CHECKPOINT} ...", flush=True) | |
| model = OmniVoiceMLX.from_pretrained( | |
| CHECKPOINT, | |
| dtype=DTYPE, | |
| audio_tokenizer_device=AUDIO_TOKENIZER_DEVICE, | |
| ) | |
| sampling_rate = model.sampling_rate | |
| print("OmniVoice MLX model loaded.", flush=True) | |
| def _gen_core( | |
| text, | |
| language, | |
| ref_audio, | |
| instruct, | |
| num_step, | |
| guidance_scale, | |
| denoise, | |
| speed, | |
| duration, | |
| preprocess_prompt, | |
| postprocess_output, | |
| mode, | |
| ref_text=None, | |
| ): | |
| if not text or not text.strip(): | |
| return None, "Please enter the text to synthesize." | |
| gen_config = OmniVoiceGenerationConfig( | |
| num_step=int(num_step or 32), | |
| guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0, | |
| denoise=bool(denoise) if denoise is not None else True, | |
| preprocess_prompt=bool(preprocess_prompt), | |
| postprocess_output=bool(postprocess_output), | |
| ) | |
| lang = language if (language and language != "Auto") else None | |
| kw: Dict[str, Any] = dict( | |
| text=text.strip(), | |
| language=lang, | |
| generation_config=gen_config, | |
| ) | |
| if speed is not None and float(speed) != 1.0: | |
| kw["speed"] = float(speed) | |
| if duration is not None and float(duration) > 0: | |
| kw["duration"] = float(duration) | |
| if mode == "clone": | |
| if not ref_audio: | |
| return None, "Please upload a reference audio." | |
| kw["ref_audio"] = ref_audio | |
| if ref_text and ref_text.strip(): | |
| kw["ref_text"] = ref_text.strip() | |
| if instruct and instruct.strip(): | |
| kw["instruct"] = instruct.strip() | |
| try: | |
| audio = model.generate(**kw) | |
| except Exception as exc: | |
| return None, f"Error: {type(exc).__name__}: {exc}" | |
| waveform = np.clip(audio[0], -1.0, 1.0) | |
| waveform = (waveform * 32767).astype(np.int16) | |
| return (sampling_rate, waveform), "Done." | |
| demo = build_demo(model, CHECKPOINT, generate_fn=_gen_core) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |