ayf3's picture
Upload app.py with huggingface_hub
44ebeaa verified
#!/usr/bin/env python3
"""
NumberBlocks One Voice Cloning Space - VoxCPM V5
Fix: float32 on CPU + monkey-patch SDPA mask shape for CPU compatibility
Root cause of "Dimension out of range":
MiniCPM4's Attention.forward_step creates a 1D attn_mask but SDPA on CPU
expects at least 2D for proper broadcasting with GQA (Grouped Query Attention).
On GPU, the flash-attention backend handles this; on CPU the math backend does not.
"""
import os
import gradio as gr
import tempfile
import soundfile as sf
import traceback
from pathlib import Path
import torch
import torch.nn.functional as F
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGINGFACE_TOKEN"))
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Monkey-patch: fix SDPA mask shape for CPU
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_original_sdpa = F.scaled_dot_product_attention
def _cpu_safe_sdpa(query, key, value, attn_mask=None, **kwargs):
"""Wrapper that fixes 1D attn_mask for CPU SDPA."""
if attn_mask is not None and attn_mask.dim() == 1 and not torch.cuda.is_available():
# attn_mask is (seq_len,) but SDPA needs (B, H, L, S) or broadcastable
# query shape: (B, H, L, D), key shape: (B, H_kv, S, D)
B, H, L, D = query.shape
S = key.shape[2]
# Reshape 1D mask to (1, 1, 1, S) for proper broadcasting
attn_mask = attn_mask.view(1, 1, 1, S).expand(B, H, L, S)
return _original_sdpa(query, key, value, attn_mask=attn_mask, **kwargs)
# Apply the patch globally
F.scaled_dot_product_attention = _cpu_safe_sdpa
print("โœ… Patched scaled_dot_product_attention for CPU mask shape fix")
def load_model():
"""ๅŠ ่ฝฝ VoxCPM ๆจกๅž‹"""
try:
from voxcpm import VoxCPM
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading VoxCPM model on {device}...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Load model (optimize=False to avoid torch.compile issues)
model = VoxCPM.from_pretrained("openbmb/VoxCPM2", load_denoiser=False, optimize=False)
# CRITICAL FIX: Force float32 on CPU
if device == "cpu":
print("Converting model to float32 for CPU compatibility...")
# Step 1: Change config dtype so _inference creates float32 tensors
if hasattr(model.tts_model, 'config'):
old_dtype = model.tts_model.config.dtype
model.tts_model.config.dtype = "float32"
print(f" config.dtype: {old_dtype} -> float32")
# Step 2: Convert all model parameters and buffers to float32
model.tts_model = model.tts_model.to(torch.float32)
# Step 3: Fix KV caches (created in __init__ with old dtype)
if hasattr(model.tts_model, 'base_lm') and hasattr(model.tts_model.base_lm, 'kv_cache'):
if model.tts_model.base_lm.kv_cache is not None:
model.tts_model.base_lm.kv_cache.kv_cache = model.tts_model.base_lm.kv_cache.kv_cache.to(torch.float32)
print(" base_lm KV cache -> float32")
if hasattr(model.tts_model, 'residual_lm') and hasattr(model.tts_model.residual_lm, 'kv_cache'):
if model.tts_model.residual_lm.kv_cache is not None:
model.tts_model.residual_lm.kv_cache.kv_cache = model.tts_model.residual_lm.kv_cache.kv_cache.to(torch.float32)
print(" residual_lm KV cache -> float32")
print("Model conversion to float32 complete!")
print("Model loaded successfully!")
return model, device, None
except Exception as e:
print(f"Error loading model: {e}")
traceback.print_exc()
return None, "cpu", str(e)
# ๅ…จๅฑ€ๆจกๅž‹็Šถๆ€
MODEL_STATE = {
"model": None,
"device": "cpu",
"error": None,
"loading": False
}
def ensure_model():
"""็กฎไฟๆจกๅž‹ๅทฒๅŠ ่ฝฝ"""
if MODEL_STATE["model"] is None and not MODEL_STATE["loading"]:
MODEL_STATE["loading"] = True
try:
model, device, error = load_model()
MODEL_STATE["model"] = model
MODEL_STATE["device"] = device
MODEL_STATE["error"] = error
except Exception as e:
MODEL_STATE["error"] = str(e)
traceback.print_exc()
finally:
MODEL_STATE["loading"] = False
return MODEL_STATE["model"], MODEL_STATE["device"], MODEL_STATE["error"]
def generate_audio(text, reference_audio, cfg_value=2.0, steps=10):
"""็”Ÿๆˆ้Ÿณ้ข‘"""
if not text or not reference_audio:
return None, "โŒ ่ฏท่พ“ๅ…ฅๆ–‡ๆœฌๅ’Œๅ‚่€ƒ้Ÿณ้ข‘"
if not text.strip():
return None, "โŒ ๆ–‡ๆœฌไธ่ƒฝไธบ็ฉบ"
try:
model, device, error = ensure_model()
if error:
return None, f"โŒ ๆจกๅž‹ๅŠ ่ฝฝๅคฑ่ดฅ: {error}"
if model is None:
return None, "โŒ ๆจกๅž‹ๆญฃๅœจๅŠ ่ฝฝไธญ๏ผŒ่ฏท็จๅ€™..."
# ่ฏปๅ–ๅ‚่€ƒ้Ÿณ้ข‘
ref_audio, sr = sf.read(reference_audio)
# ๅฆ‚ๆžœๆ˜ฏ็ซ‹ไฝ“ๅฃฐ๏ผŒ่ฝฌๆขไธบๅ•ๅฃฐ้“
if len(ref_audio.shape) > 1:
ref_audio = ref_audio[:, 0]
# ไฟๅญ˜ๅˆฐไธดๆ—ถๆ–‡ไปถ
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, ref_audio, sr)
ref_path = tmp.name
print(f"Generating with text: {text[:50]}...")
print(f"Reference audio: {len(ref_audio)/sr:.2f}s at {sr}Hz")
# ็”Ÿๆˆ้Ÿณ้ข‘
import time
t0 = time.time()
wav = model.generate(
text=text,
reference_wav_path=ref_path,
cfg_value=float(cfg_value),
inference_timesteps=int(steps),
)
elapsed = time.time() - t0
# ไฟๅญ˜่พ“ๅ‡บ
sample_rate = model.tts_model.sample_rate
output_path = "/tmp/voxcpm_output.wav"
sf.write(output_path, wav, sample_rate)
duration = len(wav) / sample_rate
msg = f"โœ… ็”ŸๆˆๆˆๅŠŸ! ๆ—ถ้•ฟ: {duration:.2f}s, ่€—ๆ—ถ: {elapsed:.1f}s, ่ฎพๅค‡: {device}"
print(msg)
# ๆธ…็†ไธดๆ—ถๆ–‡ไปถ
os.unlink(ref_path)
return output_path, msg
except Exception as e:
error_msg = f"โŒ ็”Ÿๆˆๅคฑ่ดฅ: {str(e)}"
print(f"Error: {e}")
traceback.print_exc()
return None, error_msg
# ้ข„่ฎพๆ–‡ๆœฌ
PRESET_TEXTS = {
"้—ฎๅ€™": "Hello! I am One! I am the first Numberblock, and I love being number one!",
"่ฎกๆ•ฐ": "One, two, three, four, five! Counting is so much fun! I can count all the way to ten!",
"ๆƒ…ๆ„Ÿ": "Sometimes I feel a little lonely being just one, but then I remember that one is the start of everything!",
}
# ๅˆ›ๅปบ Gradio ็•Œ้ข
with gr.Blocks(title="NumberBlocks One Voice Cloning") as demo:
gr.Markdown("# ๐ŸŽญ NumberBlocks One Voice Cloning (VoxCPM V5)")
gr.Markdown("### ไฝฟ็”จ VoxCPM 2 ๆจกๅž‹ๅ…‹้š† One ็š„ๅฃฐ้Ÿณ")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="่พ“ๅ…ฅๆ–‡ๆœฌ",
placeholder="่พ“ๅ…ฅ่ฆๅˆๆˆ็š„ๆ–‡ๆœฌ...",
lines=3,
value=PRESET_TEXTS["้—ฎๅ€™"]
)
with gr.Row():
for name, txt in PRESET_TEXTS.items():
gr.Button(name).click(lambda t=txt: t, inputs=None, outputs=text_input)
with gr.Column():
ref_audio_input = gr.Audio(
label="ๅ‚่€ƒ้Ÿณ้ข‘ (One ็š„ๅฃฐ้Ÿณ)",
type="filepath"
)
with gr.Row():
cfg_slider = gr.Slider(
minimum=0.5,
maximum=5.0,
value=2.0,
step=0.1,
label="CFG Value (่ถŠ้ซ˜่ถŠๅƒๅ‚่€ƒ้Ÿณ่‰ฒ)"
)
steps_slider = gr.Slider(
minimum=5,
maximum=50,
value=10,
step=1,
label="ๆŽจ็†ๆญฅๆ•ฐ (่ถŠ้ซ˜่ดจ้‡่ถŠๅฅฝไฝ†่ถŠๆ…ข)"
)
generate_btn = gr.Button("๐ŸŽ™๏ธ ็”Ÿๆˆ้Ÿณ้ข‘", variant="primary")
with gr.Row():
output_audio = gr.Audio(label="็”Ÿๆˆ็ป“ๆžœ")
status_msg = gr.Markdown(value="โธ๏ธ ็ญ‰ๅพ…็”Ÿๆˆ...")
generate_btn.click(
fn=generate_audio,
inputs=[text_input, ref_audio_input, cfg_slider, steps_slider],
outputs=[output_audio, status_msg]
)
gr.Markdown("---")
gr.Markdown("### ่ฏดๆ˜Ž")
gr.Markdown("""
- **ๅ‚่€ƒ้Ÿณ้ข‘**: ไธŠไผ  One ็š„ๅฃฐ้Ÿณ็‰‡ๆฎต๏ผˆๅปบ่ฎฎ 5-15 ็ง’ๆธ…ๆ™ฐ่ฏญ้Ÿณ๏ผ‰
- **CFG Value**: ๆŽงๅˆถ้Ÿณ่‰ฒ็›ธไผผๅบฆ๏ผŒ้ป˜่ฎค 2.0๏ผŒ่ถŠ้ซ˜่ถŠๅƒๅ‚่€ƒ้Ÿณ่‰ฒ
- **ๆŽจ็†ๆญฅๆ•ฐ**: ้ป˜่ฎค 10๏ผŒ่ถŠ้ซ˜่ดจ้‡่ถŠๅฅฝไฝ†็”Ÿๆˆ่ถŠๆ…ข
- **ๆจกๅž‹**: VoxCPM 2 (openbmb/VoxCPM2)
- **V5**: CPU float32 + SDPA mask shape fix
""")
if __name__ == "__main__":
import threading
def preload():
print("Preloading VoxCPM model...")
ensure_model()
threading.Thread(target=preload, daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=7860)