Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| NumberBlocks One Voice Cloning Space - VoxCPM V5 | |
| Fix: float32 on CPU + monkey-patch SDPA mask shape for CPU compatibility | |
| Root cause of "Dimension out of range": | |
| MiniCPM4's Attention.forward_step creates a 1D attn_mask but SDPA on CPU | |
| expects at least 2D for proper broadcasting with GQA (Grouped Query Attention). | |
| On GPU, the flash-attention backend handles this; on CPU the math backend does not. | |
| """ | |
| import os | |
| import gradio as gr | |
| import tempfile | |
| import soundfile as sf | |
| import traceback | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGINGFACE_TOKEN")) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Monkey-patch: fix SDPA mask shape for CPU | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| _original_sdpa = F.scaled_dot_product_attention | |
| def _cpu_safe_sdpa(query, key, value, attn_mask=None, **kwargs): | |
| """Wrapper that fixes 1D attn_mask for CPU SDPA.""" | |
| if attn_mask is not None and attn_mask.dim() == 1 and not torch.cuda.is_available(): | |
| # attn_mask is (seq_len,) but SDPA needs (B, H, L, S) or broadcastable | |
| # query shape: (B, H, L, D), key shape: (B, H_kv, S, D) | |
| B, H, L, D = query.shape | |
| S = key.shape[2] | |
| # Reshape 1D mask to (1, 1, 1, S) for proper broadcasting | |
| attn_mask = attn_mask.view(1, 1, 1, S).expand(B, H, L, S) | |
| return _original_sdpa(query, key, value, attn_mask=attn_mask, **kwargs) | |
| # Apply the patch globally | |
| F.scaled_dot_product_attention = _cpu_safe_sdpa | |
| print("โ Patched scaled_dot_product_attention for CPU mask shape fix") | |
| def load_model(): | |
| """ๅ ่ฝฝ VoxCPM ๆจกๅ""" | |
| try: | |
| from voxcpm import VoxCPM | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading VoxCPM model on {device}...") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| # Load model (optimize=False to avoid torch.compile issues) | |
| model = VoxCPM.from_pretrained("openbmb/VoxCPM2", load_denoiser=False, optimize=False) | |
| # CRITICAL FIX: Force float32 on CPU | |
| if device == "cpu": | |
| print("Converting model to float32 for CPU compatibility...") | |
| # Step 1: Change config dtype so _inference creates float32 tensors | |
| if hasattr(model.tts_model, 'config'): | |
| old_dtype = model.tts_model.config.dtype | |
| model.tts_model.config.dtype = "float32" | |
| print(f" config.dtype: {old_dtype} -> float32") | |
| # Step 2: Convert all model parameters and buffers to float32 | |
| model.tts_model = model.tts_model.to(torch.float32) | |
| # Step 3: Fix KV caches (created in __init__ with old dtype) | |
| if hasattr(model.tts_model, 'base_lm') and hasattr(model.tts_model.base_lm, 'kv_cache'): | |
| if model.tts_model.base_lm.kv_cache is not None: | |
| model.tts_model.base_lm.kv_cache.kv_cache = model.tts_model.base_lm.kv_cache.kv_cache.to(torch.float32) | |
| print(" base_lm KV cache -> float32") | |
| if hasattr(model.tts_model, 'residual_lm') and hasattr(model.tts_model.residual_lm, 'kv_cache'): | |
| if model.tts_model.residual_lm.kv_cache is not None: | |
| model.tts_model.residual_lm.kv_cache.kv_cache = model.tts_model.residual_lm.kv_cache.kv_cache.to(torch.float32) | |
| print(" residual_lm KV cache -> float32") | |
| print("Model conversion to float32 complete!") | |
| print("Model loaded successfully!") | |
| return model, device, None | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| traceback.print_exc() | |
| return None, "cpu", str(e) | |
| # ๅ จๅฑๆจกๅ็ถๆ | |
| MODEL_STATE = { | |
| "model": None, | |
| "device": "cpu", | |
| "error": None, | |
| "loading": False | |
| } | |
| def ensure_model(): | |
| """็กฎไฟๆจกๅๅทฒๅ ่ฝฝ""" | |
| if MODEL_STATE["model"] is None and not MODEL_STATE["loading"]: | |
| MODEL_STATE["loading"] = True | |
| try: | |
| model, device, error = load_model() | |
| MODEL_STATE["model"] = model | |
| MODEL_STATE["device"] = device | |
| MODEL_STATE["error"] = error | |
| except Exception as e: | |
| MODEL_STATE["error"] = str(e) | |
| traceback.print_exc() | |
| finally: | |
| MODEL_STATE["loading"] = False | |
| return MODEL_STATE["model"], MODEL_STATE["device"], MODEL_STATE["error"] | |
| def generate_audio(text, reference_audio, cfg_value=2.0, steps=10): | |
| """็ๆ้ณ้ข""" | |
| if not text or not reference_audio: | |
| return None, "โ ่ฏท่พๅ ฅๆๆฌๅๅ่้ณ้ข" | |
| if not text.strip(): | |
| return None, "โ ๆๆฌไธ่ฝไธบ็ฉบ" | |
| try: | |
| model, device, error = ensure_model() | |
| if error: | |
| return None, f"โ ๆจกๅๅ ่ฝฝๅคฑ่ดฅ: {error}" | |
| if model is None: | |
| return None, "โ ๆจกๅๆญฃๅจๅ ่ฝฝไธญ๏ผ่ฏท็จๅ..." | |
| # ่ฏปๅๅ่้ณ้ข | |
| ref_audio, sr = sf.read(reference_audio) | |
| # ๅฆๆๆฏ็ซไฝๅฃฐ๏ผ่ฝฌๆขไธบๅๅฃฐ้ | |
| if len(ref_audio.shape) > 1: | |
| ref_audio = ref_audio[:, 0] | |
| # ไฟๅญๅฐไธดๆถๆไปถ | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| sf.write(tmp.name, ref_audio, sr) | |
| ref_path = tmp.name | |
| print(f"Generating with text: {text[:50]}...") | |
| print(f"Reference audio: {len(ref_audio)/sr:.2f}s at {sr}Hz") | |
| # ็ๆ้ณ้ข | |
| import time | |
| t0 = time.time() | |
| wav = model.generate( | |
| text=text, | |
| reference_wav_path=ref_path, | |
| cfg_value=float(cfg_value), | |
| inference_timesteps=int(steps), | |
| ) | |
| elapsed = time.time() - t0 | |
| # ไฟๅญ่พๅบ | |
| sample_rate = model.tts_model.sample_rate | |
| output_path = "/tmp/voxcpm_output.wav" | |
| sf.write(output_path, wav, sample_rate) | |
| duration = len(wav) / sample_rate | |
| msg = f"โ ็ๆๆๅ! ๆถ้ฟ: {duration:.2f}s, ่ๆถ: {elapsed:.1f}s, ่ฎพๅค: {device}" | |
| print(msg) | |
| # ๆธ ็ไธดๆถๆไปถ | |
| os.unlink(ref_path) | |
| return output_path, msg | |
| except Exception as e: | |
| error_msg = f"โ ็ๆๅคฑ่ดฅ: {str(e)}" | |
| print(f"Error: {e}") | |
| traceback.print_exc() | |
| return None, error_msg | |
| # ้ข่ฎพๆๆฌ | |
| PRESET_TEXTS = { | |
| "้ฎๅ": "Hello! I am One! I am the first Numberblock, and I love being number one!", | |
| "่ฎกๆฐ": "One, two, three, four, five! Counting is so much fun! I can count all the way to ten!", | |
| "ๆ ๆ": "Sometimes I feel a little lonely being just one, but then I remember that one is the start of everything!", | |
| } | |
| # ๅๅปบ Gradio ็้ข | |
| with gr.Blocks(title="NumberBlocks One Voice Cloning") as demo: | |
| gr.Markdown("# ๐ญ NumberBlocks One Voice Cloning (VoxCPM V5)") | |
| gr.Markdown("### ไฝฟ็จ VoxCPM 2 ๆจกๅๅ ้ One ็ๅฃฐ้ณ") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="่พๅ ฅๆๆฌ", | |
| placeholder="่พๅ ฅ่ฆๅๆ็ๆๆฌ...", | |
| lines=3, | |
| value=PRESET_TEXTS["้ฎๅ"] | |
| ) | |
| with gr.Row(): | |
| for name, txt in PRESET_TEXTS.items(): | |
| gr.Button(name).click(lambda t=txt: t, inputs=None, outputs=text_input) | |
| with gr.Column(): | |
| ref_audio_input = gr.Audio( | |
| label="ๅ่้ณ้ข (One ็ๅฃฐ้ณ)", | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| cfg_slider = gr.Slider( | |
| minimum=0.5, | |
| maximum=5.0, | |
| value=2.0, | |
| step=0.1, | |
| label="CFG Value (่ถ้ซ่ถๅๅ่้ณ่ฒ)" | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label="ๆจ็ๆญฅๆฐ (่ถ้ซ่ดจ้่ถๅฅฝไฝ่ถๆ ข)" | |
| ) | |
| generate_btn = gr.Button("๐๏ธ ็ๆ้ณ้ข", variant="primary") | |
| with gr.Row(): | |
| output_audio = gr.Audio(label="็ๆ็ปๆ") | |
| status_msg = gr.Markdown(value="โธ๏ธ ็ญๅพ ็ๆ...") | |
| generate_btn.click( | |
| fn=generate_audio, | |
| inputs=[text_input, ref_audio_input, cfg_slider, steps_slider], | |
| outputs=[output_audio, status_msg] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### ่ฏดๆ") | |
| gr.Markdown(""" | |
| - **ๅ่้ณ้ข**: ไธไผ One ็ๅฃฐ้ณ็ๆฎต๏ผๅปบ่ฎฎ 5-15 ็งๆธ ๆฐ่ฏญ้ณ๏ผ | |
| - **CFG Value**: ๆงๅถ้ณ่ฒ็ธไผผๅบฆ๏ผ้ป่ฎค 2.0๏ผ่ถ้ซ่ถๅๅ่้ณ่ฒ | |
| - **ๆจ็ๆญฅๆฐ**: ้ป่ฎค 10๏ผ่ถ้ซ่ดจ้่ถๅฅฝไฝ็ๆ่ถๆ ข | |
| - **ๆจกๅ**: VoxCPM 2 (openbmb/VoxCPM2) | |
| - **V5**: CPU float32 + SDPA mask shape fix | |
| """) | |
| if __name__ == "__main__": | |
| import threading | |
| def preload(): | |
| print("Preloading VoxCPM model...") | |
| ensure_model() | |
| threading.Thread(target=preload, daemon=True).start() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |