File size: 2,781 Bytes
0cf2f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
"""Hugging Face Space entry point for the OmniVoice MLX demo."""

import logging
import os
from typing import Any, Dict

import numpy as np

from omnivoice.cli.demo import build_demo
from omnivoice.mlx import OmniVoiceMLX
from omnivoice.models.omnivoice import OmniVoiceGenerationConfig

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(name)s %(levelname)s: %(message)s",
)
logging.getLogger("omnivoice").setLevel(logging.INFO)

os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")

CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "mlx-community/OmniVoice-4bit")
DTYPE = os.environ.get("OMNIVOICE_DTYPE", "float16")
AUDIO_TOKENIZER_DEVICE = os.environ.get("OMNIVOICE_AUDIO_TOKENIZER_DEVICE", "cpu")

print(f"Loading OmniVoice MLX model from {CHECKPOINT} ...", flush=True)
model = OmniVoiceMLX.from_pretrained(
    CHECKPOINT,
    dtype=DTYPE,
    audio_tokenizer_device=AUDIO_TOKENIZER_DEVICE,
)
sampling_rate = model.sampling_rate
print("OmniVoice MLX model loaded.", flush=True)


def _gen_core(
    text,
    language,
    ref_audio,
    instruct,
    num_step,
    guidance_scale,
    denoise,
    speed,
    duration,
    preprocess_prompt,
    postprocess_output,
    mode,
    ref_text=None,
):
    if not text or not text.strip():
        return None, "Please enter the text to synthesize."

    gen_config = OmniVoiceGenerationConfig(
        num_step=int(num_step or 32),
        guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0,
        denoise=bool(denoise) if denoise is not None else True,
        preprocess_prompt=bool(preprocess_prompt),
        postprocess_output=bool(postprocess_output),
    )

    lang = language if (language and language != "Auto") else None
    kw: Dict[str, Any] = dict(
        text=text.strip(),
        language=lang,
        generation_config=gen_config,
    )

    if speed is not None and float(speed) != 1.0:
        kw["speed"] = float(speed)
    if duration is not None and float(duration) > 0:
        kw["duration"] = float(duration)

    if mode == "clone":
        if not ref_audio:
            return None, "Please upload a reference audio."
        kw["ref_audio"] = ref_audio
        if ref_text and ref_text.strip():
            kw["ref_text"] = ref_text.strip()

    if instruct and instruct.strip():
        kw["instruct"] = instruct.strip()

    try:
        audio = model.generate(**kw)
    except Exception as exc:
        return None, f"Error: {type(exc).__name__}: {exc}"

    waveform = np.clip(audio[0], -1.0, 1.0)
    waveform = (waveform * 32767).astype(np.int16)
    return (sampling_rate, waveform), "Done."


demo = build_demo(model, CHECKPOINT, generate_fn=_gen_core)

if __name__ == "__main__":
    demo.queue(max_size=8).launch()