omnivoice-demo / app.py
Batuka0901's picture
Update app.py
1a8fb2f verified
from __future__ import annotations
import datetime as dt
import logging
import os
import uuid
from typing import Any
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import HfApi
from omnivoice import OmniVoice, OmniVoiceGenerationConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
log = logging.getLogger("app")
CHECKPOINT = os.getenv("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "Batuka0901/omnivoice_user_voices")
HF_TOKEN = os.getenv("HF_TOKEN")
print(f"[init] loading OmniVoice from {CHECKPOINT} (CPU at startup) ...")
model = OmniVoice.from_pretrained(
CHECKPOINT,
dtype=torch.float16,
load_asr=True,
)
SR = model.sampling_rate
print(f"[init] model loaded on CPU; sr={SR}")
_moved_to_cuda = False
SAMPLE_TEXT = (
"Өглөөний наран уулын цаанаас аажмаар мандаж, байгаль дэлхий "
"нойрноосоо сэрж эхэллээ. Ойн гүнд шувуудын жиргээ намуухан "
"сонсогдож, тунгалаг горхины чимээ түүнтэй хослон урсна. Хөвсгөр "
"цагаан үүлс тэнгэрт алгуур нүүж, цэвэр агаар цээж дүүрэн мэдрэгдэх "
"нь юутай таатай. Энэхүү амар амгалан орчинд хүн бүхний сэтгэл "
"тэнийж, ирээдүйн сайн сайханд итгэх итгэл улам бүр нэмэгдсээр "
"байна."
)
def _gen_config():
return OmniVoiceGenerationConfig(
num_step=32,
guidance_scale=2.0,
denoise=True,
preprocess_prompt=True,
postprocess_output=True,
)
@spaces.GPU(duration=120)
def synthesize(target_text: str, ref_audio_path: str | None):
global _moved_to_cuda
if not target_text or not target_text.strip():
raise gr.Error("Хэлүүлэх текстээ оруулна уу.")
if not ref_audio_path:
raise gr.Error("Өөрийн дуу хоолой оруулна уу.")
if not _moved_to_cuda and torch.cuda.is_available():
log.info("moving model to cuda ...")
model.to("cuda")
_moved_to_cuda = True
try:
clone_prompt = model.create_voice_clone_prompt(
ref_audio=ref_audio_path,
ref_text=SAMPLE_TEXT,
)
kw: dict[str, Any] = dict(
text=target_text.strip(),
language="Mongolian",
generation_config=_gen_config(),
voice_clone_prompt=clone_prompt,
)
audio = model.generate(**kw)
except Exception as e:
log.exception("generate failed")
raise gr.Error(f"{type(e).__name__}: {e}")
waveform = (audio[0] * 32767).astype(np.int16)
out = (SR, waveform)
try:
_upload_audio_to_dataset(ref_audio_path)
except Exception:
log.exception("dataset upload failed")
return out
def _upload_audio_to_dataset(audio_path: str) -> None:
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN not set — cannot upload to dataset.")
api = HfApi(token=HF_TOKEN)
api.create_repo(
DATASET_REPO_ID, repo_type="dataset",
exist_ok=True, private=True, token=HF_TOKEN,
)
ts = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
rid = uuid.uuid4().hex[:8]
audio_remote = f"audio/{ts}-{rid}.wav"
api.upload_file(
path_or_fileobj=audio_path,
path_in_repo=audio_remote,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
token=HF_TOKEN,
)
DESCRIPTION = """
# OmniVoice — Voice Cloning (Монгол)
**Алхам:**
1. Доорх **жишээ текст**-ийг уншиж **бичлэг хийнэ**.
2. Хэлүүлэх текстээ оруулна.
3. ** Generate** дарах → clone хийгдэн таны оруулсан текстийг таны дуу хоолойгоор үүсгэнэ.
"""
HIDE_FOOTER_CSS = """
footer {visibility: hidden !important; display: none !important;}
.gradio-container .footer {display: none !important;}
"""
with gr.Blocks(title="OmniVoice — Voice Cloning (MN)", css=HIDE_FOOTER_CSS) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
gr.Textbox(
label=" Уншиж жишээ текст (~30 секунд)",
value=SAMPLE_TEXT,
lines=8,
interactive=False,
)
ref_audio = gr.Audio(
label="Чанартай audio гаргахын тулд ~30 секунд болон тод хэлнэ.",
sources=["microphone"],
type="filepath",
)
with gr.Column(scale=1):
target_text = gr.Textbox(
label=" Хэлүүлэх текст",
lines=4,
placeholder="Энэ текстийг таны хоолойгоор уншиж өгнө...",
)
btn = gr.Button(" Generate", variant="primary", size="lg")
out_audio = gr.Audio(label="Гаралт", type="numpy")
btn.click(synthesize, inputs=[target_text, ref_audio], outputs=out_audio)
if __name__ == "__main__":
demo.queue().launch(show_api=False)