Add IndicVox paper demo Space
Browse files- .gitattributes +3 -0
- README.md +29 -5
- app.py +451 -0
- assets/voices/hin_m_ref_00.wav +3 -0
- assets/voices/tam_f_ref_00.wav +3 -0
- assets/voices/tam_m_ref_00.wav +3 -0
- code_switch_prompts.json +166 -0
- packages.txt +2 -0
- requirements.txt +13 -0
- voxcpm/__init__.py +5 -0
- voxcpm/cli.py +598 -0
- voxcpm/core.py +333 -0
- voxcpm/model/__init__.py +4 -0
- voxcpm/model/utils.py +121 -0
- voxcpm/model/voxcpm.py +985 -0
- voxcpm/model/voxcpm2.py +1224 -0
- voxcpm/modules/__init__.py +0 -0
- voxcpm/modules/audiovae/__init__.py +2 -0
- voxcpm/modules/audiovae/audio_vae.py +377 -0
- voxcpm/modules/audiovae/audio_vae_v2.py +486 -0
- voxcpm/modules/layers/__init__.py +1 -0
- voxcpm/modules/layers/lora.py +130 -0
- voxcpm/modules/layers/scalar_quantization_layer.py +26 -0
- voxcpm/modules/locdit/__init__.py +3 -0
- voxcpm/modules/locdit/local_dit.py +114 -0
- voxcpm/modules/locdit/local_dit_v2.py +116 -0
- voxcpm/modules/locdit/unified_cfm.py +232 -0
- voxcpm/modules/locenc/__init__.py +1 -0
- voxcpm/modules/locenc/local_encoder.py +30 -0
- voxcpm/modules/minicpm4/__init__.py +3 -0
- voxcpm/modules/minicpm4/cache.py +47 -0
- voxcpm/modules/minicpm4/config.py +30 -0
- voxcpm/modules/minicpm4/model.py +429 -0
- voxcpm/training/__init__.py +27 -0
- voxcpm/training/accelerator.py +163 -0
- voxcpm/training/config.py +38 -0
- voxcpm/training/data.py +214 -0
- voxcpm/training/packers.py +296 -0
- voxcpm/training/state.py +20 -0
- voxcpm/training/tracker.py +78 -0
- voxcpm/utils/text_normalize.py +188 -0
- voxcpm/zipenhancer.py +72 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/voices/hin_m_ref_00.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/voices/tam_f_ref_00.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/voices/tam_m_ref_00.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,36 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.12.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "IndicVox: Hindi & Tamil Code-Switching TTS"
|
| 3 |
+
emoji: "🎙️"
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.12.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
python_version: "3.10.16"
|
| 11 |
+
suggested_hardware: a10g-small
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# IndicVox
|
| 15 |
+
|
| 16 |
+
IndicVox is a GPU-backed research demo for multilingual text-to-speech across Hindi, Tamil, and code-switched prompts. The Space exposes the paper checkpoints through a clean Gradio UI with built-in voice presets and example prompts.
|
| 17 |
+
|
| 18 |
+
## What it includes
|
| 19 |
+
|
| 20 |
+
- `Hindi Focus` for Hindi and Hindi-English prompts
|
| 21 |
+
- `Tamil Focus` for Tamil and Tamil-English prompts
|
| 22 |
+
- `Research Baseline` for direct comparison against the untuned multilingual model
|
| 23 |
+
- Built-in research voice presets for fast demo playback
|
| 24 |
+
- Zero-shot `Text Only` mode if you want to skip reference conditioning
|
| 25 |
+
|
| 26 |
+
## Usage
|
| 27 |
+
|
| 28 |
+
1. Pick a model profile.
|
| 29 |
+
2. Type a Hindi, Tamil, or code-switched prompt.
|
| 30 |
+
3. Pick a built-in voice preset or `Text Only`.
|
| 31 |
+
4. Click `Generate Speech`.
|
| 32 |
+
|
| 33 |
+
## Notes
|
| 34 |
+
|
| 35 |
+
- The base multilingual model stays resident on GPU memory and the paper checkpoints are swapped on demand.
|
| 36 |
+
- The Space is meant for inference/demo usage, not batch evaluation.
|
app.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
import traceback
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
|
| 16 |
+
APP_DIR = Path(__file__).resolve().parent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def resolve_persist_root() -> Path:
|
| 20 |
+
data_root = Path("/data")
|
| 21 |
+
if data_root.exists() and os.access(data_root, os.W_OK):
|
| 22 |
+
return data_root
|
| 23 |
+
|
| 24 |
+
local_root = APP_DIR / ".cache"
|
| 25 |
+
local_root.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
return local_root
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
PERSIST_ROOT = resolve_persist_root()
|
| 30 |
+
HF_HOME = PERSIST_ROOT / "huggingface"
|
| 31 |
+
HF_HOME.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
os.environ.setdefault("HF_HOME", str(HF_HOME))
|
| 33 |
+
os.environ.setdefault("HF_HUB_CACHE", str(HF_HOME / "hub"))
|
| 34 |
+
|
| 35 |
+
sys.path.insert(0, str(APP_DIR))
|
| 36 |
+
|
| 37 |
+
from voxcpm import VoxCPM
|
| 38 |
+
from voxcpm.model.voxcpm import LoRAConfig
|
| 39 |
+
|
| 40 |
+
SPACE_TITLE = "IndicVox: Hindi & Tamil Code-Switching TTS"
|
| 41 |
+
MODEL_REPO_ID = "himahande45/multilingual-tts"
|
| 42 |
+
PROMPTS_FILE = APP_DIR / "code_switch_prompts.json"
|
| 43 |
+
VOICE_DIR = APP_DIR / "assets" / "voices"
|
| 44 |
+
DEFAULT_PROFILE = "Tamil Focus"
|
| 45 |
+
DEFAULT_VOICE = "Tamil Female Research Voice"
|
| 46 |
+
DEFAULT_TEXT = "இந்த experimentக்கு clean reference audio use பண்ணணும், இல்லனா output quality drop ஆகும்."
|
| 47 |
+
|
| 48 |
+
MODEL_PATTERNS = [
|
| 49 |
+
"VoxCPM2_local/*",
|
| 50 |
+
"finetune_checkpoints/step_0000500/lora_config.json",
|
| 51 |
+
"finetune_checkpoints/step_0000500/lora_weights.safetensors",
|
| 52 |
+
"finetune_checkpoints/step_0001000/lora_config.json",
|
| 53 |
+
"finetune_checkpoints/step_0001000/lora_weights.safetensors",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
PROFILES = {
|
| 57 |
+
"Tamil Focus": {
|
| 58 |
+
"description": "Best for Tamil and Tamil-English code-switched prompts.",
|
| 59 |
+
"checkpoint_dir": "finetune_checkpoints/step_0001000",
|
| 60 |
+
},
|
| 61 |
+
"Hindi Focus": {
|
| 62 |
+
"description": "Best for Hindi and Hindi-English code-switched prompts.",
|
| 63 |
+
"checkpoint_dir": "finetune_checkpoints/step_0000500",
|
| 64 |
+
},
|
| 65 |
+
"Research Baseline": {
|
| 66 |
+
"description": "Base multilingual checkpoint without paper fine-tuning.",
|
| 67 |
+
"checkpoint_dir": None,
|
| 68 |
+
},
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
VOICE_PRESETS = {
|
| 72 |
+
"Hindi Research Voice": {
|
| 73 |
+
"path": VOICE_DIR / "hin_m_ref_00.wav",
|
| 74 |
+
"transcript": "लेकिन क्या यह हम सभी कार्यक्रमों के साथ कर सकते?",
|
| 75 |
+
"summary": "Short Hindi reference used for sharper Hindi + English prompting.",
|
| 76 |
+
},
|
| 77 |
+
"Tamil Female Research Voice": {
|
| 78 |
+
"path": VOICE_DIR / "tam_f_ref_00.wav",
|
| 79 |
+
"transcript": "விக்கற நேரத்தையும் லாபத்தையும் பொறுத்து, இந்த டேக்ஸை ஷார்ட் டேர்ம் இல்ல லாங் டேர்ம்னு பிரிப்பாங்க.",
|
| 80 |
+
"summary": "Clear Tamil reference with stable conversational prosody.",
|
| 81 |
+
},
|
| 82 |
+
"Tamil Male Research Voice": {
|
| 83 |
+
"path": VOICE_DIR / "tam_m_ref_00.wav",
|
| 84 |
+
"transcript": "கொரோனா பாதிப்பு காலத்தில் எண்பது கோடி மக்களுக்கு உணவு தானியம் வழங்கப்பட்டதாகவும் அவர் தெரிவித்தார்.",
|
| 85 |
+
"summary": "Tamil male reference that holds rhythm well on longer prompts.",
|
| 86 |
+
},
|
| 87 |
+
"Text Only": {
|
| 88 |
+
"path": None,
|
| 89 |
+
"transcript": None,
|
| 90 |
+
"summary": "Zero-shot generation without a reference voice clip.",
|
| 91 |
+
},
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
CUSTOM_CSS = """
|
| 95 |
+
#app-shell {
|
| 96 |
+
max-width: 1180px;
|
| 97 |
+
margin: 0 auto;
|
| 98 |
+
}
|
| 99 |
+
#hero {
|
| 100 |
+
padding: 24px 26px 12px 26px;
|
| 101 |
+
border: 1px solid rgba(255, 255, 255, 0.08);
|
| 102 |
+
border-radius: 22px;
|
| 103 |
+
background:
|
| 104 |
+
radial-gradient(circle at top right, rgba(99, 102, 241, 0.16), transparent 34%),
|
| 105 |
+
radial-gradient(circle at bottom left, rgba(16, 185, 129, 0.14), transparent 30%),
|
| 106 |
+
rgba(15, 23, 42, 0.74);
|
| 107 |
+
}
|
| 108 |
+
.stat-chip {
|
| 109 |
+
display: inline-block;
|
| 110 |
+
margin: 6px 8px 0 0;
|
| 111 |
+
padding: 8px 12px;
|
| 112 |
+
border-radius: 999px;
|
| 113 |
+
background: rgba(255, 255, 255, 0.06);
|
| 114 |
+
font-size: 0.92rem;
|
| 115 |
+
}
|
| 116 |
+
.footnote {
|
| 117 |
+
opacity: 0.78;
|
| 118 |
+
font-size: 0.94rem;
|
| 119 |
+
}
|
| 120 |
+
footer {
|
| 121 |
+
visibility: hidden;
|
| 122 |
+
}
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
if torch.cuda.is_available():
|
| 126 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 127 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 128 |
+
torch.set_float32_matmul_precision("high")
|
| 129 |
+
|
| 130 |
+
THEME = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_examples() -> list[list[str]]:
|
| 134 |
+
with PROMPTS_FILE.open("r", encoding="utf-8") as f:
|
| 135 |
+
prompt_bank = json.load(f)
|
| 136 |
+
|
| 137 |
+
return [
|
| 138 |
+
[prompt_bank["hi_en"][0]["text"], "Hindi Focus", "Hindi Research Voice"],
|
| 139 |
+
[prompt_bank["hi_en"][9]["text"], "Hindi Focus", "Hindi Research Voice"],
|
| 140 |
+
[prompt_bank["hi_en"][16]["text"], "Hindi Focus", "Hindi Research Voice"],
|
| 141 |
+
[prompt_bank["ta_en"][0]["text"], "Tamil Focus", "Tamil Female Research Voice"],
|
| 142 |
+
[prompt_bank["ta_en"][9]["text"], "Tamil Focus", "Tamil Female Research Voice"],
|
| 143 |
+
[prompt_bank["ta_en"][14]["text"], "Tamil Focus", "Tamil Male Research Voice"],
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def profile_markdown(profile_name: str) -> str:
|
| 148 |
+
description = PROFILES[profile_name]["description"]
|
| 149 |
+
return f"**{profile_name}** \n{description}"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def voice_markdown(voice_name: str) -> str:
|
| 153 |
+
voice = VOICE_PRESETS[voice_name]
|
| 154 |
+
if voice["path"] is None:
|
| 155 |
+
return f"**{voice_name}** \n{voice['summary']}"
|
| 156 |
+
transcript = voice["transcript"]
|
| 157 |
+
return f"**{voice_name}** \n{voice['summary']} \nReference transcript: `{transcript}`"
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def dynamic_max_len(text: str) -> int:
|
| 161 |
+
char_count = max(len(text.strip()), 1)
|
| 162 |
+
return max(280, min(900, int(char_count * 7.5)))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ModelManager:
|
| 166 |
+
def __init__(self) -> None:
|
| 167 |
+
self.lock = threading.Lock()
|
| 168 |
+
self.repo_dir = self._resolve_repo_dir()
|
| 169 |
+
self.base_dir = self.repo_dir / "VoxCPM2_local"
|
| 170 |
+
self.loaded_profile: str | None = None
|
| 171 |
+
self.active_profile: str | None = None
|
| 172 |
+
self.model = self._load_model()
|
| 173 |
+
self.activate_profile(DEFAULT_PROFILE)
|
| 174 |
+
|
| 175 |
+
def _resolve_repo_dir(self) -> Path:
|
| 176 |
+
local_repo = os.getenv("INDICVOX_LOCAL_MODEL_REPO")
|
| 177 |
+
if local_repo:
|
| 178 |
+
path = Path(local_repo).expanduser().resolve()
|
| 179 |
+
if path.exists():
|
| 180 |
+
return path
|
| 181 |
+
raise FileNotFoundError(f"INDICVOX_LOCAL_MODEL_REPO does not exist: {path}")
|
| 182 |
+
|
| 183 |
+
token = os.getenv("HF_TOKEN")
|
| 184 |
+
snapshot_path = snapshot_download(
|
| 185 |
+
repo_id=MODEL_REPO_ID,
|
| 186 |
+
repo_type="model",
|
| 187 |
+
allow_patterns=MODEL_PATTERNS,
|
| 188 |
+
token=token,
|
| 189 |
+
)
|
| 190 |
+
return Path(snapshot_path)
|
| 191 |
+
|
| 192 |
+
def _load_lora_config(self, checkpoint_dir: Path) -> LoRAConfig:
|
| 193 |
+
payload = json.loads((checkpoint_dir / "lora_config.json").read_text(encoding="utf-8"))
|
| 194 |
+
return LoRAConfig(**payload["lora_config"])
|
| 195 |
+
|
| 196 |
+
def _load_model(self) -> VoxCPM:
|
| 197 |
+
if not torch.cuda.is_available():
|
| 198 |
+
raise RuntimeError("A GPU runtime is required. Request an A10G/L4 Space and restart.")
|
| 199 |
+
|
| 200 |
+
checkpoint_dir = self.repo_dir / PROFILES[DEFAULT_PROFILE]["checkpoint_dir"]
|
| 201 |
+
lora_config = self._load_lora_config(checkpoint_dir)
|
| 202 |
+
model = VoxCPM.from_pretrained(
|
| 203 |
+
hf_model_id=str(self.base_dir),
|
| 204 |
+
load_denoiser=False,
|
| 205 |
+
optimize=False,
|
| 206 |
+
lora_config=lora_config,
|
| 207 |
+
)
|
| 208 |
+
return model
|
| 209 |
+
|
| 210 |
+
def activate_profile(self, profile_name: str) -> None:
|
| 211 |
+
spec = PROFILES[profile_name]
|
| 212 |
+
checkpoint_dir = spec["checkpoint_dir"]
|
| 213 |
+
|
| 214 |
+
if checkpoint_dir is None:
|
| 215 |
+
self.model.set_lora_enabled(False)
|
| 216 |
+
self.active_profile = profile_name
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
if self.loaded_profile != profile_name:
|
| 220 |
+
if self.loaded_profile is not None:
|
| 221 |
+
self.model.unload_lora()
|
| 222 |
+
self.model.load_lora(str(self.repo_dir / checkpoint_dir))
|
| 223 |
+
self.loaded_profile = profile_name
|
| 224 |
+
|
| 225 |
+
self.model.set_lora_enabled(True)
|
| 226 |
+
self.active_profile = profile_name
|
| 227 |
+
|
| 228 |
+
def synthesize(
|
| 229 |
+
self,
|
| 230 |
+
text: str,
|
| 231 |
+
profile_name: str,
|
| 232 |
+
voice_name: str,
|
| 233 |
+
cfg_value: float,
|
| 234 |
+
inference_steps: int,
|
| 235 |
+
) -> tuple[tuple[int, np.ndarray], str]:
|
| 236 |
+
clean_text = text.strip()
|
| 237 |
+
if not clean_text:
|
| 238 |
+
raise gr.Error("Enter a prompt first.")
|
| 239 |
+
|
| 240 |
+
start = time.perf_counter()
|
| 241 |
+
with self.lock:
|
| 242 |
+
self.activate_profile(profile_name)
|
| 243 |
+
kwargs = {
|
| 244 |
+
"text": clean_text,
|
| 245 |
+
"cfg_value": float(cfg_value),
|
| 246 |
+
"inference_timesteps": int(inference_steps),
|
| 247 |
+
"max_len": dynamic_max_len(clean_text),
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
voice = VOICE_PRESETS[voice_name]
|
| 251 |
+
if voice["path"] is not None:
|
| 252 |
+
kwargs["prompt_wav_path"] = str(voice["path"])
|
| 253 |
+
kwargs["prompt_text"] = voice["transcript"]
|
| 254 |
+
|
| 255 |
+
wav = self.model.generate(**kwargs)
|
| 256 |
+
sample_rate = int(self.model.tts_model.sample_rate)
|
| 257 |
+
|
| 258 |
+
if isinstance(wav, torch.Tensor):
|
| 259 |
+
wav = wav.detach().cpu().numpy()
|
| 260 |
+
wav = np.asarray(wav, dtype=np.float32).squeeze()
|
| 261 |
+
wav = np.clip(wav, -1.0, 1.0)
|
| 262 |
+
|
| 263 |
+
elapsed = time.perf_counter() - start
|
| 264 |
+
duration = float(wav.shape[-1]) / sample_rate if wav.size else 0.0
|
| 265 |
+
rtf = elapsed / duration if duration > 0 else float("nan")
|
| 266 |
+
speed_line = f"RTF {rtf:.2f}x" if np.isfinite(rtf) else "RTF n/a"
|
| 267 |
+
status = (
|
| 268 |
+
f"**Ready** \n"
|
| 269 |
+
f"Profile: `{profile_name}` \n"
|
| 270 |
+
f"Voice: `{voice_name}` \n"
|
| 271 |
+
f"Audio length: `{duration:.2f}s` \n"
|
| 272 |
+
f"Generation time: `{elapsed:.2f}s` ({speed_line})"
|
| 273 |
+
)
|
| 274 |
+
return (sample_rate, wav), status
|
| 275 |
+
|
| 276 |
+
def boot_markdown(self) -> str:
|
| 277 |
+
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU"
|
| 278 |
+
active_profile = self.active_profile or DEFAULT_PROFILE
|
| 279 |
+
return (
|
| 280 |
+
f"**GPU Ready** \n"
|
| 281 |
+
f"Runtime: `{gpu_name}` \n"
|
| 282 |
+
f"Warm profile: `{active_profile}` \n"
|
| 283 |
+
f"Model source: `{MODEL_REPO_ID}`"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
BOOT_ERROR: str | None = None
|
| 288 |
+
MODEL_MANAGER: ModelManager | None = None
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
MODEL_MANAGER = ModelManager()
|
| 292 |
+
except Exception:
|
| 293 |
+
BOOT_ERROR = traceback.format_exc()
|
| 294 |
+
|
| 295 |
+
EXAMPLES = load_examples()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def synthesize(text: str, profile_name: str, voice_name: str, cfg_value: float, inference_steps: int):
|
| 299 |
+
if MODEL_MANAGER is None:
|
| 300 |
+
raise gr.Error(f"Model initialization failed.\n\n{BOOT_ERROR}")
|
| 301 |
+
return MODEL_MANAGER.synthesize(text, profile_name, voice_name, cfg_value, inference_steps)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def voice_preview(voice_name: str):
|
| 305 |
+
voice = VOICE_PRESETS[voice_name]
|
| 306 |
+
preview_path = str(voice["path"]) if voice["path"] is not None else None
|
| 307 |
+
return preview_path, voice_markdown(voice_name)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def clear_prompt() -> str:
|
| 311 |
+
return ""
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def boot_status() -> str:
|
| 315 |
+
if MODEL_MANAGER is not None:
|
| 316 |
+
return MODEL_MANAGER.boot_markdown()
|
| 317 |
+
return f"**Startup Error** \n```text\n{BOOT_ERROR}\n```"
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
with gr.Blocks() as demo:
|
| 321 |
+
with gr.Column(elem_id="app-shell"):
|
| 322 |
+
gr.HTML(
|
| 323 |
+
"""
|
| 324 |
+
<div id="hero">
|
| 325 |
+
<h1>IndicVox</h1>
|
| 326 |
+
<p>Research demo for multilingual TTS across Hindi, Tamil, and code-switched prompts.</p>
|
| 327 |
+
<div>
|
| 328 |
+
<span class="stat-chip">GPU-backed Space</span>
|
| 329 |
+
<span class="stat-chip">Warm-loaded model</span>
|
| 330 |
+
<span class="stat-chip">Hindi + Tamil + English prompts</span>
|
| 331 |
+
</div>
|
| 332 |
+
</div>
|
| 333 |
+
"""
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
with gr.Row():
|
| 337 |
+
with gr.Column(scale=5):
|
| 338 |
+
prompt = gr.Textbox(
|
| 339 |
+
label="Prompt",
|
| 340 |
+
value=DEFAULT_TEXT,
|
| 341 |
+
lines=5,
|
| 342 |
+
max_lines=8,
|
| 343 |
+
placeholder="Type Hindi, Tamil, or code-switched text here...",
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
with gr.Row():
|
| 347 |
+
profile = gr.Dropdown(
|
| 348 |
+
choices=list(PROFILES.keys()),
|
| 349 |
+
value=DEFAULT_PROFILE,
|
| 350 |
+
label="Model Profile",
|
| 351 |
+
info="Switch between the Hindi-tuned and Tamil-tuned research profiles.",
|
| 352 |
+
)
|
| 353 |
+
voice = gr.Dropdown(
|
| 354 |
+
choices=list(VOICE_PRESETS.keys()),
|
| 355 |
+
value=DEFAULT_VOICE,
|
| 356 |
+
label="Voice Preset",
|
| 357 |
+
info="Built-in research voices plus a zero-shot option.",
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 361 |
+
with gr.Row():
|
| 362 |
+
cfg_value = gr.Slider(
|
| 363 |
+
minimum=1.0,
|
| 364 |
+
maximum=4.0,
|
| 365 |
+
value=2.0,
|
| 366 |
+
step=0.1,
|
| 367 |
+
label="CFG",
|
| 368 |
+
info="Higher values usually sound more guided but less relaxed.",
|
| 369 |
+
)
|
| 370 |
+
inference_steps = gr.Slider(
|
| 371 |
+
minimum=6,
|
| 372 |
+
maximum=16,
|
| 373 |
+
value=10,
|
| 374 |
+
step=1,
|
| 375 |
+
label="Diffusion Steps",
|
| 376 |
+
info="10 is the paper demo default.",
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
with gr.Row():
|
| 380 |
+
generate_btn = gr.Button("Generate Speech", variant="primary", size="lg")
|
| 381 |
+
clear_btn = gr.Button("Clear Prompt")
|
| 382 |
+
|
| 383 |
+
with gr.Row():
|
| 384 |
+
profile_info = gr.Markdown(profile_markdown(DEFAULT_PROFILE))
|
| 385 |
+
voice_info = gr.Markdown(voice_markdown(DEFAULT_VOICE))
|
| 386 |
+
|
| 387 |
+
with gr.Column(scale=4):
|
| 388 |
+
status = gr.Markdown(boot_status())
|
| 389 |
+
output_audio = gr.Audio(
|
| 390 |
+
label="Synthesized Audio",
|
| 391 |
+
autoplay=False,
|
| 392 |
+
format="wav",
|
| 393 |
+
)
|
| 394 |
+
voice_preview_audio = gr.Audio(
|
| 395 |
+
label="Voice Preset Preview",
|
| 396 |
+
value=str(VOICE_PRESETS[DEFAULT_VOICE]["path"]),
|
| 397 |
+
interactive=False,
|
| 398 |
+
autoplay=False,
|
| 399 |
+
format="wav",
|
| 400 |
+
)
|
| 401 |
+
gr.Markdown(
|
| 402 |
+
"The demo keeps the base model resident on GPU and swaps paper checkpoints on demand.",
|
| 403 |
+
elem_classes=["footnote"],
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
with gr.Tabs():
|
| 407 |
+
with gr.Tab("Hindi + English Examples"):
|
| 408 |
+
gr.Examples(
|
| 409 |
+
examples=[row for row in EXAMPLES if row[1] == "Hindi Focus"],
|
| 410 |
+
inputs=[prompt, profile, voice],
|
| 411 |
+
cache_examples=False,
|
| 412 |
+
)
|
| 413 |
+
with gr.Tab("Tamil + English Examples"):
|
| 414 |
+
gr.Examples(
|
| 415 |
+
examples=[row for row in EXAMPLES if row[1] == "Tamil Focus"],
|
| 416 |
+
inputs=[prompt, profile, voice],
|
| 417 |
+
cache_examples=False,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
gr.Markdown(
|
| 421 |
+
"""
|
| 422 |
+
**Demo notes**
|
| 423 |
+
|
| 424 |
+
- `Hindi Focus` maps to the Hindi-strong checkpoint from the paper experiments.
|
| 425 |
+
- `Tamil Focus` maps to the Tamil + code-switch checkpoint and is the default for the Space.
|
| 426 |
+
- `Text Only` skips the reference clip and runs zero-shot synthesis.
|
| 427 |
+
""",
|
| 428 |
+
elem_classes=["footnote"],
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
generate_btn.click(
|
| 432 |
+
fn=synthesize,
|
| 433 |
+
inputs=[prompt, profile, voice, cfg_value, inference_steps],
|
| 434 |
+
outputs=[output_audio, status],
|
| 435 |
+
api_name="synthesize",
|
| 436 |
+
)
|
| 437 |
+
prompt.submit(
|
| 438 |
+
fn=synthesize,
|
| 439 |
+
inputs=[prompt, profile, voice, cfg_value, inference_steps],
|
| 440 |
+
outputs=[output_audio, status],
|
| 441 |
+
api_name=False,
|
| 442 |
+
)
|
| 443 |
+
profile.change(fn=profile_markdown, inputs=profile, outputs=profile_info, api_name=False)
|
| 444 |
+
voice.change(fn=voice_preview, inputs=voice, outputs=[voice_preview_audio, voice_info], api_name=False)
|
| 445 |
+
clear_btn.click(fn=clear_prompt, outputs=prompt, api_name=False)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
demo.queue(default_concurrency_limit=1, max_size=16)
|
| 449 |
+
|
| 450 |
+
if __name__ == "__main__":
|
| 451 |
+
demo.launch(theme=THEME, css=CUSTOM_CSS)
|
assets/voices/hin_m_ref_00.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5400ed6ce26df5efddce5e264153423f378623af05f987cc1c435c06cfd24df2
|
| 3 |
+
size 398732
|
assets/voices/tam_f_ref_00.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71564e96ac0378d91f35cf70e27f56e5ad814db267c59eac91ac370e612998f0
|
| 3 |
+
size 605228
|
assets/voices/tam_m_ref_00.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07ba36b8a8ac46246b0ae5cabe12b90325bd62eb4532dd4b6a117b306c8658d3
|
| 3 |
+
size 573452
|
code_switch_prompts.json
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hi_en": [
|
| 3 |
+
{
|
| 4 |
+
"id": "hi_en_001",
|
| 5 |
+
"text": "आज morning standup में हमने Hindi और English prompts पर ASR output compare किया।"
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"id": "hi_en_002",
|
| 9 |
+
"text": "कल client demo से पहले तुम latest checkpoint का audio sample एक बार verify कर लो।"
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"id": "hi_en_003",
|
| 13 |
+
"text": "अगर final report ready है तो उसे shared drive में upload कर दो।"
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": "hi_en_004",
|
| 17 |
+
"text": "ये model normal sentences अच्छा बोलता है, लेकिन code-switch parts में अभी भी थोड़ा hesitation आता है।"
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"id": "hi_en_005",
|
| 21 |
+
"text": "मुझे लगता है कि speaker similarity के लिए हमें clean reference clip use करना चाहिए।"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "hi_en_006",
|
| 25 |
+
"text": "तुमने meeting notes में Tamil section add किया या वो अभी pending है?"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"id": "hi_en_007",
|
| 29 |
+
"text": "आज lab में GPU free है, इसलिए full evaluation run अभी start कर देते हैं।"
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"id": "hi_en_008",
|
| 33 |
+
"text": "अगर transcript में punctuation ज्यादा हो तो Whisper कभी कभी extra words insert कर देता है।"
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"id": "hi_en_009",
|
| 37 |
+
"text": "इस experiment के लिए मैंने short reference audio चुना ताकि cloning stable रहे।"
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"id": "hi_en_010",
|
| 41 |
+
"text": "हम paper में monolingual results और code-switch results अलग tables में दिखाएँगे।"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": "hi_en_011",
|
| 45 |
+
"text": "please final plots save कर लेना, वरना thesis draft फिर से update करना पड़ेगा।"
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"id": "hi_en_012",
|
| 49 |
+
"text": "आज के test set में proper nouns, news style और casual conversation तीनों mix किए गए हैं।"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": "hi_en_013",
|
| 53 |
+
"text": "अगर base model Tamil शब्द गलत बोलता है तो LoRA adaptation का effect तुरंत दिख जाएगा।"
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"id": "hi_en_014",
|
| 57 |
+
"text": "मैंने summary sheet में WER, CER, switch-WER और speaker similarity सब add कर दिया है।"
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"id": "hi_en_015",
|
| 61 |
+
"text": "आज evening तक तुम generated audio folders को model-wise sort कर दो।"
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"id": "hi_en_016",
|
| 65 |
+
"text": "meeting के बाद हम ASR transcripts manually spot-check भी करेंगे ताकि obvious errors miss न हों।"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"id": "hi_en_017",
|
| 69 |
+
"text": "ये checkpoint short prompts पर ठीक है, पर long mixed sentences में इसकी rhythm थोड़ी uneven लगती है।"
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"id": "hi_en_018",
|
| 73 |
+
"text": "अगर inference time ज्यादा हुआ तो पहले pilot run करेंगे और फिर full batch launch करेंगे।"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"id": "hi_en_019",
|
| 77 |
+
"text": "reference speaker clean है, लेकिन generated output में English words का stress अभी consistent नहीं है।"
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"id": "hi_en_020",
|
| 81 |
+
"text": "इस बार final appendix में example prompts, transcripts और metric formulas तीनों include करना।"
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"ta_en": [
|
| 85 |
+
{
|
| 86 |
+
"id": "ta_en_001",
|
| 87 |
+
"text": "நேத்து team meetingல புதிய checkpoint பற்றி detailedஆ பேசினோம்."
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"id": "ta_en_002",
|
| 91 |
+
"text": "இந்த experimentக்கு clean reference audio use பண்ணணும், இல்லனா output quality drop ஆகும்."
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"id": "ta_en_003",
|
| 95 |
+
"text": "final report ready ஆனதும் அதை shared folderல upload பண்ணிடு."
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"id": "ta_en_004",
|
| 99 |
+
"text": "இந்த model Tamil words நல்லா பேசுது, ஆனா English switch வரும் இடங்களில் இன்னும் slight hesitation இருக்கு."
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"id": "ta_en_005",
|
| 103 |
+
"text": "speaker similarity score stable ஆகணும்னா same voice reference தொடர்ந்து use பண்ணணும்."
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"id": "ta_en_006",
|
| 107 |
+
"text": "இன்று full evaluation run start பண்ணலாம், ஏன்னா GPU slot இப்போ free இருக்கு."
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"id": "ta_en_007",
|
| 111 |
+
"text": "Whisper transcriptல punctuation இல்லாதப்போ சில code-switch words betterஆ capture ஆகுது."
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"id": "ta_en_008",
|
| 115 |
+
"text": "paper tableல monolingual Tamil resultsவும் Tamil-English resultsவும் separateஆ காட்டணும்."
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"id": "ta_en_009",
|
| 119 |
+
"text": "இந்த promptல proper noun, news style, casual speech மூன்றும் mixedஆ இருக்கு."
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"id": "ta_en_010",
|
| 123 |
+
"text": "latest checkpoint load பண்ணதுக்குப் பிறகு ஒரு short sanity test முதலில் run பண்ணலாம்."
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"id": "ta_en_011",
|
| 127 |
+
"text": "please generated audio files எல்லாம் model-wise sort பண்ணி metrics folderக்குள் move பண்ணு."
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"id": "ta_en_012",
|
| 131 |
+
"text": "இந்த setupல base modelக்கு Tamil pronunciation கொஞ்சம் weakஆ இருந்தா LoRA gain clearஆ தெரியும்."
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"id": "ta_en_013",
|
| 135 |
+
"text": "summary sheetல WER, CER, switch-WER, speaker similarity எல்லாமே சேர்க்கணும்."
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"id": "ta_en_014",
|
| 139 |
+
"text": "meeting முடிஞ்சதும் manual spot-check பண்ணி obvious ASR mistakes இருக்கா என்று பார்க்கலாம்."
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"id": "ta_en_015",
|
| 143 |
+
"text": "short promptsல output cleanஆ இருக்கு, ஆனா long mixed sentenceல rhythm கொஞ்சம் unevenஆ இருக்கு."
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"id": "ta_en_016",
|
| 147 |
+
"text": "if the plots look clean, appendixல example promptsமும் generated transcriptsமும் add பண்ணலாம்."
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"id": "ta_en_017",
|
| 151 |
+
"text": "இந்த reference clip calmஆ இருக்குது, அதனால் generated voiceவும் naturalஆ வர வாய்ப்பு அதிகம்."
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"id": "ta_en_018",
|
| 155 |
+
"text": "tonightக்குள் full batch finish ஆயிடுச்சுனா நாளைக்கு paper draftல numbers insert பண்ணலாம்."
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"id": "ta_en_019",
|
| 159 |
+
"text": "speaker clone நல்லா இருக்கு, ஆனால் English stress pattern இன்னும் fully consistent இல்ல."
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"id": "ta_en_020",
|
| 163 |
+
"text": "இந்த round முடிஞ்சதும் next stepஆ human listening test plan பண்ணலாம்."
|
| 164 |
+
}
|
| 165 |
+
]
|
| 166 |
+
}
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsndfile1
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=6,<7
|
| 2 |
+
huggingface_hub>=1.0
|
| 3 |
+
numpy<3
|
| 4 |
+
torch>=2.5.0
|
| 5 |
+
torchaudio>=2.5.0
|
| 6 |
+
transformers>=4.36.2
|
| 7 |
+
einops>=0.8.0
|
| 8 |
+
inflect>=7.0.0
|
| 9 |
+
wetext
|
| 10 |
+
librosa>=0.10.2
|
| 11 |
+
soundfile>=0.12.1
|
| 12 |
+
pydantic>=2
|
| 13 |
+
safetensors>=0.4.5
|
voxcpm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .core import VoxCPM
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"VoxCPM",
|
| 5 |
+
]
|
voxcpm/cli.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
VoxCPM Command Line Interface
|
| 4 |
+
|
| 5 |
+
VoxCPM2-first CLI for voice design, cloning, and batch processing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
|
| 16 |
+
from voxcpm.core import VoxCPM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
|
| 20 |
+
|
| 21 |
+
# -----------------------------
|
| 22 |
+
# Validators
|
| 23 |
+
# -----------------------------
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
| 27 |
+
path = Path(file_path)
|
| 28 |
+
if not path.exists():
|
| 29 |
+
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
| 30 |
+
return path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def require_file_exists(file_path: str, parser, file_type: str = "file") -> Path:
|
| 34 |
+
try:
|
| 35 |
+
return validate_file_exists(file_path, file_type)
|
| 36 |
+
except FileNotFoundError as exc:
|
| 37 |
+
parser.error(str(exc))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def validate_output_path(output_path: str) -> Path:
|
| 41 |
+
path = Path(output_path)
|
| 42 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
return path
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def validate_ranges(args, parser):
|
| 47 |
+
"""Validate numeric argument ranges."""
|
| 48 |
+
if not (0.1 <= args.cfg_value <= 10.0):
|
| 49 |
+
parser.error("--cfg-value must be between 0.1 and 10.0 (recommended: 1.0–3.0)")
|
| 50 |
+
|
| 51 |
+
if not (1 <= args.inference_timesteps <= 100):
|
| 52 |
+
parser.error("--inference-timesteps must be between 1 and 100 (recommended: 4–30)")
|
| 53 |
+
|
| 54 |
+
if args.lora_r <= 0:
|
| 55 |
+
parser.error("--lora-r must be a positive integer")
|
| 56 |
+
|
| 57 |
+
if args.lora_alpha <= 0:
|
| 58 |
+
parser.error("--lora-alpha must be a positive integer")
|
| 59 |
+
|
| 60 |
+
if not (0.0 <= args.lora_dropout <= 1.0):
|
| 61 |
+
parser.error("--lora-dropout must be between 0.0 and 1.0")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def warn_legacy_mode():
|
| 65 |
+
print(
|
| 66 |
+
"Warning: legacy root CLI arguments are deprecated. Prefer `voxcpm design|clone|batch ...`.",
|
| 67 |
+
file=sys.stderr,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_final_text(text: str, control: str | None) -> str:
|
| 72 |
+
control = (control or "").strip()
|
| 73 |
+
return f"({control}){text}" if control else text
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def resolve_prompt_text(args, parser) -> str | None:
|
| 77 |
+
prompt_text = getattr(args, "prompt_text", None)
|
| 78 |
+
prompt_file = getattr(args, "prompt_file", None)
|
| 79 |
+
|
| 80 |
+
if prompt_text and prompt_file:
|
| 81 |
+
parser.error("Use either --prompt-text or --prompt-file, not both.")
|
| 82 |
+
|
| 83 |
+
if prompt_file:
|
| 84 |
+
prompt_path = require_file_exists(prompt_file, parser, "prompt text file")
|
| 85 |
+
return prompt_path.read_text(encoding="utf-8").strip()
|
| 86 |
+
|
| 87 |
+
if prompt_text:
|
| 88 |
+
return prompt_text.strip()
|
| 89 |
+
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def detect_model_architecture(args) -> str | None:
|
| 94 |
+
model_location = getattr(args, "model_path", None) or getattr(
|
| 95 |
+
args, "hf_model_id", None
|
| 96 |
+
)
|
| 97 |
+
if not model_location:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
if os.path.isdir(model_location):
|
| 101 |
+
config_path = Path(model_location) / "config.json"
|
| 102 |
+
if not config_path.exists():
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 106 |
+
return json.load(f).get("architecture", "voxcpm").lower()
|
| 107 |
+
|
| 108 |
+
model_hint = str(model_location).lower()
|
| 109 |
+
if "voxcpm2" in model_hint:
|
| 110 |
+
return "voxcpm2"
|
| 111 |
+
if (
|
| 112 |
+
"voxcpm1.5" in model_hint
|
| 113 |
+
or "voxcpm-1.5" in model_hint
|
| 114 |
+
or "voxcpm_1.5" in model_hint
|
| 115 |
+
):
|
| 116 |
+
return "voxcpm"
|
| 117 |
+
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def validate_prompt_related_args(args, parser, prompt_text: str | None):
|
| 122 |
+
if prompt_text and not args.prompt_audio:
|
| 123 |
+
parser.error("--prompt-text/--prompt-file requires --prompt-audio.")
|
| 124 |
+
|
| 125 |
+
if args.prompt_audio and not prompt_text:
|
| 126 |
+
parser.error("--prompt-audio requires --prompt-text or --prompt-file.")
|
| 127 |
+
|
| 128 |
+
if args.control and prompt_text:
|
| 129 |
+
parser.error(
|
| 130 |
+
"--control cannot be used together with --prompt-text or --prompt-file."
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def validate_reference_support(args, parser):
|
| 135 |
+
if not getattr(args, "reference_audio", None):
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
arch = detect_model_architecture(args)
|
| 139 |
+
if arch == "voxcpm":
|
| 140 |
+
parser.error("--reference-audio is only supported with VoxCPM2 models.")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def validate_design_args(args, parser):
|
| 144 |
+
prompt_text = resolve_prompt_text(args, parser)
|
| 145 |
+
if args.prompt_audio or args.reference_audio or prompt_text:
|
| 146 |
+
parser.error(
|
| 147 |
+
"`design` does not accept prompt/reference audio. Use `clone` instead."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def validate_clone_args(args, parser):
|
| 152 |
+
prompt_text = resolve_prompt_text(args, parser)
|
| 153 |
+
validate_prompt_related_args(args, parser, prompt_text)
|
| 154 |
+
validate_reference_support(args, parser)
|
| 155 |
+
|
| 156 |
+
if not args.prompt_audio and not args.reference_audio:
|
| 157 |
+
parser.error(
|
| 158 |
+
"`clone` requires --reference-audio, or --prompt-audio with --prompt-text/--prompt-file."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return prompt_text
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def validate_batch_args(args, parser):
|
| 165 |
+
prompt_text = resolve_prompt_text(args, parser)
|
| 166 |
+
validate_prompt_related_args(args, parser, prompt_text)
|
| 167 |
+
validate_reference_support(args, parser)
|
| 168 |
+
return prompt_text
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# -----------------------------
|
| 172 |
+
# Model loading
|
| 173 |
+
# -----------------------------
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def load_model(args) -> VoxCPM:
|
| 177 |
+
print("Loading VoxCPM model...", file=sys.stderr)
|
| 178 |
+
|
| 179 |
+
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
| 180 |
+
"ZIPENHANCER_MODEL_PATH", None
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Build LoRA config if provided
|
| 184 |
+
lora_config = None
|
| 185 |
+
lora_weights_path = getattr(args, "lora_path", None)
|
| 186 |
+
if lora_weights_path:
|
| 187 |
+
from voxcpm.model.voxcpm import LoRAConfig
|
| 188 |
+
|
| 189 |
+
lora_config = LoRAConfig(
|
| 190 |
+
enable_lm=not args.lora_disable_lm,
|
| 191 |
+
enable_dit=not args.lora_disable_dit,
|
| 192 |
+
enable_proj=args.lora_enable_proj,
|
| 193 |
+
r=args.lora_r,
|
| 194 |
+
alpha=args.lora_alpha,
|
| 195 |
+
dropout=args.lora_dropout,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
print(
|
| 199 |
+
f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
| 200 |
+
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}",
|
| 201 |
+
file=sys.stderr,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Load local model if specified
|
| 205 |
+
if args.model_path:
|
| 206 |
+
try:
|
| 207 |
+
model = VoxCPM(
|
| 208 |
+
voxcpm_model_path=args.model_path,
|
| 209 |
+
zipenhancer_model_path=zipenhancer_path,
|
| 210 |
+
enable_denoiser=not args.no_denoiser,
|
| 211 |
+
optimize=not args.no_optimize,
|
| 212 |
+
lora_config=lora_config,
|
| 213 |
+
lora_weights_path=lora_weights_path,
|
| 214 |
+
)
|
| 215 |
+
print("Model loaded (local).", file=sys.stderr)
|
| 216 |
+
return model
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
| 219 |
+
sys.exit(1)
|
| 220 |
+
|
| 221 |
+
# Load from Hugging Face Hub
|
| 222 |
+
try:
|
| 223 |
+
model = VoxCPM.from_pretrained(
|
| 224 |
+
hf_model_id=args.hf_model_id,
|
| 225 |
+
load_denoiser=not args.no_denoiser,
|
| 226 |
+
zipenhancer_model_id=zipenhancer_path,
|
| 227 |
+
cache_dir=args.cache_dir,
|
| 228 |
+
local_files_only=args.local_files_only,
|
| 229 |
+
optimize=not args.no_optimize,
|
| 230 |
+
lora_config=lora_config,
|
| 231 |
+
lora_weights_path=lora_weights_path,
|
| 232 |
+
)
|
| 233 |
+
print("Model loaded (from_pretrained).", file=sys.stderr)
|
| 234 |
+
return model
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
|
| 237 |
+
sys.exit(1)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# -----------------------------
|
| 241 |
+
# Commands
|
| 242 |
+
# -----------------------------
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None):
|
| 246 |
+
output_path = validate_output_path(output)
|
| 247 |
+
|
| 248 |
+
if args.prompt_audio:
|
| 249 |
+
require_file_exists(args.prompt_audio, parser, "prompt audio file")
|
| 250 |
+
if args.reference_audio:
|
| 251 |
+
require_file_exists(args.reference_audio, parser, "reference audio file")
|
| 252 |
+
|
| 253 |
+
model = load_model(args)
|
| 254 |
+
|
| 255 |
+
audio_array = model.generate(
|
| 256 |
+
text=text,
|
| 257 |
+
prompt_wav_path=args.prompt_audio,
|
| 258 |
+
prompt_text=prompt_text,
|
| 259 |
+
reference_wav_path=args.reference_audio,
|
| 260 |
+
cfg_value=args.cfg_value,
|
| 261 |
+
inference_timesteps=args.inference_timesteps,
|
| 262 |
+
normalize=args.normalize,
|
| 263 |
+
denoise=args.denoise
|
| 264 |
+
and (args.prompt_audio is not None or args.reference_audio is not None),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
| 268 |
+
|
| 269 |
+
duration = len(audio_array) / model.tts_model.sample_rate
|
| 270 |
+
print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def cmd_design(args, parser):
|
| 274 |
+
validate_design_args(args, parser)
|
| 275 |
+
final_text = build_final_text(args.text, args.control)
|
| 276 |
+
return _run_single(
|
| 277 |
+
args, parser, text=final_text, output=args.output, prompt_text=None
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def cmd_clone(args, parser):
|
| 282 |
+
prompt_text = validate_clone_args(args, parser)
|
| 283 |
+
final_text = build_final_text(args.text, args.control)
|
| 284 |
+
return _run_single(
|
| 285 |
+
args, parser, text=final_text, output=args.output, prompt_text=prompt_text
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def cmd_batch(args, parser):
|
| 290 |
+
input_file = require_file_exists(args.input, parser, "input file")
|
| 291 |
+
output_dir = Path(args.output_dir)
|
| 292 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
| 295 |
+
texts = [line.strip() for line in f if line.strip()]
|
| 296 |
+
|
| 297 |
+
if not texts:
|
| 298 |
+
sys.exit("Error: Input file is empty")
|
| 299 |
+
|
| 300 |
+
prompt_text = validate_batch_args(args, parser)
|
| 301 |
+
model = load_model(args)
|
| 302 |
+
|
| 303 |
+
prompt_audio_path = None
|
| 304 |
+
if args.prompt_audio:
|
| 305 |
+
prompt_audio_path = str(
|
| 306 |
+
require_file_exists(args.prompt_audio, parser, "prompt audio file")
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
reference_audio_path = None
|
| 310 |
+
if args.reference_audio:
|
| 311 |
+
reference_audio_path = str(
|
| 312 |
+
require_file_exists(args.reference_audio, parser, "reference audio file")
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
success_count = 0
|
| 316 |
+
|
| 317 |
+
for i, text in enumerate(texts, 1):
|
| 318 |
+
try:
|
| 319 |
+
final_text = build_final_text(text, args.control)
|
| 320 |
+
audio_array = model.generate(
|
| 321 |
+
text=final_text,
|
| 322 |
+
prompt_wav_path=prompt_audio_path,
|
| 323 |
+
prompt_text=prompt_text,
|
| 324 |
+
reference_wav_path=reference_audio_path,
|
| 325 |
+
cfg_value=args.cfg_value,
|
| 326 |
+
inference_timesteps=args.inference_timesteps,
|
| 327 |
+
normalize=args.normalize,
|
| 328 |
+
denoise=args.denoise
|
| 329 |
+
and (prompt_audio_path is not None or reference_audio_path is not None),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
output_file = output_dir / f"output_{i:03d}.wav"
|
| 333 |
+
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
| 334 |
+
|
| 335 |
+
duration = len(audio_array) / model.tts_model.sample_rate
|
| 336 |
+
print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
| 337 |
+
success_count += 1
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f"Failed on line {i}: {e}", file=sys.stderr)
|
| 341 |
+
|
| 342 |
+
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# -----------------------------
|
| 346 |
+
# Parser
|
| 347 |
+
# -----------------------------
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _add_common_generation_args(parser):
|
| 351 |
+
parser.add_argument("--text", "-t", help="Text to synthesize")
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--control",
|
| 354 |
+
type=str,
|
| 355 |
+
help="Control instruction for VoxCPM2 voice design/cloning",
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--cfg-value",
|
| 359 |
+
type=float,
|
| 360 |
+
default=2.0,
|
| 361 |
+
help="CFG guidance scale (float, recommended 1.0–3.0, default: 2.0)",
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--inference-timesteps",
|
| 365 |
+
type=int,
|
| 366 |
+
default=10,
|
| 367 |
+
help="Inference steps (int, recommended 4–30, default: 10)",
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--normalize", action="store_true", help="Enable text normalization"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _add_prompt_reference_args(parser):
|
| 375 |
+
parser.add_argument(
|
| 376 |
+
"--prompt-audio",
|
| 377 |
+
"-pa",
|
| 378 |
+
help="Prompt audio file path (continuation mode, requires --prompt-text or --prompt-file)",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--prompt-text", "-pt", help="Text corresponding to the prompt audio"
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--prompt-file", type=str, help="Text file corresponding to the prompt audio"
|
| 385 |
+
)
|
| 386 |
+
parser.add_argument(
|
| 387 |
+
"--reference-audio",
|
| 388 |
+
"-ra",
|
| 389 |
+
help="Reference audio for voice cloning (VoxCPM2 only)",
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--denoise",
|
| 393 |
+
action="store_true",
|
| 394 |
+
help="Enable prompt/reference speech enhancement",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def _add_model_args(parser):
|
| 399 |
+
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--hf-model-id",
|
| 402 |
+
type=str,
|
| 403 |
+
default=DEFAULT_HF_MODEL_ID,
|
| 404 |
+
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
| 405 |
+
)
|
| 406 |
+
parser.add_argument(
|
| 407 |
+
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--local-files-only", action="store_true", help="Disable network access"
|
| 411 |
+
)
|
| 412 |
+
parser.add_argument(
|
| 413 |
+
"--no-denoiser", action="store_true", help="Disable denoiser model loading"
|
| 414 |
+
)
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--no-optimize",
|
| 417 |
+
action="store_true",
|
| 418 |
+
help="Disable model optimization during loading",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--zipenhancer-path",
|
| 422 |
+
type=str,
|
| 423 |
+
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)",
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _add_lora_args(parser):
|
| 428 |
+
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
| 429 |
+
parser.add_argument(
|
| 430 |
+
"--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)"
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
"--lora-alpha",
|
| 434 |
+
type=int,
|
| 435 |
+
default=16,
|
| 436 |
+
help="LoRA alpha (positive int, default: 16)",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--lora-dropout",
|
| 440 |
+
type=float,
|
| 441 |
+
default=0.0,
|
| 442 |
+
help="LoRA dropout rate (0.0–1.0, default: 0.0)",
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers"
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers"
|
| 449 |
+
)
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--lora-enable-proj",
|
| 452 |
+
action="store_true",
|
| 453 |
+
help="Enable LoRA on projection layers",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _build_parser():
|
| 458 |
+
parser = argparse.ArgumentParser(
|
| 459 |
+
description="VoxCPM CLI - VoxCPM2-first voice design, cloning, and batch processing",
|
| 460 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 461 |
+
epilog="""
|
| 462 |
+
Examples:
|
| 463 |
+
voxcpm design --text "Hello world" --output out.wav
|
| 464 |
+
voxcpm design --text "Hello world" --control "warm female voice" --output out.wav
|
| 465 |
+
voxcpm clone --text "Hello" --reference-audio ref.wav --output out.wav
|
| 466 |
+
voxcpm batch --input texts.txt --output-dir ./outs --reference-audio ref.wav
|
| 467 |
+
""",
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
subparsers = parser.add_subparsers(dest="command")
|
| 471 |
+
|
| 472 |
+
design_parser = subparsers.add_parser(
|
| 473 |
+
"design", help="Generate speech with VoxCPM2-first voice design"
|
| 474 |
+
)
|
| 475 |
+
_add_common_generation_args(design_parser)
|
| 476 |
+
_add_prompt_reference_args(design_parser)
|
| 477 |
+
_add_model_args(design_parser)
|
| 478 |
+
_add_lora_args(design_parser)
|
| 479 |
+
design_parser.add_argument(
|
| 480 |
+
"--output", "-o", required=True, help="Output audio file path"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
clone_parser = subparsers.add_parser(
|
| 484 |
+
"clone", help="Clone a voice with reference/prompt audio"
|
| 485 |
+
)
|
| 486 |
+
_add_common_generation_args(clone_parser)
|
| 487 |
+
_add_prompt_reference_args(clone_parser)
|
| 488 |
+
_add_model_args(clone_parser)
|
| 489 |
+
_add_lora_args(clone_parser)
|
| 490 |
+
clone_parser.add_argument(
|
| 491 |
+
"--output", "-o", required=True, help="Output audio file path"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
batch_parser = subparsers.add_parser(
|
| 495 |
+
"batch", help="Batch-generate one line per output file"
|
| 496 |
+
)
|
| 497 |
+
batch_parser.add_argument(
|
| 498 |
+
"--input", "-i", required=True, help="Input text file (one text per line)"
|
| 499 |
+
)
|
| 500 |
+
batch_parser.add_argument(
|
| 501 |
+
"--output-dir", "-od", required=True, help="Output directory"
|
| 502 |
+
)
|
| 503 |
+
batch_parser.add_argument(
|
| 504 |
+
"--control",
|
| 505 |
+
type=str,
|
| 506 |
+
help="Control instruction for VoxCPM2 voice design/cloning",
|
| 507 |
+
)
|
| 508 |
+
_add_prompt_reference_args(batch_parser)
|
| 509 |
+
batch_parser.add_argument(
|
| 510 |
+
"--cfg-value",
|
| 511 |
+
type=float,
|
| 512 |
+
default=2.0,
|
| 513 |
+
help="CFG guidance scale (float, recommended 1.0–3.0, default: 2.0)",
|
| 514 |
+
)
|
| 515 |
+
batch_parser.add_argument(
|
| 516 |
+
"--inference-timesteps",
|
| 517 |
+
type=int,
|
| 518 |
+
default=10,
|
| 519 |
+
help="Inference steps (int, recommended 4–30, default: 10)",
|
| 520 |
+
)
|
| 521 |
+
batch_parser.add_argument(
|
| 522 |
+
"--normalize", action="store_true", help="Enable text normalization"
|
| 523 |
+
)
|
| 524 |
+
_add_model_args(batch_parser)
|
| 525 |
+
_add_lora_args(batch_parser)
|
| 526 |
+
|
| 527 |
+
# Legacy root arguments
|
| 528 |
+
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
| 529 |
+
parser.add_argument(
|
| 530 |
+
"--output-dir", "-od", help="Output directory (batch mode only)"
|
| 531 |
+
)
|
| 532 |
+
_add_common_generation_args(parser)
|
| 533 |
+
parser.add_argument(
|
| 534 |
+
"--output", "-o", help="Output audio file path (single or clone mode)"
|
| 535 |
+
)
|
| 536 |
+
_add_prompt_reference_args(parser)
|
| 537 |
+
_add_model_args(parser)
|
| 538 |
+
_add_lora_args(parser)
|
| 539 |
+
|
| 540 |
+
return parser
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def _dispatch_legacy(args, parser):
|
| 544 |
+
warn_legacy_mode()
|
| 545 |
+
|
| 546 |
+
if args.input and args.text:
|
| 547 |
+
parser.error(
|
| 548 |
+
"Use either batch mode (--input) or single mode (--text), not both."
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if args.input:
|
| 552 |
+
if not args.output_dir:
|
| 553 |
+
parser.error("Batch mode requires --output-dir")
|
| 554 |
+
return cmd_batch(args, parser)
|
| 555 |
+
|
| 556 |
+
if not args.text or not args.output:
|
| 557 |
+
parser.error("Single-sample legacy mode requires --text and --output")
|
| 558 |
+
|
| 559 |
+
if (
|
| 560 |
+
args.prompt_audio
|
| 561 |
+
or args.prompt_text
|
| 562 |
+
or args.prompt_file
|
| 563 |
+
or args.reference_audio
|
| 564 |
+
):
|
| 565 |
+
return cmd_clone(args, parser)
|
| 566 |
+
|
| 567 |
+
return cmd_design(args, parser)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# -----------------------------
|
| 571 |
+
# Entrypoint
|
| 572 |
+
# -----------------------------
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def main():
|
| 576 |
+
parser = _build_parser()
|
| 577 |
+
args = parser.parse_args()
|
| 578 |
+
|
| 579 |
+
validate_ranges(args, parser)
|
| 580 |
+
|
| 581 |
+
if args.command == "design":
|
| 582 |
+
if not args.text:
|
| 583 |
+
parser.error("`design` requires --text")
|
| 584 |
+
return cmd_design(args, parser)
|
| 585 |
+
|
| 586 |
+
if args.command == "clone":
|
| 587 |
+
if not args.text or not args.output:
|
| 588 |
+
parser.error("`clone` requires --text and --output")
|
| 589 |
+
return cmd_clone(args, parser)
|
| 590 |
+
|
| 591 |
+
if args.command == "batch":
|
| 592 |
+
return cmd_batch(args, parser)
|
| 593 |
+
|
| 594 |
+
return _dispatch_legacy(args, parser)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
if __name__ == "__main__":
|
| 598 |
+
main()
|
voxcpm/core.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import tempfile
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Generator, Optional
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
| 10 |
+
from .model.voxcpm2 import VoxCPM2Model
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VoxCPM:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
voxcpm_model_path: str,
|
| 17 |
+
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
| 18 |
+
enable_denoiser: bool = True,
|
| 19 |
+
optimize: bool = True,
|
| 20 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 21 |
+
lora_weights_path: Optional[str] = None,
|
| 22 |
+
):
|
| 23 |
+
"""Initialize VoxCPM TTS pipeline.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
voxcpm_model_path: Local filesystem path to the VoxCPM model assets
|
| 27 |
+
(weights, configs, etc.). Typically the directory returned by
|
| 28 |
+
a prior download step.
|
| 29 |
+
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
| 30 |
+
id or local path. If None, denoiser will not be initialized.
|
| 31 |
+
enable_denoiser: Whether to initialize the denoiser pipeline.
|
| 32 |
+
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
| 33 |
+
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
| 34 |
+
provided without lora_config, a default config will be created.
|
| 35 |
+
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
| 36 |
+
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
| 37 |
+
"""
|
| 38 |
+
print(
|
| 39 |
+
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
|
| 40 |
+
file=sys.stderr,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# If lora_weights_path is provided but no lora_config, create a default one
|
| 44 |
+
if lora_weights_path is not None and lora_config is None:
|
| 45 |
+
lora_config = LoRAConfig(
|
| 46 |
+
enable_lm=True,
|
| 47 |
+
enable_dit=True,
|
| 48 |
+
enable_proj=False,
|
| 49 |
+
)
|
| 50 |
+
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
| 51 |
+
|
| 52 |
+
# Determine model type from config.json architecture field
|
| 53 |
+
config_path = os.path.join(voxcpm_model_path, "config.json")
|
| 54 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 55 |
+
config = json.load(f)
|
| 56 |
+
arch = config.get("architecture", "voxcpm").lower()
|
| 57 |
+
|
| 58 |
+
if arch == "voxcpm2":
|
| 59 |
+
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
| 60 |
+
print("Loaded VoxCPM2Model", file=sys.stderr)
|
| 61 |
+
elif arch == "voxcpm":
|
| 62 |
+
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
| 63 |
+
print("Loaded VoxCPMModel", file=sys.stderr)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unsupported architecture: {arch}")
|
| 66 |
+
|
| 67 |
+
# Load LoRA weights if path is provided
|
| 68 |
+
if lora_weights_path is not None:
|
| 69 |
+
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
| 70 |
+
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
| 71 |
+
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
| 72 |
+
|
| 73 |
+
self.text_normalizer = None
|
| 74 |
+
self.denoiser = None
|
| 75 |
+
if enable_denoiser and zipenhancer_model_path is not None:
|
| 76 |
+
from .zipenhancer import ZipEnhancer
|
| 77 |
+
|
| 78 |
+
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
| 79 |
+
else:
|
| 80 |
+
self.denoiser = None
|
| 81 |
+
if optimize:
|
| 82 |
+
print("Warm up VoxCPMModel...", file=sys.stderr)
|
| 83 |
+
self.tts_model.generate(
|
| 84 |
+
target_text="Hello, this is the first test sentence.",
|
| 85 |
+
max_len=10,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def from_pretrained(
|
| 90 |
+
cls,
|
| 91 |
+
hf_model_id: str = "openbmb/VoxCPM2",
|
| 92 |
+
load_denoiser: bool = True,
|
| 93 |
+
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
| 94 |
+
cache_dir: str = None,
|
| 95 |
+
local_files_only: bool = False,
|
| 96 |
+
optimize: bool = True,
|
| 97 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 98 |
+
lora_weights_path: Optional[str] = None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
|
| 105 |
+
load_denoiser: Whether to initialize the denoiser pipeline.
|
| 106 |
+
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
| 107 |
+
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
| 108 |
+
acoustic noise suppression.
|
| 109 |
+
cache_dir: Custom cache directory for the snapshot.
|
| 110 |
+
local_files_only: If True, only use local files and do not attempt
|
| 111 |
+
to download.
|
| 112 |
+
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
| 113 |
+
provided without lora_config, a default config will be created with
|
| 114 |
+
enable_lm=True and enable_dit=True.
|
| 115 |
+
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
| 116 |
+
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
|
| 117 |
+
after model initialization.
|
| 118 |
+
Kwargs:
|
| 119 |
+
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
| 123 |
+
the downloaded snapshot directory.
|
| 124 |
+
|
| 125 |
+
Raises:
|
| 126 |
+
ValueError: If neither a valid ``hf_model_id`` nor a resolvable
|
| 127 |
+
``hf_model_id`` is provided.
|
| 128 |
+
"""
|
| 129 |
+
repo_id = hf_model_id
|
| 130 |
+
if not repo_id:
|
| 131 |
+
raise ValueError("You must provide hf_model_id")
|
| 132 |
+
|
| 133 |
+
# Load from local path if provided
|
| 134 |
+
if os.path.isdir(repo_id):
|
| 135 |
+
local_path = repo_id
|
| 136 |
+
else:
|
| 137 |
+
# Otherwise, try from_pretrained (Hub); exit on failure
|
| 138 |
+
local_path = snapshot_download(
|
| 139 |
+
repo_id=repo_id,
|
| 140 |
+
cache_dir=cache_dir,
|
| 141 |
+
local_files_only=local_files_only,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return cls(
|
| 145 |
+
voxcpm_model_path=local_path,
|
| 146 |
+
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
| 147 |
+
enable_denoiser=load_denoiser,
|
| 148 |
+
optimize=optimize,
|
| 149 |
+
lora_config=lora_config,
|
| 150 |
+
lora_weights_path=lora_weights_path,
|
| 151 |
+
**kwargs,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def generate(self, *args, **kwargs) -> np.ndarray:
|
| 155 |
+
return next(self._generate(*args, streaming=False, **kwargs))
|
| 156 |
+
|
| 157 |
+
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
| 158 |
+
return self._generate(*args, streaming=True, **kwargs)
|
| 159 |
+
|
| 160 |
+
def _generate(
|
| 161 |
+
self,
|
| 162 |
+
text: str,
|
| 163 |
+
prompt_wav_path: str = None,
|
| 164 |
+
prompt_text: str = None,
|
| 165 |
+
reference_wav_path: str = None,
|
| 166 |
+
cfg_value: float = 2.0,
|
| 167 |
+
inference_timesteps: int = 10,
|
| 168 |
+
min_len: int = 2,
|
| 169 |
+
max_len: int = 4096,
|
| 170 |
+
normalize: bool = False,
|
| 171 |
+
denoise: bool = False,
|
| 172 |
+
retry_badcase: bool = True,
|
| 173 |
+
retry_badcase_max_times: int = 3,
|
| 174 |
+
retry_badcase_ratio_threshold: float = 6.0,
|
| 175 |
+
streaming: bool = False,
|
| 176 |
+
) -> Generator[np.ndarray, None, None]:
|
| 177 |
+
"""Synthesize speech for the given text and return a single waveform.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
text: Input text to synthesize.
|
| 181 |
+
prompt_wav_path: Path to prompt audio for continuation mode.
|
| 182 |
+
Must be paired with ``prompt_text``.
|
| 183 |
+
prompt_text: Text content corresponding to the prompt audio.
|
| 184 |
+
reference_wav_path: Path to reference audio for voice cloning
|
| 185 |
+
(structurally isolated via ref_audio tokens). Can be used
|
| 186 |
+
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
|
| 187 |
+
cfg_value: Guidance scale for the generation model.
|
| 188 |
+
inference_timesteps: Number of inference steps.
|
| 189 |
+
min_len: Minimum audio length.
|
| 190 |
+
max_len: Maximum token length during generation.
|
| 191 |
+
normalize: Whether to run text normalization before generation.
|
| 192 |
+
denoise: Whether to denoise the prompt/reference audio if a
|
| 193 |
+
denoiser is available.
|
| 194 |
+
retry_badcase: Whether to retry badcase.
|
| 195 |
+
retry_badcase_max_times: Maximum number of times to retry badcase.
|
| 196 |
+
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
| 197 |
+
streaming: Whether to return a generator of audio chunks.
|
| 198 |
+
Returns:
|
| 199 |
+
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
| 200 |
+
Yields audio chunks for each generation step if ``streaming=True``,
|
| 201 |
+
otherwise yields a single array containing the final audio.
|
| 202 |
+
"""
|
| 203 |
+
if not text.strip() or not isinstance(text, str):
|
| 204 |
+
raise ValueError("target text must be a non-empty string")
|
| 205 |
+
|
| 206 |
+
if prompt_wav_path is not None:
|
| 207 |
+
if not os.path.exists(prompt_wav_path):
|
| 208 |
+
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
| 209 |
+
|
| 210 |
+
if reference_wav_path is not None:
|
| 211 |
+
if not os.path.exists(reference_wav_path):
|
| 212 |
+
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
|
| 213 |
+
|
| 214 |
+
if (prompt_wav_path is None) != (prompt_text is None):
|
| 215 |
+
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
| 216 |
+
|
| 217 |
+
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
|
| 218 |
+
if reference_wav_path is not None and not is_v2:
|
| 219 |
+
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
|
| 220 |
+
|
| 221 |
+
text = text.replace("\n", " ")
|
| 222 |
+
text = re.sub(r"\s+", " ", text)
|
| 223 |
+
temp_files = []
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
actual_prompt_path = prompt_wav_path
|
| 227 |
+
actual_ref_path = reference_wav_path
|
| 228 |
+
|
| 229 |
+
if denoise and self.denoiser is not None:
|
| 230 |
+
if prompt_wav_path is not None:
|
| 231 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 232 |
+
temp_files.append(tmp.name)
|
| 233 |
+
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
|
| 234 |
+
actual_prompt_path = temp_files[-1]
|
| 235 |
+
if reference_wav_path is not None:
|
| 236 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 237 |
+
temp_files.append(tmp.name)
|
| 238 |
+
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
|
| 239 |
+
actual_ref_path = temp_files[-1]
|
| 240 |
+
|
| 241 |
+
if actual_prompt_path is not None or actual_ref_path is not None:
|
| 242 |
+
if is_v2:
|
| 243 |
+
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
| 244 |
+
prompt_text=prompt_text,
|
| 245 |
+
prompt_wav_path=actual_prompt_path,
|
| 246 |
+
reference_wav_path=actual_ref_path,
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
| 250 |
+
prompt_text=prompt_text,
|
| 251 |
+
prompt_wav_path=actual_prompt_path,
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
fixed_prompt_cache = None
|
| 255 |
+
|
| 256 |
+
if normalize:
|
| 257 |
+
if self.text_normalizer is None:
|
| 258 |
+
from .utils.text_normalize import TextNormalizer
|
| 259 |
+
|
| 260 |
+
self.text_normalizer = TextNormalizer()
|
| 261 |
+
text = self.text_normalizer.normalize(text)
|
| 262 |
+
|
| 263 |
+
generate_result = self.tts_model._generate_with_prompt_cache(
|
| 264 |
+
target_text=text,
|
| 265 |
+
prompt_cache=fixed_prompt_cache,
|
| 266 |
+
min_len=min_len,
|
| 267 |
+
max_len=max_len,
|
| 268 |
+
inference_timesteps=inference_timesteps,
|
| 269 |
+
cfg_value=cfg_value,
|
| 270 |
+
retry_badcase=retry_badcase,
|
| 271 |
+
retry_badcase_max_times=retry_badcase_max_times,
|
| 272 |
+
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
| 273 |
+
streaming=streaming,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
for wav, _, _ in generate_result:
|
| 277 |
+
yield wav.squeeze(0).cpu().numpy()
|
| 278 |
+
|
| 279 |
+
finally:
|
| 280 |
+
for tmp_path in temp_files:
|
| 281 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 282 |
+
try:
|
| 283 |
+
os.unlink(tmp_path)
|
| 284 |
+
except OSError:
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
# ------------------------------------------------------------------ #
|
| 288 |
+
# LoRA Interface (delegated to VoxCPMModel)
|
| 289 |
+
# ------------------------------------------------------------------ #
|
| 290 |
+
def load_lora(self, lora_weights_path: str) -> tuple:
|
| 291 |
+
"""Load LoRA weights from a checkpoint file.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
lora_weights_path: Path to LoRA weights (.pth file or directory
|
| 295 |
+
containing lora_weights.ckpt).
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
| 299 |
+
|
| 300 |
+
Raises:
|
| 301 |
+
RuntimeError: If model was not initialized with LoRA config.
|
| 302 |
+
"""
|
| 303 |
+
if self.tts_model.lora_config is None:
|
| 304 |
+
raise RuntimeError(
|
| 305 |
+
"Cannot load LoRA weights: model was not initialized with LoRA config. "
|
| 306 |
+
"Please reinitialize with lora_config or lora_weights_path parameter."
|
| 307 |
+
)
|
| 308 |
+
return self.tts_model.load_lora_weights(lora_weights_path)
|
| 309 |
+
|
| 310 |
+
def unload_lora(self):
|
| 311 |
+
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
| 312 |
+
self.tts_model.reset_lora_weights()
|
| 313 |
+
|
| 314 |
+
def set_lora_enabled(self, enabled: bool):
|
| 315 |
+
"""Enable or disable LoRA layers without unloading weights.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
enabled: If True, LoRA layers are active; if False, only base model is used.
|
| 319 |
+
"""
|
| 320 |
+
self.tts_model.set_lora_enabled(enabled)
|
| 321 |
+
|
| 322 |
+
def get_lora_state_dict(self) -> dict:
|
| 323 |
+
"""Get current LoRA parameters state dict.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
| 327 |
+
"""
|
| 328 |
+
return self.tts_model.get_lora_state_dict()
|
| 329 |
+
|
| 330 |
+
@property
|
| 331 |
+
def lora_enabled(self) -> bool:
|
| 332 |
+
"""Check if LoRA is currently configured."""
|
| 333 |
+
return self.tts_model.lora_config is not None
|
voxcpm/model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .voxcpm import VoxCPMModel
|
| 2 |
+
from .voxcpm2 import VoxCPM2Model
|
| 3 |
+
|
| 4 |
+
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
|
voxcpm/model/utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import PreTrainedTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
| 7 |
+
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
| 8 |
+
|
| 9 |
+
This function creates a wrapper around the provided tokenizer that automatically
|
| 10 |
+
splits multi-character Chinese tokens into individual characters. This is useful
|
| 11 |
+
for ensuring consistent tokenization of Chinese text.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
tokenizer: The base tokenizer to wrap
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
|
| 18 |
+
|
| 19 |
+
Example:
|
| 20 |
+
>>> from transformers import LlamaTokenizerFast
|
| 21 |
+
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
|
| 22 |
+
>>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
| 23 |
+
>>> tokens = wrapped_tokenizer("你好世界")
|
| 24 |
+
"""
|
| 25 |
+
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
| 26 |
+
multichar_tokens = {
|
| 27 |
+
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
class CharTokenizerWrapper:
|
| 31 |
+
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
|
| 32 |
+
|
| 33 |
+
This wrapper automatically splits multi-character Chinese tokens into
|
| 34 |
+
individual characters while preserving the original tokenizer's interface.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
|
| 38 |
+
"""Initialize the wrapper with a base tokenizer.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
base_tokenizer: The tokenizer to wrap
|
| 42 |
+
"""
|
| 43 |
+
self.tokenizer = base_tokenizer
|
| 44 |
+
self.multichar_tokens = multichar_tokens
|
| 45 |
+
|
| 46 |
+
def tokenize(self, text: str, **kwargs) -> List[str]:
|
| 47 |
+
"""Tokenize text and split multi-character Chinese tokens into single characters.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
text: Input text to tokenize
|
| 51 |
+
**kwargs: Additional arguments passed to the base tokenizer
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of processed tokens with multi-character Chinese tokens split
|
| 55 |
+
|
| 56 |
+
Example:
|
| 57 |
+
>>> wrapper = CharTokenizerWrapper(tokenizer)
|
| 58 |
+
>>> tokens = wrapper.tokenize("你好世界")
|
| 59 |
+
>>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
|
| 60 |
+
"""
|
| 61 |
+
if not isinstance(text, str):
|
| 62 |
+
raise TypeError(f"Expected string input, got {type(text)}")
|
| 63 |
+
|
| 64 |
+
tokens = self.tokenizer.tokenize(text, **kwargs)
|
| 65 |
+
processed = []
|
| 66 |
+
|
| 67 |
+
for token in tokens:
|
| 68 |
+
# Remove possible subword prefix
|
| 69 |
+
clean_token = token.replace("▁", "")
|
| 70 |
+
|
| 71 |
+
if clean_token in self.multichar_tokens:
|
| 72 |
+
# Split multi-character token into single characters
|
| 73 |
+
chars = list(clean_token)
|
| 74 |
+
processed.extend(chars)
|
| 75 |
+
else:
|
| 76 |
+
processed.append(token)
|
| 77 |
+
|
| 78 |
+
return processed
|
| 79 |
+
|
| 80 |
+
def __call__(self, text: str, **kwargs) -> List[int]:
|
| 81 |
+
"""Call the tokenizer and return token IDs.
|
| 82 |
+
|
| 83 |
+
This method provides the same interface as the original tokenizer
|
| 84 |
+
but with multi-character Chinese token handling.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
text: Input text to tokenize
|
| 88 |
+
**kwargs: Additional arguments passed to the base tokenizer
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
List of token IDs
|
| 92 |
+
|
| 93 |
+
Raises:
|
| 94 |
+
TypeError: If input is not a string
|
| 95 |
+
ValueError: If tokenization fails
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
tokens = self.tokenize(text, **kwargs)
|
| 99 |
+
result = self.tokenizer.convert_tokens_to_ids(tokens)
|
| 100 |
+
return result
|
| 101 |
+
except Exception as e:
|
| 102 |
+
raise ValueError(f"Tokenization failed: {str(e)}") from e
|
| 103 |
+
|
| 104 |
+
return CharTokenizerWrapper(tokenizer)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_dtype(dtype: str):
|
| 108 |
+
if dtype == "bfloat16":
|
| 109 |
+
return torch.bfloat16
|
| 110 |
+
elif dtype == "bf16":
|
| 111 |
+
return torch.bfloat16
|
| 112 |
+
elif dtype == "float16":
|
| 113 |
+
return torch.float16
|
| 114 |
+
elif dtype == "fp16":
|
| 115 |
+
return torch.float16
|
| 116 |
+
elif dtype == "float32":
|
| 117 |
+
return torch.float32
|
| 118 |
+
elif dtype == "fp32":
|
| 119 |
+
return torch.float32
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
voxcpm/model/voxcpm.py
ADDED
|
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoxCPM: A Tokenizer-free speech generation model
|
| 3 |
+
|
| 4 |
+
This module contains the main VoxCPM model implementation, including configuration classes
|
| 5 |
+
and the core VoxCPMModel for text-to-speech generation.
|
| 6 |
+
|
| 7 |
+
Copyright 2025 OpenBMB
|
| 8 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
you may not use this file except in compliance with the License.
|
| 10 |
+
You may obtain a copy of the License at
|
| 11 |
+
|
| 12 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
|
| 14 |
+
Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
See the License for the specific language governing permissions and
|
| 18 |
+
limitations under the License.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
from typing import Tuple, Union, Generator, List, Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torchaudio
|
| 28 |
+
import warnings
|
| 29 |
+
from einops import rearrange
|
| 30 |
+
from pydantic import BaseModel
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from safetensors.torch import load_file
|
| 34 |
+
|
| 35 |
+
SAFETENSORS_AVAILABLE = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
SAFETENSORS_AVAILABLE = False
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
from transformers import LlamaTokenizerFast
|
| 40 |
+
|
| 41 |
+
from ..modules.audiovae import AudioVAE, AudioVAEConfig
|
| 42 |
+
from ..modules.layers import ScalarQuantizationLayer
|
| 43 |
+
from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
| 44 |
+
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
| 45 |
+
from ..modules.locenc import VoxCPMLocEnc
|
| 46 |
+
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
| 47 |
+
from .utils import get_dtype, mask_multichar_chinese_tokens
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class VoxCPMEncoderConfig(BaseModel):
|
| 51 |
+
hidden_dim: int = 1024
|
| 52 |
+
ffn_dim: int = 4096
|
| 53 |
+
num_heads: int = 16
|
| 54 |
+
num_layers: int = 4
|
| 55 |
+
kv_channels: int = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class VoxCPMDitConfig(BaseModel):
|
| 59 |
+
hidden_dim: int = 1024
|
| 60 |
+
ffn_dim: int = 4096
|
| 61 |
+
num_heads: int = 16
|
| 62 |
+
num_layers: int = 4
|
| 63 |
+
kv_channels: int = None
|
| 64 |
+
|
| 65 |
+
cfm_config: CfmConfig
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class VoxCPMConfig(BaseModel):
|
| 69 |
+
lm_config: MiniCPM4Config
|
| 70 |
+
patch_size: int = 2
|
| 71 |
+
feat_dim: int = 64
|
| 72 |
+
residual_lm_num_layers: int = 6
|
| 73 |
+
scalar_quantization_latent_dim: int = 256
|
| 74 |
+
scalar_quantization_scale: int = 9
|
| 75 |
+
|
| 76 |
+
encoder_config: VoxCPMEncoderConfig
|
| 77 |
+
dit_config: VoxCPMDitConfig
|
| 78 |
+
audio_vae_config: Optional[AudioVAEConfig] = None
|
| 79 |
+
|
| 80 |
+
max_length: int = 4096
|
| 81 |
+
device: str = "cuda"
|
| 82 |
+
dtype: str = "bfloat16"
|
| 83 |
+
dit_mean_mode: bool = False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class LoRAConfig(BaseModel):
|
| 87 |
+
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
| 88 |
+
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
| 89 |
+
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
| 90 |
+
|
| 91 |
+
r: int = 8
|
| 92 |
+
alpha: int = 16
|
| 93 |
+
dropout: float = 0.0
|
| 94 |
+
|
| 95 |
+
# Target linear layer names for LM & DiT (matched by attribute name)
|
| 96 |
+
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 97 |
+
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 98 |
+
# Projection layer attribute names to find on VoxCPMModel
|
| 99 |
+
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
VoxCPMConfig.model_rebuild()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class VoxCPMModel(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
config: VoxCPMConfig,
|
| 109 |
+
tokenizer: LlamaTokenizerFast,
|
| 110 |
+
audio_vae: AudioVAE,
|
| 111 |
+
lora_config: LoRAConfig = None,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.config = config
|
| 115 |
+
self.lora_config = lora_config
|
| 116 |
+
self.feat_dim = config.feat_dim
|
| 117 |
+
self.patch_size = config.patch_size
|
| 118 |
+
self.device = config.device
|
| 119 |
+
if not torch.cuda.is_available():
|
| 120 |
+
if torch.backends.mps.is_available():
|
| 121 |
+
self.device = "mps"
|
| 122 |
+
else:
|
| 123 |
+
self.device = "cpu"
|
| 124 |
+
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
| 125 |
+
|
| 126 |
+
# Text-Semantic LM
|
| 127 |
+
self.base_lm = MiniCPMModel(config.lm_config)
|
| 128 |
+
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
| 129 |
+
|
| 130 |
+
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
| 131 |
+
self.audio_start_token = 101
|
| 132 |
+
self.audio_end_token = 102
|
| 133 |
+
|
| 134 |
+
# Residual Acoustic LM
|
| 135 |
+
residual_lm_config = config.lm_config.model_copy(deep=True)
|
| 136 |
+
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
| 137 |
+
residual_lm_config.vocab_size = 0
|
| 138 |
+
self.residual_lm = MiniCPMModel(residual_lm_config)
|
| 139 |
+
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
| 140 |
+
|
| 141 |
+
# Local Encoder
|
| 142 |
+
encoder_config = config.lm_config.model_copy(deep=True)
|
| 143 |
+
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
| 144 |
+
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
| 145 |
+
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
| 146 |
+
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
| 147 |
+
encoder_config.kv_channels = config.encoder_config.kv_channels
|
| 148 |
+
encoder_config.vocab_size = 0
|
| 149 |
+
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
| 150 |
+
|
| 151 |
+
# Local DiT
|
| 152 |
+
decoder_config = config.lm_config.model_copy(deep=True)
|
| 153 |
+
decoder_config.hidden_size = config.dit_config.hidden_dim
|
| 154 |
+
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
| 155 |
+
decoder_config.num_attention_heads = config.dit_config.num_heads
|
| 156 |
+
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
| 157 |
+
decoder_config.kv_channels = config.dit_config.kv_channels
|
| 158 |
+
decoder_config.vocab_size = 0
|
| 159 |
+
self.feat_decoder = UnifiedCFM(
|
| 160 |
+
in_channels=config.feat_dim,
|
| 161 |
+
cfm_params=config.dit_config.cfm_config,
|
| 162 |
+
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
| 163 |
+
mean_mode=config.dit_mean_mode,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Projection layers
|
| 167 |
+
self.fsq_layer = ScalarQuantizationLayer(
|
| 168 |
+
config.lm_config.hidden_size,
|
| 169 |
+
config.lm_config.hidden_size,
|
| 170 |
+
config.scalar_quantization_latent_dim,
|
| 171 |
+
config.scalar_quantization_scale,
|
| 172 |
+
)
|
| 173 |
+
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
| 174 |
+
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
| 175 |
+
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
| 176 |
+
|
| 177 |
+
# Stop Predictor
|
| 178 |
+
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
| 179 |
+
self.stop_actn = nn.SiLU()
|
| 180 |
+
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
| 181 |
+
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
|
| 182 |
+
|
| 183 |
+
# Audio VAE
|
| 184 |
+
self.audio_vae = audio_vae
|
| 185 |
+
self.chunk_size = audio_vae.chunk_size
|
| 186 |
+
self.sample_rate = audio_vae.sample_rate
|
| 187 |
+
|
| 188 |
+
if self.lora_config is not None:
|
| 189 |
+
self._apply_lora()
|
| 190 |
+
|
| 191 |
+
def _apply_lora(self):
|
| 192 |
+
"""注入 LoRA 到 LM / DiT / 投影层"""
|
| 193 |
+
cfg = self.lora_config
|
| 194 |
+
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
|
| 195 |
+
|
| 196 |
+
# LM: base_lm + residual_lm
|
| 197 |
+
if cfg.enable_lm:
|
| 198 |
+
for lm in [self.base_lm, self.residual_lm]:
|
| 199 |
+
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
| 200 |
+
|
| 201 |
+
# DiT: feat_decoder.estimator
|
| 202 |
+
if cfg.enable_dit:
|
| 203 |
+
apply_lora_to_named_linear_modules(
|
| 204 |
+
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# 投影层
|
| 208 |
+
if cfg.enable_proj:
|
| 209 |
+
from ..modules.layers.lora import LoRALinear
|
| 210 |
+
|
| 211 |
+
for attr_name in cfg.target_proj_modules:
|
| 212 |
+
module = getattr(self, attr_name, None)
|
| 213 |
+
if isinstance(module, nn.Linear):
|
| 214 |
+
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
|
| 215 |
+
|
| 216 |
+
def optimize(self, disable: bool = False):
|
| 217 |
+
if disable:
|
| 218 |
+
return self
|
| 219 |
+
try:
|
| 220 |
+
if self.device != "cuda":
|
| 221 |
+
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
| 222 |
+
try:
|
| 223 |
+
import triton # noqa: F401
|
| 224 |
+
except ImportError:
|
| 225 |
+
raise ValueError("triton is not installed")
|
| 226 |
+
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
| 227 |
+
self.residual_lm.forward_step = torch.compile(
|
| 228 |
+
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
| 229 |
+
)
|
| 230 |
+
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
| 231 |
+
self.feat_decoder.estimator = torch.compile(
|
| 232 |
+
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
| 233 |
+
)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
| 236 |
+
return self
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
text_tokens: torch.Tensor,
|
| 241 |
+
text_mask: torch.Tensor,
|
| 242 |
+
audio_feats: torch.Tensor,
|
| 243 |
+
audio_mask: torch.Tensor,
|
| 244 |
+
loss_mask: torch.Tensor,
|
| 245 |
+
position_ids: torch.Tensor,
|
| 246 |
+
labels: torch.Tensor,
|
| 247 |
+
*,
|
| 248 |
+
progress: float = 0.0,
|
| 249 |
+
sample_generate: bool = False,
|
| 250 |
+
):
|
| 251 |
+
del position_ids # not used yet
|
| 252 |
+
|
| 253 |
+
text_tokens = text_tokens.to(self.device, dtype=torch.long)
|
| 254 |
+
text_mask = text_mask.to(self.device, dtype=self._dtype())
|
| 255 |
+
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
|
| 256 |
+
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
|
| 257 |
+
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
|
| 258 |
+
labels = labels.to(self.device, dtype=torch.long)
|
| 259 |
+
|
| 260 |
+
B, T, P, D = audio_feats.shape
|
| 261 |
+
feat_embed = self.feat_encoder(audio_feats)
|
| 262 |
+
feat_embed = self.enc_to_lm_proj(feat_embed)
|
| 263 |
+
|
| 264 |
+
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
|
| 265 |
+
if not getattr(self.config.lm_config, "use_mup", False):
|
| 266 |
+
scale_emb = 1.0
|
| 267 |
+
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
|
| 268 |
+
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
|
| 269 |
+
|
| 270 |
+
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
|
| 271 |
+
enc_outputs = enc_outputs.to(self._dtype())
|
| 272 |
+
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
| 273 |
+
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
|
| 274 |
+
|
| 275 |
+
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
|
| 276 |
+
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
|
| 277 |
+
residual_outputs = residual_outputs.to(self._dtype())
|
| 278 |
+
residual_hidden = torch.cat(
|
| 279 |
+
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
|
| 280 |
+
dim=1,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
|
| 284 |
+
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
|
| 285 |
+
|
| 286 |
+
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
|
| 287 |
+
target_dtype = self._dtype()
|
| 288 |
+
|
| 289 |
+
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
|
| 290 |
+
feat_cond = torch.cat(
|
| 291 |
+
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
|
| 292 |
+
dim=1,
|
| 293 |
+
)
|
| 294 |
+
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
|
| 295 |
+
|
| 296 |
+
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
|
| 297 |
+
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
|
| 298 |
+
|
| 299 |
+
diff_loss = self.feat_decoder.compute_loss(
|
| 300 |
+
feat_gt.transpose(1, 2).contiguous(),
|
| 301 |
+
dit_hidden,
|
| 302 |
+
cond=feat_cond.transpose(1, 2).contiguous(),
|
| 303 |
+
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
|
| 304 |
+
progress=progress,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
|
| 308 |
+
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
|
| 309 |
+
denom = torch.clamp(loss_mask.sum(), min=1.0)
|
| 310 |
+
stop_loss = (stop_losses * loss_mask).sum() / denom
|
| 311 |
+
|
| 312 |
+
feat_pred = None
|
| 313 |
+
if sample_generate:
|
| 314 |
+
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
|
| 315 |
+
feat_pred_seq = self.feat_decoder(
|
| 316 |
+
mu=dit_hidden,
|
| 317 |
+
patch_size=self.patch_size,
|
| 318 |
+
cond=feat_cond_for_sample,
|
| 319 |
+
n_timesteps=(
|
| 320 |
+
self.config.dit_config.cfm_config.inference_cfg_rate
|
| 321 |
+
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
| 322 |
+
else 10
|
| 323 |
+
),
|
| 324 |
+
)
|
| 325 |
+
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
| 326 |
+
|
| 327 |
+
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
|
| 328 |
+
|
| 329 |
+
return {
|
| 330 |
+
"loss/diff": diff_loss,
|
| 331 |
+
"loss/stop": stop_loss,
|
| 332 |
+
"feat_gt": feat_gt_tensor,
|
| 333 |
+
"feat_pred": feat_pred,
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
def _dtype(self):
|
| 337 |
+
return get_dtype(self.config.dtype)
|
| 338 |
+
|
| 339 |
+
def generate(self, *args, **kwargs) -> torch.Tensor:
|
| 340 |
+
return next(self._generate(*args, streaming=False, **kwargs))
|
| 341 |
+
|
| 342 |
+
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
| 343 |
+
return self._generate(*args, streaming=True, **kwargs)
|
| 344 |
+
|
| 345 |
+
@torch.inference_mode()
|
| 346 |
+
def _generate(
|
| 347 |
+
self,
|
| 348 |
+
target_text: str,
|
| 349 |
+
prompt_text: str = "",
|
| 350 |
+
prompt_wav_path: str = "",
|
| 351 |
+
min_len: int = 2,
|
| 352 |
+
max_len: int = 2000,
|
| 353 |
+
inference_timesteps: int = 10,
|
| 354 |
+
cfg_value: float = 2.0,
|
| 355 |
+
retry_badcase: bool = False,
|
| 356 |
+
retry_badcase_max_times: int = 3,
|
| 357 |
+
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
| 358 |
+
streaming: bool = False,
|
| 359 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 360 |
+
if retry_badcase and streaming:
|
| 361 |
+
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
| 362 |
+
retry_badcase = False
|
| 363 |
+
if len(prompt_wav_path) == 0:
|
| 364 |
+
text = target_text
|
| 365 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 366 |
+
text_token = torch.cat(
|
| 367 |
+
[
|
| 368 |
+
text_token,
|
| 369 |
+
torch.tensor(
|
| 370 |
+
[self.audio_start_token],
|
| 371 |
+
dtype=torch.int32,
|
| 372 |
+
device=text_token.device,
|
| 373 |
+
),
|
| 374 |
+
],
|
| 375 |
+
dim=-1,
|
| 376 |
+
)
|
| 377 |
+
text_length = text_token.shape[0]
|
| 378 |
+
|
| 379 |
+
audio_feat = torch.zeros(
|
| 380 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
device=text_token.device,
|
| 383 |
+
)
|
| 384 |
+
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
|
| 385 |
+
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
|
| 386 |
+
|
| 387 |
+
else:
|
| 388 |
+
text = prompt_text + target_text
|
| 389 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 390 |
+
text_token = torch.cat(
|
| 391 |
+
[
|
| 392 |
+
text_token,
|
| 393 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 394 |
+
],
|
| 395 |
+
dim=-1,
|
| 396 |
+
)
|
| 397 |
+
text_length = text_token.shape[0]
|
| 398 |
+
|
| 399 |
+
audio, sr = torchaudio.load(prompt_wav_path)
|
| 400 |
+
if audio.size(0) > 1:
|
| 401 |
+
audio = audio.mean(dim=0, keepdim=True)
|
| 402 |
+
|
| 403 |
+
if sr != self.sample_rate:
|
| 404 |
+
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
| 405 |
+
|
| 406 |
+
patch_len = self.patch_size * self.chunk_size
|
| 407 |
+
|
| 408 |
+
if audio.size(1) % patch_len != 0:
|
| 409 |
+
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
|
| 410 |
+
padding_size = patch_len - audio.size(1) % patch_len
|
| 411 |
+
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
| 412 |
+
|
| 413 |
+
# (B, D, T)
|
| 414 |
+
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
| 415 |
+
audio_feat = audio_feat.view(
|
| 416 |
+
self.audio_vae.latent_dim,
|
| 417 |
+
-1,
|
| 418 |
+
self.patch_size,
|
| 419 |
+
).permute(1, 2, 0)
|
| 420 |
+
audio_length = audio_feat.size(0)
|
| 421 |
+
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
| 422 |
+
text_token = torch.cat([text_token, text_pad_token])
|
| 423 |
+
audio_pad_feat = torch.zeros(
|
| 424 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 425 |
+
dtype=torch.float32,
|
| 426 |
+
device=text_token.device,
|
| 427 |
+
)
|
| 428 |
+
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
|
| 429 |
+
text_mask = (
|
| 430 |
+
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
| 431 |
+
)
|
| 432 |
+
audio_mask = (
|
| 433 |
+
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
text_token = text_token.unsqueeze(0).to(self.device)
|
| 437 |
+
text_mask = text_mask.unsqueeze(0).to(self.device)
|
| 438 |
+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
| 439 |
+
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
| 440 |
+
|
| 441 |
+
target_text_length = len(self.text_tokenizer(target_text))
|
| 442 |
+
|
| 443 |
+
retry_badcase_times = 0
|
| 444 |
+
while retry_badcase_times < retry_badcase_max_times:
|
| 445 |
+
inference_result = self._inference(
|
| 446 |
+
text_token,
|
| 447 |
+
text_mask,
|
| 448 |
+
audio_feat,
|
| 449 |
+
audio_mask,
|
| 450 |
+
min_len=min_len,
|
| 451 |
+
max_len=min(
|
| 452 |
+
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
| 453 |
+
), # avoid too long audio
|
| 454 |
+
inference_timesteps=inference_timesteps,
|
| 455 |
+
cfg_value=cfg_value,
|
| 456 |
+
streaming=streaming,
|
| 457 |
+
)
|
| 458 |
+
if streaming:
|
| 459 |
+
patch_len = self.patch_size * self.chunk_size
|
| 460 |
+
for latent_pred, _ in inference_result:
|
| 461 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 462 |
+
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
| 463 |
+
yield decode_audio
|
| 464 |
+
break
|
| 465 |
+
else:
|
| 466 |
+
latent_pred, pred_audio_feat = next(inference_result)
|
| 467 |
+
if retry_badcase:
|
| 468 |
+
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
| 469 |
+
print(
|
| 470 |
+
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
| 471 |
+
file=sys.stderr,
|
| 472 |
+
)
|
| 473 |
+
retry_badcase_times += 1
|
| 474 |
+
continue
|
| 475 |
+
else:
|
| 476 |
+
break
|
| 477 |
+
else:
|
| 478 |
+
break
|
| 479 |
+
|
| 480 |
+
if not streaming:
|
| 481 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
| 482 |
+
yield decode_audio
|
| 483 |
+
|
| 484 |
+
@torch.inference_mode()
|
| 485 |
+
def build_prompt_cache(
|
| 486 |
+
self,
|
| 487 |
+
prompt_text: str,
|
| 488 |
+
prompt_wav_path: str,
|
| 489 |
+
):
|
| 490 |
+
"""
|
| 491 |
+
Build prompt cache for subsequent fast generation.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
prompt_text: prompt text (required)
|
| 495 |
+
prompt_wav_path: prompt audio path (required)
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
prompt_cache: dict with prompt_text (raw text) and audio features.
|
| 499 |
+
Text tokenization will be done during generation for consistency.
|
| 500 |
+
"""
|
| 501 |
+
if not prompt_text or not prompt_wav_path:
|
| 502 |
+
raise ValueError("prompt_text and prompt_wav_path are required")
|
| 503 |
+
|
| 504 |
+
# load audio
|
| 505 |
+
audio, sr = torchaudio.load(prompt_wav_path)
|
| 506 |
+
if audio.size(0) > 1:
|
| 507 |
+
audio = audio.mean(dim=0, keepdim=True)
|
| 508 |
+
|
| 509 |
+
if sr != self.sample_rate:
|
| 510 |
+
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
| 511 |
+
|
| 512 |
+
patch_len = self.patch_size * self.chunk_size
|
| 513 |
+
|
| 514 |
+
if audio.size(1) % patch_len != 0:
|
| 515 |
+
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
|
| 516 |
+
padding_size = patch_len - audio.size(1) % patch_len
|
| 517 |
+
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
| 518 |
+
|
| 519 |
+
# extract audio features
|
| 520 |
+
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
| 521 |
+
|
| 522 |
+
audio_feat = audio_feat.view(
|
| 523 |
+
self.audio_vae.latent_dim,
|
| 524 |
+
-1,
|
| 525 |
+
self.patch_size,
|
| 526 |
+
).permute(
|
| 527 |
+
1, 2, 0
|
| 528 |
+
) # (D, T, P)
|
| 529 |
+
# build prompt cache - only save raw text and audio features
|
| 530 |
+
prompt_cache = {
|
| 531 |
+
"prompt_text": prompt_text,
|
| 532 |
+
"audio_feat": audio_feat,
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
return prompt_cache
|
| 536 |
+
|
| 537 |
+
def merge_prompt_cache(
|
| 538 |
+
self,
|
| 539 |
+
original_cache: dict,
|
| 540 |
+
new_text: str,
|
| 541 |
+
new_audio_feat: torch.Tensor,
|
| 542 |
+
):
|
| 543 |
+
"""
|
| 544 |
+
Merge original prompt cache with newly generated content to stabilize voice.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
original_cache: original prompt cache
|
| 548 |
+
new_text: newly generated text
|
| 549 |
+
new_audio_feat: newly generated audio features
|
| 550 |
+
|
| 551 |
+
Returns:
|
| 552 |
+
merged_cache: merged cache with prompt_text and audio_feat
|
| 553 |
+
"""
|
| 554 |
+
if original_cache is None:
|
| 555 |
+
return {
|
| 556 |
+
"prompt_text": new_text,
|
| 557 |
+
"audio_feat": new_audio_feat,
|
| 558 |
+
}
|
| 559 |
+
original_prompt_text = original_cache["prompt_text"]
|
| 560 |
+
original_audio_feat = original_cache["audio_feat"]
|
| 561 |
+
# Merge text by concatenation
|
| 562 |
+
merged_prompt_text = original_prompt_text + new_text
|
| 563 |
+
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
| 564 |
+
|
| 565 |
+
# build new cache
|
| 566 |
+
merged_cache = {
|
| 567 |
+
"prompt_text": merged_prompt_text,
|
| 568 |
+
"audio_feat": merged_audio_feat,
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
return merged_cache
|
| 572 |
+
|
| 573 |
+
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 574 |
+
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
| 575 |
+
|
| 576 |
+
def generate_with_prompt_cache_streaming(
|
| 577 |
+
self, *args, **kwargs
|
| 578 |
+
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
| 579 |
+
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
| 580 |
+
|
| 581 |
+
@torch.inference_mode()
|
| 582 |
+
def _generate_with_prompt_cache(
|
| 583 |
+
self,
|
| 584 |
+
target_text: str,
|
| 585 |
+
prompt_cache: dict,
|
| 586 |
+
min_len: int = 2,
|
| 587 |
+
max_len: int = 2000,
|
| 588 |
+
inference_timesteps: int = 10,
|
| 589 |
+
cfg_value: float = 2.0,
|
| 590 |
+
retry_badcase: bool = False,
|
| 591 |
+
retry_badcase_max_times: int = 3,
|
| 592 |
+
retry_badcase_ratio_threshold: float = 6.0,
|
| 593 |
+
streaming: bool = False,
|
| 594 |
+
streaming_prefix_len: int = 3,
|
| 595 |
+
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
| 596 |
+
"""
|
| 597 |
+
Generate audio using pre-built prompt cache.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
target_text: Text to convert to speech
|
| 601 |
+
prompt_cache: Cache built by build_prompt_cache (can be None)
|
| 602 |
+
min_len: Minimum audio length to avoid very short audio
|
| 603 |
+
max_len: Maximum audio length
|
| 604 |
+
inference_timesteps: Number of diffusion sampling steps
|
| 605 |
+
cfg_value: Classifier-free guidance value
|
| 606 |
+
retry_badcase: Whether to retry on bad cases
|
| 607 |
+
retry_badcase_max_times: Maximum retry attempts
|
| 608 |
+
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
| 609 |
+
streaming: Whether to return a generator of audio chunks
|
| 610 |
+
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
Generator of Tuple containing:
|
| 614 |
+
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
| 615 |
+
- Tensor of new text tokens
|
| 616 |
+
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
| 617 |
+
"""
|
| 618 |
+
if retry_badcase and streaming:
|
| 619 |
+
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
| 620 |
+
retry_badcase = False
|
| 621 |
+
# get prompt from cache
|
| 622 |
+
if prompt_cache is None:
|
| 623 |
+
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
| 624 |
+
text = target_text
|
| 625 |
+
else:
|
| 626 |
+
prompt_audio_feat = prompt_cache["audio_feat"]
|
| 627 |
+
prompt_text = prompt_cache["prompt_text"]
|
| 628 |
+
text = prompt_text + target_text
|
| 629 |
+
|
| 630 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 631 |
+
text_token = torch.cat(
|
| 632 |
+
[
|
| 633 |
+
text_token,
|
| 634 |
+
torch.tensor(
|
| 635 |
+
[self.audio_start_token],
|
| 636 |
+
dtype=torch.int32,
|
| 637 |
+
device=text_token.device,
|
| 638 |
+
),
|
| 639 |
+
],
|
| 640 |
+
dim=-1,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
| 644 |
+
|
| 645 |
+
audio_length = prompt_audio_feat.size(0)
|
| 646 |
+
text_length = text_token.shape[0]
|
| 647 |
+
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
| 648 |
+
audio_pad_feat = torch.zeros(
|
| 649 |
+
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
|
| 650 |
+
dtype=torch.float32,
|
| 651 |
+
device=text_token.device,
|
| 652 |
+
)
|
| 653 |
+
text_token = torch.cat([text_token, text_pad_token])
|
| 654 |
+
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
| 655 |
+
text_mask = (
|
| 656 |
+
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
| 657 |
+
)
|
| 658 |
+
audio_mask = (
|
| 659 |
+
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
text_token = text_token.unsqueeze(0).to(self.device)
|
| 663 |
+
text_mask = text_mask.unsqueeze(0).to(self.device)
|
| 664 |
+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
| 665 |
+
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
| 666 |
+
|
| 667 |
+
# run inference
|
| 668 |
+
target_text_length = len(self.text_tokenizer(target_text))
|
| 669 |
+
retry_badcase_times = 0
|
| 670 |
+
while retry_badcase_times < retry_badcase_max_times:
|
| 671 |
+
inference_result = self._inference(
|
| 672 |
+
text_token,
|
| 673 |
+
text_mask,
|
| 674 |
+
audio_feat,
|
| 675 |
+
audio_mask,
|
| 676 |
+
min_len=min_len,
|
| 677 |
+
max_len=min(
|
| 678 |
+
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
| 679 |
+
), # avoid too long audio
|
| 680 |
+
inference_timesteps=inference_timesteps,
|
| 681 |
+
cfg_value=cfg_value,
|
| 682 |
+
streaming=streaming,
|
| 683 |
+
streaming_prefix_len=streaming_prefix_len,
|
| 684 |
+
)
|
| 685 |
+
if streaming:
|
| 686 |
+
patch_len = self.patch_size * self.chunk_size
|
| 687 |
+
for latent_pred, pred_audio_feat in inference_result:
|
| 688 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 689 |
+
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
| 690 |
+
yield (decode_audio, target_text_token, pred_audio_feat)
|
| 691 |
+
break
|
| 692 |
+
else:
|
| 693 |
+
latent_pred, pred_audio_feat = next(inference_result)
|
| 694 |
+
if retry_badcase:
|
| 695 |
+
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
| 696 |
+
print(
|
| 697 |
+
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
| 698 |
+
file=sys.stderr,
|
| 699 |
+
)
|
| 700 |
+
retry_badcase_times += 1
|
| 701 |
+
continue
|
| 702 |
+
else:
|
| 703 |
+
break
|
| 704 |
+
else:
|
| 705 |
+
break
|
| 706 |
+
if not streaming:
|
| 707 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 708 |
+
patch_len = self.patch_size * self.chunk_size
|
| 709 |
+
if audio_mask.sum().item() > 0:
|
| 710 |
+
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
| 711 |
+
else:
|
| 712 |
+
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
| 713 |
+
yield (decode_audio, target_text_token, pred_audio_feat)
|
| 714 |
+
|
| 715 |
+
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 716 |
+
return next(self._inference(*args, streaming=False, **kwargs))
|
| 717 |
+
|
| 718 |
+
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
| 719 |
+
return self._inference(*args, streaming=True, **kwargs)
|
| 720 |
+
|
| 721 |
+
@torch.inference_mode()
|
| 722 |
+
def _inference(
|
| 723 |
+
self,
|
| 724 |
+
text: torch.Tensor,
|
| 725 |
+
text_mask: torch.Tensor,
|
| 726 |
+
feat: torch.Tensor,
|
| 727 |
+
feat_mask: torch.Tensor,
|
| 728 |
+
min_len: int = 2,
|
| 729 |
+
max_len: int = 2000,
|
| 730 |
+
inference_timesteps: int = 10,
|
| 731 |
+
cfg_value: float = 2.0,
|
| 732 |
+
streaming: bool = False,
|
| 733 |
+
streaming_prefix_len: int = 3,
|
| 734 |
+
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
| 735 |
+
"""Core inference method for audio generation.
|
| 736 |
+
|
| 737 |
+
This is the main inference loop that generates audio features
|
| 738 |
+
using the language model and diffusion transformer.
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
text: Input text tokens
|
| 742 |
+
text_mask: Mask for text tokens
|
| 743 |
+
feat: Input audio features
|
| 744 |
+
feat_mask: Mask for audio features
|
| 745 |
+
min_len: Minimum generation length
|
| 746 |
+
max_len: Maximum generation length
|
| 747 |
+
inference_timesteps: Number of diffusion steps
|
| 748 |
+
cfg_value: Classifier-free guidance value
|
| 749 |
+
streaming: Whether to yield each step latent feature or just the final result
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
Generator of Tuple containing:
|
| 753 |
+
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
| 754 |
+
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
| 755 |
+
"""
|
| 756 |
+
B, T, P, D = feat.shape
|
| 757 |
+
|
| 758 |
+
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
| 759 |
+
feat_embed = self.enc_to_lm_proj(feat_embed)
|
| 760 |
+
|
| 761 |
+
if self.config.lm_config.use_mup:
|
| 762 |
+
scale_emb = self.config.lm_config.scale_emb
|
| 763 |
+
else:
|
| 764 |
+
scale_emb = 1.0
|
| 765 |
+
|
| 766 |
+
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
| 767 |
+
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
| 768 |
+
|
| 769 |
+
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
| 770 |
+
pred_feat_seq = [] # b, t, p, d
|
| 771 |
+
curr_embed = None
|
| 772 |
+
|
| 773 |
+
# Prepare prompt context patches for streaming mode
|
| 774 |
+
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
|
| 775 |
+
prompt_context_patches = []
|
| 776 |
+
audio_patch_count = int(feat_mask.sum().item())
|
| 777 |
+
if audio_patch_count > 0:
|
| 778 |
+
context_len = min(streaming_prefix_len - 1, audio_patch_count)
|
| 779 |
+
# Take the last context_len patches from prompt audio as initial context
|
| 780 |
+
# Split into list of [b, 1, p, d] tensors to match pred_feat_seq format
|
| 781 |
+
prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
|
| 782 |
+
pred_feat_seq = prompt_context_patches + pred_feat_seq
|
| 783 |
+
|
| 784 |
+
enc_outputs, kv_cache_tuple = self.base_lm(
|
| 785 |
+
inputs_embeds=combined_embed,
|
| 786 |
+
is_causal=True,
|
| 787 |
+
)
|
| 788 |
+
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
| 789 |
+
|
| 790 |
+
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
| 791 |
+
lm_hidden = enc_outputs[:, -1, :]
|
| 792 |
+
|
| 793 |
+
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
| 794 |
+
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
| 795 |
+
is_causal=True,
|
| 796 |
+
)
|
| 797 |
+
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
| 798 |
+
residual_hidden = residual_enc_outputs[:, -1, :]
|
| 799 |
+
|
| 800 |
+
for i in tqdm(range(max_len)):
|
| 801 |
+
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
| 802 |
+
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
| 803 |
+
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
|
| 804 |
+
|
| 805 |
+
pred_feat = self.feat_decoder(
|
| 806 |
+
mu=dit_hidden,
|
| 807 |
+
patch_size=self.patch_size,
|
| 808 |
+
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
| 809 |
+
n_timesteps=inference_timesteps,
|
| 810 |
+
cfg_value=cfg_value,
|
| 811 |
+
).transpose(
|
| 812 |
+
1, 2
|
| 813 |
+
) # [b, p, d]
|
| 814 |
+
|
| 815 |
+
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
| 816 |
+
curr_embed = self.enc_to_lm_proj(curr_embed)
|
| 817 |
+
|
| 818 |
+
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
| 819 |
+
prefix_feat_cond = pred_feat
|
| 820 |
+
|
| 821 |
+
if streaming:
|
| 822 |
+
# return the last three predicted latent features to provide enough context for smooth decoding
|
| 823 |
+
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
| 824 |
+
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
| 825 |
+
|
| 826 |
+
yield feat_pred, pred_feat_seq
|
| 827 |
+
|
| 828 |
+
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
| 829 |
+
if i > min_len and stop_flag == 1:
|
| 830 |
+
break
|
| 831 |
+
|
| 832 |
+
lm_hidden = self.base_lm.forward_step(
|
| 833 |
+
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
| 834 |
+
).clone()
|
| 835 |
+
|
| 836 |
+
lm_hidden = self.fsq_layer(lm_hidden)
|
| 837 |
+
residual_hidden = self.residual_lm.forward_step(
|
| 838 |
+
lm_hidden + curr_embed[:, 0, :],
|
| 839 |
+
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
|
| 840 |
+
).clone()
|
| 841 |
+
|
| 842 |
+
if not streaming:
|
| 843 |
+
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
| 844 |
+
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
| 845 |
+
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
| 846 |
+
|
| 847 |
+
@classmethod
|
| 848 |
+
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
| 849 |
+
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
| 850 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
| 851 |
+
audio_vae_config = getattr(config, "audio_vae_config", None)
|
| 852 |
+
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
| 853 |
+
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
| 854 |
+
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
| 855 |
+
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
| 856 |
+
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
| 857 |
+
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
| 858 |
+
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
| 859 |
+
elif os.path.exists(audiovae_pth_path):
|
| 860 |
+
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
| 861 |
+
checkpoint = torch.load(
|
| 862 |
+
audiovae_pth_path,
|
| 863 |
+
map_location="cpu",
|
| 864 |
+
weights_only=True,
|
| 865 |
+
)
|
| 866 |
+
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
| 867 |
+
else:
|
| 868 |
+
raise FileNotFoundError(
|
| 869 |
+
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
| 870 |
+
)
|
| 871 |
+
model = cls(config, tokenizer, audio_vae, lora_config)
|
| 872 |
+
if not training:
|
| 873 |
+
lm_dtype = get_dtype(model.config.dtype)
|
| 874 |
+
model = model.to(lm_dtype)
|
| 875 |
+
else: # training mode
|
| 876 |
+
for name, param in model.named_parameters():
|
| 877 |
+
if "audio_vae" in name: # freeze VAE weights
|
| 878 |
+
param.requires_grad = False
|
| 879 |
+
continue
|
| 880 |
+
if lora_config is not None:
|
| 881 |
+
if "lora" not in name: # freeze non-LoRA weights
|
| 882 |
+
param.requires_grad = False
|
| 883 |
+
model.audio_vae = model.audio_vae.to(torch.float32)
|
| 884 |
+
|
| 885 |
+
# Try to load from safetensors first, fallback to pytorch_model.bin
|
| 886 |
+
safetensors_path = os.path.join(path, "model.safetensors")
|
| 887 |
+
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
| 888 |
+
|
| 889 |
+
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
| 890 |
+
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
| 891 |
+
model_state_dict = load_file(safetensors_path)
|
| 892 |
+
elif os.path.exists(pytorch_model_path):
|
| 893 |
+
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
|
| 894 |
+
checkpoint = torch.load(
|
| 895 |
+
pytorch_model_path,
|
| 896 |
+
map_location="cpu",
|
| 897 |
+
weights_only=True,
|
| 898 |
+
)
|
| 899 |
+
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
| 900 |
+
else:
|
| 901 |
+
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
| 902 |
+
|
| 903 |
+
for kw, val in vae_state_dict.items():
|
| 904 |
+
model_state_dict[f"audio_vae.{kw}"] = val
|
| 905 |
+
|
| 906 |
+
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
| 907 |
+
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
| 908 |
+
model.load_state_dict(model_state_dict, strict=False)
|
| 909 |
+
if training:
|
| 910 |
+
return model
|
| 911 |
+
return model.to(model.device).eval().optimize(disable=not optimize)
|
| 912 |
+
|
| 913 |
+
# ------------------------------------------------------------------ #
|
| 914 |
+
# LoRA Weight Management
|
| 915 |
+
# ------------------------------------------------------------------ #
|
| 916 |
+
def _iter_lora_modules(self):
|
| 917 |
+
"""Iterate over all LoRA modules."""
|
| 918 |
+
from ..modules.layers.lora import LoRALinear
|
| 919 |
+
|
| 920 |
+
for module in self.modules():
|
| 921 |
+
if isinstance(module, LoRALinear):
|
| 922 |
+
yield module
|
| 923 |
+
|
| 924 |
+
def load_lora_weights(self, lora_path: str, device: str = None):
|
| 925 |
+
"""
|
| 926 |
+
Load LoRA weights from file, supports calling after torch.compile.
|
| 927 |
+
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
| 928 |
+
Supports both safetensors and pytorch formats.
|
| 929 |
+
|
| 930 |
+
Args:
|
| 931 |
+
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
| 932 |
+
device: Target device, defaults to model's current device
|
| 933 |
+
Returns:
|
| 934 |
+
tuple: (loaded_keys, skipped_keys)
|
| 935 |
+
"""
|
| 936 |
+
from pathlib import Path
|
| 937 |
+
|
| 938 |
+
device = device or self.device
|
| 939 |
+
lora_p = Path(lora_path)
|
| 940 |
+
|
| 941 |
+
# Try safetensors first, then fallback to .ckpt
|
| 942 |
+
if lora_p.is_dir():
|
| 943 |
+
safetensors_file = lora_p / "lora_weights.safetensors"
|
| 944 |
+
ckpt_file = lora_p / "lora_weights.ckpt"
|
| 945 |
+
else:
|
| 946 |
+
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
| 947 |
+
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
| 948 |
+
|
| 949 |
+
# Load from safetensors if available
|
| 950 |
+
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
| 951 |
+
state_dict = load_file(str(safetensors_file), device=device)
|
| 952 |
+
elif ckpt_file and ckpt_file.exists():
|
| 953 |
+
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
| 954 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 955 |
+
else:
|
| 956 |
+
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
| 957 |
+
|
| 958 |
+
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
| 959 |
+
model_params = dict(self.named_parameters())
|
| 960 |
+
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
| 961 |
+
|
| 962 |
+
loaded_keys, skipped_keys = [], []
|
| 963 |
+
for key, value in state_dict.items():
|
| 964 |
+
target_key = key if key in model_params else key_mapping.get(key)
|
| 965 |
+
if target_key:
|
| 966 |
+
model_params[target_key].data.copy_(value.to(device))
|
| 967 |
+
loaded_keys.append(key)
|
| 968 |
+
else:
|
| 969 |
+
skipped_keys.append(key)
|
| 970 |
+
|
| 971 |
+
return loaded_keys, skipped_keys
|
| 972 |
+
|
| 973 |
+
def set_lora_enabled(self, enabled: bool):
|
| 974 |
+
"""Enable/disable all LoRA layers."""
|
| 975 |
+
for module in self._iter_lora_modules():
|
| 976 |
+
module.set_enabled(enabled)
|
| 977 |
+
|
| 978 |
+
def reset_lora_weights(self):
|
| 979 |
+
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
|
| 980 |
+
for module in self._iter_lora_modules():
|
| 981 |
+
module.reset_lora_parameters()
|
| 982 |
+
|
| 983 |
+
def get_lora_state_dict(self) -> dict:
|
| 984 |
+
"""Get all LoRA parameters (lora_A/lora_B)."""
|
| 985 |
+
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
voxcpm/model/voxcpm2.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoxCPM: A Tokenizer-free speech generation model
|
| 3 |
+
|
| 4 |
+
This module contains the main VoxCPM model implementation, including configuration classes
|
| 5 |
+
and the core VoxCPMModel for text-to-speech generation.
|
| 6 |
+
|
| 7 |
+
Copyright 2026 OpenBMB
|
| 8 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
you may not use this file except in compliance with the License.
|
| 10 |
+
You may obtain a copy of the License at
|
| 11 |
+
|
| 12 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
|
| 14 |
+
Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
See the License for the specific language governing permissions and
|
| 18 |
+
limitations under the License.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
from typing import Tuple, Union, Generator, List, Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import warnings
|
| 28 |
+
import librosa
|
| 29 |
+
import numpy as np
|
| 30 |
+
from einops import rearrange
|
| 31 |
+
from pydantic import BaseModel
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from safetensors.torch import load_file
|
| 35 |
+
|
| 36 |
+
SAFETENSORS_AVAILABLE = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
SAFETENSORS_AVAILABLE = False
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
from transformers import LlamaTokenizerFast
|
| 41 |
+
|
| 42 |
+
from ..modules.audiovae import AudioVAEV2, AudioVAEConfigV2
|
| 43 |
+
from ..modules.layers import ScalarQuantizationLayer
|
| 44 |
+
from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
| 45 |
+
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
| 46 |
+
from ..modules.locenc import VoxCPMLocEnc
|
| 47 |
+
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
| 48 |
+
from .utils import get_dtype, mask_multichar_chinese_tokens
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _trim_audio_silence_vad(
|
| 52 |
+
audio: torch.Tensor,
|
| 53 |
+
sample_rate: int,
|
| 54 |
+
max_silence_ms: float = 200.0,
|
| 55 |
+
top_db: float = 35.0,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
| 58 |
+
|
| 59 |
+
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
audio: (1, T) 的音频 tensor
|
| 63 |
+
sample_rate: 采样率
|
| 64 |
+
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
| 65 |
+
top_db: 低于参考电平多少 dB 视为静音
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
截取后的 (1, T') tensor
|
| 69 |
+
"""
|
| 70 |
+
if audio.numel() == 0:
|
| 71 |
+
return audio
|
| 72 |
+
y = audio.squeeze(0).numpy()
|
| 73 |
+
n = len(y)
|
| 74 |
+
frame_length = 2048
|
| 75 |
+
hop_length = 512
|
| 76 |
+
ref = np.max(np.abs(y))
|
| 77 |
+
if ref <= 0:
|
| 78 |
+
return audio
|
| 79 |
+
threshold = ref * (10.0 ** (-top_db / 20.0))
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
_, (start, end) = librosa.effects.trim(
|
| 83 |
+
y, top_db=top_db, ref=np.max, frame_length=frame_length, hop_length=hop_length
|
| 84 |
+
)
|
| 85 |
+
except Exception:
|
| 86 |
+
start, end = 0, n
|
| 87 |
+
|
| 88 |
+
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
| 89 |
+
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
| 90 |
+
last_voice_frame = -1
|
| 91 |
+
for j in range(n_frames):
|
| 92 |
+
idx = j * hop_length
|
| 93 |
+
if idx + frame_length > n:
|
| 94 |
+
break
|
| 95 |
+
rms = np.sqrt(np.mean(y[idx : idx + frame_length] ** 2))
|
| 96 |
+
if rms >= threshold:
|
| 97 |
+
last_voice_frame = j
|
| 98 |
+
if last_voice_frame >= 0:
|
| 99 |
+
end_by_vad = min(n, (last_voice_frame + 1) * hop_length + (frame_length - hop_length))
|
| 100 |
+
end = min(end, end_by_vad)
|
| 101 |
+
|
| 102 |
+
max_silence_samples = int(max_silence_ms * sample_rate / 1000.0)
|
| 103 |
+
new_start = max(0, start - max_silence_samples)
|
| 104 |
+
new_end = min(n, end + max_silence_samples)
|
| 105 |
+
return audio[:, new_start:new_end]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class VoxCPMEncoderConfig(BaseModel):
|
| 109 |
+
hidden_dim: int = 1024
|
| 110 |
+
ffn_dim: int = 4096
|
| 111 |
+
num_heads: int = 16
|
| 112 |
+
num_layers: int = 4
|
| 113 |
+
kv_channels: int = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VoxCPMDitConfig(BaseModel):
|
| 117 |
+
hidden_dim: int = 1024
|
| 118 |
+
ffn_dim: int = 4096
|
| 119 |
+
num_heads: int = 16
|
| 120 |
+
num_layers: int = 4
|
| 121 |
+
kv_channels: int = None
|
| 122 |
+
dit_mean_mode: bool = False
|
| 123 |
+
|
| 124 |
+
cfm_config: CfmConfig
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class VoxCPMConfig(BaseModel):
|
| 128 |
+
lm_config: MiniCPM4Config
|
| 129 |
+
patch_size: int = 4
|
| 130 |
+
feat_dim: int = 64
|
| 131 |
+
residual_lm_num_layers: int = 8
|
| 132 |
+
residual_lm_no_rope: bool = False
|
| 133 |
+
scalar_quantization_latent_dim: int = 512
|
| 134 |
+
scalar_quantization_scale: int = 9
|
| 135 |
+
|
| 136 |
+
encoder_config: VoxCPMEncoderConfig
|
| 137 |
+
dit_config: VoxCPMDitConfig
|
| 138 |
+
audio_vae_config: Optional[AudioVAEConfigV2] = None
|
| 139 |
+
|
| 140 |
+
max_length: int = 8192
|
| 141 |
+
device: str = "cuda"
|
| 142 |
+
dtype: str = "bfloat16"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class LoRAConfig(BaseModel):
|
| 146 |
+
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
| 147 |
+
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
| 148 |
+
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
| 149 |
+
|
| 150 |
+
r: int = 8
|
| 151 |
+
alpha: int = 16
|
| 152 |
+
dropout: float = 0.0
|
| 153 |
+
|
| 154 |
+
# Target linear layer names for LM & DiT (matched by attribute name)
|
| 155 |
+
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 156 |
+
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 157 |
+
# Projection layer attribute names to find on VoxCPM2Model
|
| 158 |
+
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj", "fusion_concat_proj"]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
VoxCPMConfig.model_rebuild()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class VoxCPM2Model(nn.Module):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
config: VoxCPMConfig,
|
| 168 |
+
tokenizer: LlamaTokenizerFast,
|
| 169 |
+
audio_vae: AudioVAEV2,
|
| 170 |
+
lora_config: LoRAConfig = None,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.config = config
|
| 174 |
+
self.lora_config = lora_config
|
| 175 |
+
self.feat_dim = config.feat_dim
|
| 176 |
+
self.patch_size = config.patch_size
|
| 177 |
+
self.device = config.device
|
| 178 |
+
if not torch.cuda.is_available():
|
| 179 |
+
if torch.backends.mps.is_available():
|
| 180 |
+
self.device = "mps"
|
| 181 |
+
else:
|
| 182 |
+
self.device = "cpu"
|
| 183 |
+
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
| 184 |
+
|
| 185 |
+
# Text-Semantic LM
|
| 186 |
+
self.base_lm = MiniCPMModel(config.lm_config)
|
| 187 |
+
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
| 188 |
+
|
| 189 |
+
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
| 190 |
+
self.audio_start_token = 101
|
| 191 |
+
self.audio_end_token = 102
|
| 192 |
+
self.ref_audio_start_token = 103
|
| 193 |
+
self.ref_audio_end_token = 104
|
| 194 |
+
|
| 195 |
+
# Residual Acoustic LM
|
| 196 |
+
residual_lm_config = config.lm_config.model_copy(deep=True)
|
| 197 |
+
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
| 198 |
+
residual_lm_config.vocab_size = 0
|
| 199 |
+
residual_lm_config.no_rope = config.residual_lm_no_rope
|
| 200 |
+
self.residual_lm = MiniCPMModel(residual_lm_config)
|
| 201 |
+
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
| 202 |
+
|
| 203 |
+
# Local Encoder
|
| 204 |
+
encoder_config = config.lm_config.model_copy(deep=True)
|
| 205 |
+
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
| 206 |
+
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
| 207 |
+
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
| 208 |
+
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
| 209 |
+
encoder_config.kv_channels = config.encoder_config.kv_channels
|
| 210 |
+
encoder_config.vocab_size = 0
|
| 211 |
+
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
| 212 |
+
|
| 213 |
+
# Local DiT
|
| 214 |
+
decoder_config = config.lm_config.model_copy(deep=True)
|
| 215 |
+
decoder_config.hidden_size = config.dit_config.hidden_dim
|
| 216 |
+
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
| 217 |
+
decoder_config.num_attention_heads = config.dit_config.num_heads
|
| 218 |
+
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
| 219 |
+
decoder_config.kv_channels = config.dit_config.kv_channels
|
| 220 |
+
decoder_config.vocab_size = 0
|
| 221 |
+
self.feat_decoder = UnifiedCFM(
|
| 222 |
+
in_channels=config.feat_dim,
|
| 223 |
+
cfm_params=config.dit_config.cfm_config,
|
| 224 |
+
estimator=VoxCPMLocDiTV2(decoder_config, in_channels=config.feat_dim),
|
| 225 |
+
mean_mode=config.dit_config.dit_mean_mode,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Projection layers
|
| 229 |
+
self.fsq_layer = ScalarQuantizationLayer(
|
| 230 |
+
config.lm_config.hidden_size,
|
| 231 |
+
config.lm_config.hidden_size,
|
| 232 |
+
config.scalar_quantization_latent_dim,
|
| 233 |
+
config.scalar_quantization_scale,
|
| 234 |
+
)
|
| 235 |
+
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
| 236 |
+
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
| 237 |
+
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
| 238 |
+
self.fusion_concat_proj = nn.Linear(config.lm_config.hidden_size * 2, config.lm_config.hidden_size)
|
| 239 |
+
|
| 240 |
+
# Stop Predictor
|
| 241 |
+
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
| 242 |
+
self.stop_actn = nn.SiLU()
|
| 243 |
+
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
| 244 |
+
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
|
| 245 |
+
|
| 246 |
+
# Audio VAE
|
| 247 |
+
self.audio_vae = audio_vae
|
| 248 |
+
self.chunk_size = audio_vae.chunk_size
|
| 249 |
+
self._encode_sample_rate = audio_vae.sample_rate
|
| 250 |
+
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
| 251 |
+
|
| 252 |
+
if self.lora_config is not None:
|
| 253 |
+
self._apply_lora()
|
| 254 |
+
|
| 255 |
+
def _apply_lora(self):
|
| 256 |
+
"""注入 LoRA 到 LM / DiT / 投影层"""
|
| 257 |
+
cfg = self.lora_config
|
| 258 |
+
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
|
| 259 |
+
|
| 260 |
+
# LM: base_lm + residual_lm
|
| 261 |
+
if cfg.enable_lm:
|
| 262 |
+
for lm in [self.base_lm, self.residual_lm]:
|
| 263 |
+
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
| 264 |
+
|
| 265 |
+
# DiT: feat_decoder.estimator
|
| 266 |
+
if cfg.enable_dit:
|
| 267 |
+
apply_lora_to_named_linear_modules(
|
| 268 |
+
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# 投影层
|
| 272 |
+
if cfg.enable_proj:
|
| 273 |
+
from ..modules.layers.lora import LoRALinear
|
| 274 |
+
|
| 275 |
+
for attr_name in cfg.target_proj_modules:
|
| 276 |
+
module = getattr(self, attr_name, None)
|
| 277 |
+
if isinstance(module, nn.Linear):
|
| 278 |
+
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
|
| 279 |
+
|
| 280 |
+
def optimize(self, disable: bool = False):
|
| 281 |
+
if disable:
|
| 282 |
+
return self
|
| 283 |
+
try:
|
| 284 |
+
if self.device != "cuda":
|
| 285 |
+
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
| 286 |
+
try:
|
| 287 |
+
import triton # noqa: F401
|
| 288 |
+
except ImportError:
|
| 289 |
+
raise ValueError("triton is not installed")
|
| 290 |
+
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
| 291 |
+
self.residual_lm.forward_step = torch.compile(
|
| 292 |
+
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
| 293 |
+
)
|
| 294 |
+
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
| 295 |
+
self.feat_decoder.estimator = torch.compile(
|
| 296 |
+
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
| 297 |
+
)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
| 300 |
+
return self
|
| 301 |
+
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
text_tokens: torch.Tensor,
|
| 305 |
+
text_mask: torch.Tensor,
|
| 306 |
+
audio_feats: torch.Tensor,
|
| 307 |
+
audio_mask: torch.Tensor,
|
| 308 |
+
loss_mask: torch.Tensor,
|
| 309 |
+
position_ids: torch.Tensor,
|
| 310 |
+
labels: torch.Tensor,
|
| 311 |
+
*,
|
| 312 |
+
progress: float = 0.0,
|
| 313 |
+
sample_generate: bool = False,
|
| 314 |
+
):
|
| 315 |
+
del position_ids # not used yet
|
| 316 |
+
|
| 317 |
+
text_tokens = text_tokens.to(self.device, dtype=torch.long)
|
| 318 |
+
text_mask = text_mask.to(self.device, dtype=self._dtype())
|
| 319 |
+
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
|
| 320 |
+
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
|
| 321 |
+
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
|
| 322 |
+
labels = labels.to(self.device, dtype=torch.long)
|
| 323 |
+
|
| 324 |
+
B, T, P, D = audio_feats.shape
|
| 325 |
+
feat_embed = self.feat_encoder(audio_feats)
|
| 326 |
+
feat_embed = self.enc_to_lm_proj(feat_embed)
|
| 327 |
+
|
| 328 |
+
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
|
| 329 |
+
if not getattr(self.config.lm_config, "use_mup", False):
|
| 330 |
+
scale_emb = 1.0
|
| 331 |
+
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
|
| 332 |
+
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
|
| 333 |
+
|
| 334 |
+
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
|
| 335 |
+
enc_outputs = enc_outputs.to(self._dtype())
|
| 336 |
+
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
| 337 |
+
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
|
| 338 |
+
|
| 339 |
+
residual_inputs = self.fusion_concat_proj(
|
| 340 |
+
torch.cat((enc_outputs, audio_mask.unsqueeze(-1) * feat_embed), dim=-1)
|
| 341 |
+
)
|
| 342 |
+
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
|
| 343 |
+
residual_outputs = residual_outputs.to(self._dtype())
|
| 344 |
+
residual_hidden = torch.cat(
|
| 345 |
+
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
|
| 346 |
+
dim=1,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
dit_hidden = torch.cat((self.lm_to_dit_proj(lm_hidden), self.res_to_dit_proj(residual_hidden)), dim=-1)
|
| 350 |
+
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
|
| 351 |
+
|
| 352 |
+
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
|
| 353 |
+
target_dtype = self._dtype()
|
| 354 |
+
|
| 355 |
+
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
|
| 356 |
+
feat_cond = torch.cat(
|
| 357 |
+
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
|
| 358 |
+
dim=1,
|
| 359 |
+
)
|
| 360 |
+
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
|
| 361 |
+
|
| 362 |
+
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
|
| 363 |
+
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
|
| 364 |
+
|
| 365 |
+
diff_loss = self.feat_decoder.compute_loss(
|
| 366 |
+
feat_gt.transpose(1, 2).contiguous(),
|
| 367 |
+
dit_hidden,
|
| 368 |
+
cond=feat_cond.transpose(1, 2).contiguous(),
|
| 369 |
+
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
|
| 370 |
+
progress=progress,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
|
| 374 |
+
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
|
| 375 |
+
denom = torch.clamp(loss_mask.sum(), min=1.0)
|
| 376 |
+
stop_loss = (stop_losses * loss_mask).sum() / denom
|
| 377 |
+
|
| 378 |
+
feat_pred = None
|
| 379 |
+
if sample_generate:
|
| 380 |
+
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
|
| 381 |
+
feat_pred_seq = self.feat_decoder(
|
| 382 |
+
mu=dit_hidden,
|
| 383 |
+
patch_size=self.patch_size,
|
| 384 |
+
cond=feat_cond_for_sample,
|
| 385 |
+
n_timesteps=(
|
| 386 |
+
self.config.dit_config.cfm_config.inference_cfg_rate
|
| 387 |
+
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
| 388 |
+
else 10
|
| 389 |
+
),
|
| 390 |
+
)
|
| 391 |
+
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
| 392 |
+
|
| 393 |
+
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
"loss/diff": diff_loss,
|
| 397 |
+
"loss/stop": stop_loss,
|
| 398 |
+
"feat_gt": feat_gt_tensor,
|
| 399 |
+
"feat_pred": feat_pred,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
def _dtype(self):
|
| 403 |
+
return get_dtype(self.config.dtype)
|
| 404 |
+
|
| 405 |
+
def _encode_wav(self, wav_path: str, padding_mode: str = "right") -> torch.Tensor:
|
| 406 |
+
"""Load, trim, pad and VAE-encode an audio file.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
wav_path: path to the audio file.
|
| 410 |
+
padding_mode: "right" (default) or "left" padding for alignment.
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
audio_feat: (T, P, D) tensor of latent patches.
|
| 414 |
+
"""
|
| 415 |
+
audio, _ = librosa.load(wav_path, sr=self._encode_sample_rate, mono=True)
|
| 416 |
+
audio = torch.from_numpy(audio).unsqueeze(0)
|
| 417 |
+
audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
|
| 418 |
+
patch_len = self.patch_size * self.chunk_size
|
| 419 |
+
if audio.size(1) % patch_len != 0:
|
| 420 |
+
padding_size = patch_len - audio.size(1) % patch_len
|
| 421 |
+
pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
|
| 422 |
+
audio = torch.nn.functional.pad(audio, pad)
|
| 423 |
+
feat = self.audio_vae.encode(audio.to(self.device), self._encode_sample_rate).cpu()
|
| 424 |
+
return feat.view(self.audio_vae.latent_dim, -1, self.patch_size).permute(1, 2, 0)
|
| 425 |
+
|
| 426 |
+
def _make_ref_prefix(self, ref_feat: torch.Tensor, device: torch.device):
|
| 427 |
+
"""Build the [ref_start ref_audio ref_end] prefix segments.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
tokens, feats, text_mask, audio_mask (all 1-D / 2-D tensors)
|
| 431 |
+
"""
|
| 432 |
+
ref_len = ref_feat.size(0)
|
| 433 |
+
z1 = torch.zeros((1, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32, device=device)
|
| 434 |
+
tokens = torch.cat(
|
| 435 |
+
[
|
| 436 |
+
torch.tensor([self.ref_audio_start_token], dtype=torch.int32, device=device),
|
| 437 |
+
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
| 438 |
+
torch.tensor([self.ref_audio_end_token], dtype=torch.int32, device=device),
|
| 439 |
+
]
|
| 440 |
+
)
|
| 441 |
+
feats = torch.cat([z1, ref_feat, z1], dim=0)
|
| 442 |
+
t_mask = torch.cat(
|
| 443 |
+
[
|
| 444 |
+
torch.tensor([1], dtype=torch.int32),
|
| 445 |
+
torch.zeros(ref_len, dtype=torch.int32),
|
| 446 |
+
torch.tensor([1], dtype=torch.int32),
|
| 447 |
+
]
|
| 448 |
+
).to(device)
|
| 449 |
+
a_mask = torch.cat(
|
| 450 |
+
[
|
| 451 |
+
torch.tensor([0], dtype=torch.int32),
|
| 452 |
+
torch.ones(ref_len, dtype=torch.int32),
|
| 453 |
+
torch.tensor([0], dtype=torch.int32),
|
| 454 |
+
]
|
| 455 |
+
).to(device)
|
| 456 |
+
return tokens, feats, t_mask, a_mask
|
| 457 |
+
|
| 458 |
+
def generate(self, *args, **kwargs) -> torch.Tensor:
|
| 459 |
+
return next(self._generate(*args, streaming=False, **kwargs))
|
| 460 |
+
|
| 461 |
+
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
| 462 |
+
return self._generate(*args, streaming=True, **kwargs)
|
| 463 |
+
|
| 464 |
+
@torch.inference_mode()
|
| 465 |
+
def _generate(
|
| 466 |
+
self,
|
| 467 |
+
target_text: str,
|
| 468 |
+
prompt_text: str = "",
|
| 469 |
+
prompt_wav_path: str = "",
|
| 470 |
+
reference_wav_path: str = "",
|
| 471 |
+
min_len: int = 2,
|
| 472 |
+
max_len: int = 2000,
|
| 473 |
+
inference_timesteps: int = 10,
|
| 474 |
+
cfg_value: float = 2.0,
|
| 475 |
+
retry_badcase: bool = False,
|
| 476 |
+
retry_badcase_max_times: int = 3,
|
| 477 |
+
retry_badcase_ratio_threshold: float = 6.0,
|
| 478 |
+
streaming: bool = False,
|
| 479 |
+
streaming_prefix_len: int = 4,
|
| 480 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 481 |
+
if retry_badcase and streaming:
|
| 482 |
+
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
| 483 |
+
retry_badcase = False
|
| 484 |
+
|
| 485 |
+
if reference_wav_path and prompt_wav_path:
|
| 486 |
+
# Combined mode: reference isolation prefix + continuation suffix
|
| 487 |
+
text = prompt_text + target_text
|
| 488 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 489 |
+
text_token = torch.cat(
|
| 490 |
+
[
|
| 491 |
+
text_token,
|
| 492 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 493 |
+
],
|
| 494 |
+
dim=-1,
|
| 495 |
+
)
|
| 496 |
+
text_length = text_token.shape[0]
|
| 497 |
+
|
| 498 |
+
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
|
| 499 |
+
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
|
| 500 |
+
prompt_audio_length = prompt_feat.size(0)
|
| 501 |
+
|
| 502 |
+
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
|
| 503 |
+
|
| 504 |
+
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
|
| 505 |
+
text_pad_feat = torch.zeros(
|
| 506 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 507 |
+
dtype=torch.float32,
|
| 508 |
+
device=text_token.device,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
|
| 512 |
+
audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_feat], dim=0)
|
| 513 |
+
text_mask = torch.cat(
|
| 514 |
+
[
|
| 515 |
+
ref_t_mask,
|
| 516 |
+
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
|
| 517 |
+
torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
|
| 518 |
+
]
|
| 519 |
+
)
|
| 520 |
+
audio_mask = torch.cat(
|
| 521 |
+
[
|
| 522 |
+
ref_a_mask,
|
| 523 |
+
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
|
| 524 |
+
torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
|
| 525 |
+
]
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
elif reference_wav_path:
|
| 529 |
+
# Reference-only mode (prompt isolation)
|
| 530 |
+
text = target_text
|
| 531 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 532 |
+
text_token = torch.cat(
|
| 533 |
+
[
|
| 534 |
+
text_token,
|
| 535 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 536 |
+
],
|
| 537 |
+
dim=-1,
|
| 538 |
+
)
|
| 539 |
+
text_length = text_token.shape[0]
|
| 540 |
+
|
| 541 |
+
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
|
| 542 |
+
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
|
| 543 |
+
|
| 544 |
+
text_pad_feat = torch.zeros(
|
| 545 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 546 |
+
dtype=torch.float32,
|
| 547 |
+
device=text_token.device,
|
| 548 |
+
)
|
| 549 |
+
text_token = torch.cat([ref_tokens, text_token])
|
| 550 |
+
audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
|
| 551 |
+
text_mask = torch.cat(
|
| 552 |
+
[
|
| 553 |
+
ref_t_mask,
|
| 554 |
+
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
audio_mask = torch.cat(
|
| 558 |
+
[
|
| 559 |
+
ref_a_mask,
|
| 560 |
+
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
|
| 561 |
+
]
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
elif len(prompt_wav_path) == 0:
|
| 565 |
+
# Zero-shot mode
|
| 566 |
+
text = target_text
|
| 567 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 568 |
+
text_token = torch.cat(
|
| 569 |
+
[
|
| 570 |
+
text_token,
|
| 571 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 572 |
+
],
|
| 573 |
+
dim=-1,
|
| 574 |
+
)
|
| 575 |
+
text_length = text_token.shape[0]
|
| 576 |
+
|
| 577 |
+
audio_feat = torch.zeros(
|
| 578 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 579 |
+
dtype=torch.float32,
|
| 580 |
+
device=text_token.device,
|
| 581 |
+
)
|
| 582 |
+
text_mask = torch.ones(text_length, dtype=torch.int32).to(text_token.device)
|
| 583 |
+
audio_mask = torch.zeros(text_length, dtype=torch.int32).to(text_token.device)
|
| 584 |
+
|
| 585 |
+
else:
|
| 586 |
+
# Continuation-only mode
|
| 587 |
+
text = prompt_text + target_text
|
| 588 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 589 |
+
text_token = torch.cat(
|
| 590 |
+
[
|
| 591 |
+
text_token,
|
| 592 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 593 |
+
],
|
| 594 |
+
dim=-1,
|
| 595 |
+
)
|
| 596 |
+
text_length = text_token.shape[0]
|
| 597 |
+
|
| 598 |
+
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
|
| 599 |
+
prompt_audio_length = prompt_feat.size(0)
|
| 600 |
+
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
|
| 601 |
+
text_pad_feat = torch.zeros(
|
| 602 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 603 |
+
dtype=torch.float32,
|
| 604 |
+
device=text_token.device,
|
| 605 |
+
)
|
| 606 |
+
text_token = torch.cat([text_token, prompt_pad_token])
|
| 607 |
+
audio_feat = torch.cat([text_pad_feat, prompt_feat], dim=0)
|
| 608 |
+
text_mask = torch.cat(
|
| 609 |
+
[
|
| 610 |
+
torch.ones(text_length, dtype=torch.int32),
|
| 611 |
+
torch.zeros(prompt_audio_length, dtype=torch.int32),
|
| 612 |
+
]
|
| 613 |
+
).to(text_token.device)
|
| 614 |
+
audio_mask = torch.cat(
|
| 615 |
+
[
|
| 616 |
+
torch.zeros(text_length, dtype=torch.int32),
|
| 617 |
+
torch.ones(prompt_audio_length, dtype=torch.int32),
|
| 618 |
+
]
|
| 619 |
+
).to(text_token.device)
|
| 620 |
+
|
| 621 |
+
text_token = text_token.unsqueeze(0).to(self.device)
|
| 622 |
+
text_mask = text_mask.unsqueeze(0).to(self.device)
|
| 623 |
+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
| 624 |
+
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
| 625 |
+
|
| 626 |
+
target_text_length = len(self.text_tokenizer(target_text))
|
| 627 |
+
|
| 628 |
+
retry_badcase_times = 0
|
| 629 |
+
while retry_badcase_times < retry_badcase_max_times:
|
| 630 |
+
inference_result = self._inference(
|
| 631 |
+
text_token,
|
| 632 |
+
text_mask,
|
| 633 |
+
audio_feat,
|
| 634 |
+
audio_mask,
|
| 635 |
+
min_len=min_len,
|
| 636 |
+
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
|
| 637 |
+
inference_timesteps=inference_timesteps,
|
| 638 |
+
cfg_value=cfg_value,
|
| 639 |
+
streaming=streaming,
|
| 640 |
+
streaming_prefix_len=streaming_prefix_len,
|
| 641 |
+
)
|
| 642 |
+
if streaming:
|
| 643 |
+
out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
|
| 644 |
+
for latent_pred, _ in inference_result:
|
| 645 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 646 |
+
decode_audio = decode_audio[..., -out_patch_len:].squeeze(1).cpu()
|
| 647 |
+
yield decode_audio
|
| 648 |
+
break
|
| 649 |
+
else:
|
| 650 |
+
latent_pred, pred_audio_feat = next(inference_result)
|
| 651 |
+
if retry_badcase:
|
| 652 |
+
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
| 653 |
+
print(
|
| 654 |
+
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
| 655 |
+
file=sys.stderr,
|
| 656 |
+
)
|
| 657 |
+
retry_badcase_times += 1
|
| 658 |
+
continue
|
| 659 |
+
else:
|
| 660 |
+
break
|
| 661 |
+
else:
|
| 662 |
+
break
|
| 663 |
+
|
| 664 |
+
if not streaming:
|
| 665 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 666 |
+
out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
|
| 667 |
+
has_continuation = bool(prompt_wav_path)
|
| 668 |
+
if has_continuation:
|
| 669 |
+
decode_audio = decode_audio[..., out_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
| 670 |
+
else:
|
| 671 |
+
decode_audio = decode_audio.squeeze(1).cpu()
|
| 672 |
+
yield decode_audio
|
| 673 |
+
|
| 674 |
+
@torch.inference_mode()
|
| 675 |
+
def build_prompt_cache(
|
| 676 |
+
self,
|
| 677 |
+
prompt_text: str = None,
|
| 678 |
+
prompt_wav_path: str = None,
|
| 679 |
+
reference_wav_path: str = None,
|
| 680 |
+
):
|
| 681 |
+
"""
|
| 682 |
+
Build prompt cache for subsequent generation.
|
| 683 |
+
|
| 684 |
+
Supports the same parameter combinations as ``generate()``:
|
| 685 |
+
- ``reference_wav_path`` only -> reference mode (voice cloning, isolated)
|
| 686 |
+
- ``prompt_text`` + ``prompt_wav_path`` -> continuation mode
|
| 687 |
+
- all three -> combined ref + continuation mode
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
prompt_text: prompt text for continuation mode.
|
| 691 |
+
Must be paired with ``prompt_wav_path``.
|
| 692 |
+
prompt_wav_path: prompt audio path for continuation mode.
|
| 693 |
+
Must be paired with ``prompt_text``.
|
| 694 |
+
reference_wav_path: reference audio path for voice cloning
|
| 695 |
+
(structurally isolated via ref_audio tokens).
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
prompt_cache: dict used by ``_generate_with_prompt_cache``.
|
| 699 |
+
"""
|
| 700 |
+
if (prompt_wav_path is None) != (prompt_text is None):
|
| 701 |
+
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
| 702 |
+
if prompt_wav_path is None and reference_wav_path is None:
|
| 703 |
+
raise ValueError("At least one of prompt_wav_path or reference_wav_path must be provided")
|
| 704 |
+
|
| 705 |
+
cache = {}
|
| 706 |
+
|
| 707 |
+
if reference_wav_path:
|
| 708 |
+
cache["ref_audio_feat"] = self._encode_wav(reference_wav_path, padding_mode="right")
|
| 709 |
+
|
| 710 |
+
if prompt_wav_path and prompt_text is not None:
|
| 711 |
+
cache["prompt_text"] = prompt_text
|
| 712 |
+
cache["audio_feat"] = self._encode_wav(prompt_wav_path, padding_mode="left")
|
| 713 |
+
|
| 714 |
+
has_ref = "ref_audio_feat" in cache
|
| 715 |
+
has_prompt = "audio_feat" in cache
|
| 716 |
+
if has_ref and has_prompt:
|
| 717 |
+
cache["mode"] = "ref_continuation"
|
| 718 |
+
elif has_ref:
|
| 719 |
+
cache["mode"] = "reference"
|
| 720 |
+
else:
|
| 721 |
+
cache["mode"] = "continuation"
|
| 722 |
+
|
| 723 |
+
return cache
|
| 724 |
+
|
| 725 |
+
def merge_prompt_cache(
|
| 726 |
+
self,
|
| 727 |
+
original_cache: dict,
|
| 728 |
+
new_text: str,
|
| 729 |
+
new_audio_feat: torch.Tensor,
|
| 730 |
+
):
|
| 731 |
+
"""
|
| 732 |
+
Merge original prompt cache with newly generated content to stabilize voice.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
original_cache: original prompt cache (any mode)
|
| 736 |
+
new_text: newly generated text
|
| 737 |
+
new_audio_feat: newly generated audio features
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
merged_cache: merged cache with prompt_text and audio_feat
|
| 741 |
+
"""
|
| 742 |
+
if original_cache is None:
|
| 743 |
+
return {
|
| 744 |
+
"prompt_text": new_text,
|
| 745 |
+
"audio_feat": new_audio_feat,
|
| 746 |
+
"mode": "continuation",
|
| 747 |
+
}
|
| 748 |
+
merged = {}
|
| 749 |
+
if "ref_audio_feat" in original_cache:
|
| 750 |
+
merged["ref_audio_feat"] = original_cache["ref_audio_feat"]
|
| 751 |
+
merged["prompt_text"] = original_cache.get("prompt_text", "") + new_text
|
| 752 |
+
old_feat = original_cache.get("audio_feat", new_audio_feat.new_empty(0, *new_audio_feat.shape[1:]))
|
| 753 |
+
merged["audio_feat"] = torch.cat([old_feat, new_audio_feat], dim=0)
|
| 754 |
+
merged["mode"] = "ref_continuation" if "ref_audio_feat" in merged else "continuation"
|
| 755 |
+
return merged
|
| 756 |
+
|
| 757 |
+
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 758 |
+
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
| 759 |
+
|
| 760 |
+
def generate_with_prompt_cache_streaming(
|
| 761 |
+
self, *args, **kwargs
|
| 762 |
+
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
| 763 |
+
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
| 764 |
+
|
| 765 |
+
@torch.inference_mode()
|
| 766 |
+
def _generate_with_prompt_cache(
|
| 767 |
+
self,
|
| 768 |
+
target_text: str,
|
| 769 |
+
prompt_cache: dict,
|
| 770 |
+
min_len: int = 2,
|
| 771 |
+
max_len: int = 2000,
|
| 772 |
+
inference_timesteps: int = 10,
|
| 773 |
+
cfg_value: float = 2.0,
|
| 774 |
+
retry_badcase: bool = False,
|
| 775 |
+
retry_badcase_max_times: int = 3,
|
| 776 |
+
retry_badcase_ratio_threshold: float = 6.0,
|
| 777 |
+
streaming: bool = False,
|
| 778 |
+
streaming_prefix_len: int = 4,
|
| 779 |
+
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
| 780 |
+
"""
|
| 781 |
+
Generate audio using pre-built prompt cache.
|
| 782 |
+
|
| 783 |
+
Args:
|
| 784 |
+
target_text: Text to convert to speech
|
| 785 |
+
prompt_cache: Cache built by ``build_prompt_cache()``. Can be None
|
| 786 |
+
for zero-shot generation.
|
| 787 |
+
min_len: Minimum audio length to avoid very short audio
|
| 788 |
+
max_len: Maximum audio length
|
| 789 |
+
inference_timesteps: Number of diffusion sampling steps
|
| 790 |
+
cfg_value: Classifier-free guidance value
|
| 791 |
+
retry_badcase: Whether to retry on bad cases
|
| 792 |
+
retry_badcase_max_times: Maximum retry attempts
|
| 793 |
+
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
| 794 |
+
streaming: Whether to return a generator of audio chunks
|
| 795 |
+
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
Generator of Tuple containing:
|
| 799 |
+
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
| 800 |
+
- Tensor of new text tokens
|
| 801 |
+
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
| 802 |
+
"""
|
| 803 |
+
if retry_badcase and streaming:
|
| 804 |
+
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
| 805 |
+
retry_badcase = False
|
| 806 |
+
|
| 807 |
+
# Determine mode from cache
|
| 808 |
+
if prompt_cache is None:
|
| 809 |
+
mode = "zero_shot"
|
| 810 |
+
text = target_text
|
| 811 |
+
else:
|
| 812 |
+
mode = prompt_cache.get("mode", "continuation")
|
| 813 |
+
if mode in ("continuation", "ref_continuation"):
|
| 814 |
+
prompt_text = prompt_cache.get("prompt_text", "")
|
| 815 |
+
text = prompt_text + target_text
|
| 816 |
+
else:
|
| 817 |
+
text = target_text
|
| 818 |
+
|
| 819 |
+
text_token = torch.LongTensor(self.text_tokenizer(text))
|
| 820 |
+
text_token = torch.cat(
|
| 821 |
+
[
|
| 822 |
+
text_token,
|
| 823 |
+
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
| 824 |
+
],
|
| 825 |
+
dim=-1,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
| 829 |
+
text_length = text_token.shape[0]
|
| 830 |
+
|
| 831 |
+
if mode in ("zero_shot", "continuation"):
|
| 832 |
+
prompt_audio_feat = (
|
| 833 |
+
prompt_cache["audio_feat"]
|
| 834 |
+
if prompt_cache
|
| 835 |
+
else torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
| 836 |
+
)
|
| 837 |
+
audio_length = prompt_audio_feat.size(0)
|
| 838 |
+
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
| 839 |
+
text_pad_feat = torch.zeros(
|
| 840 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 841 |
+
dtype=torch.float32,
|
| 842 |
+
device=text_token.device,
|
| 843 |
+
)
|
| 844 |
+
text_token = torch.cat([text_token, text_pad_token])
|
| 845 |
+
audio_feat = torch.cat([text_pad_feat, prompt_audio_feat], dim=0)
|
| 846 |
+
text_mask = torch.cat(
|
| 847 |
+
[torch.ones(text_length, dtype=torch.int32), torch.zeros(audio_length, dtype=torch.int32)]
|
| 848 |
+
).to(text_token.device)
|
| 849 |
+
audio_mask = torch.cat(
|
| 850 |
+
[torch.zeros(text_length, dtype=torch.int32), torch.ones(audio_length, dtype=torch.int32)]
|
| 851 |
+
).to(text_token.device)
|
| 852 |
+
|
| 853 |
+
elif mode == "reference":
|
| 854 |
+
ref_audio_feat = prompt_cache["ref_audio_feat"]
|
| 855 |
+
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
|
| 856 |
+
text_pad_feat = torch.zeros(
|
| 857 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 858 |
+
dtype=torch.float32,
|
| 859 |
+
device=text_token.device,
|
| 860 |
+
)
|
| 861 |
+
text_token = torch.cat([ref_tokens, text_token])
|
| 862 |
+
audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
|
| 863 |
+
text_mask = torch.cat([ref_t_mask, torch.ones(text_length, dtype=torch.int32).to(text_token.device)])
|
| 864 |
+
audio_mask = torch.cat([ref_a_mask, torch.zeros(text_length, dtype=torch.int32).to(text_token.device)])
|
| 865 |
+
|
| 866 |
+
else:
|
| 867 |
+
# ref_continuation mode
|
| 868 |
+
ref_audio_feat = prompt_cache["ref_audio_feat"]
|
| 869 |
+
prompt_audio_feat = prompt_cache["audio_feat"]
|
| 870 |
+
prompt_audio_length = prompt_audio_feat.size(0)
|
| 871 |
+
|
| 872 |
+
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
|
| 873 |
+
|
| 874 |
+
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
|
| 875 |
+
text_pad_feat = torch.zeros(
|
| 876 |
+
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
| 877 |
+
dtype=torch.float32,
|
| 878 |
+
device=text_token.device,
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
|
| 882 |
+
audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_audio_feat], dim=0)
|
| 883 |
+
text_mask = torch.cat(
|
| 884 |
+
[
|
| 885 |
+
ref_t_mask,
|
| 886 |
+
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
|
| 887 |
+
torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
|
| 888 |
+
]
|
| 889 |
+
)
|
| 890 |
+
audio_mask = torch.cat(
|
| 891 |
+
[
|
| 892 |
+
ref_a_mask,
|
| 893 |
+
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
|
| 894 |
+
torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
|
| 895 |
+
]
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
text_token = text_token.unsqueeze(0).to(self.device)
|
| 899 |
+
text_mask = text_mask.unsqueeze(0).to(self.device)
|
| 900 |
+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
| 901 |
+
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
| 902 |
+
|
| 903 |
+
# run inference
|
| 904 |
+
target_text_length = len(self.text_tokenizer(target_text))
|
| 905 |
+
retry_badcase_times = 0
|
| 906 |
+
while retry_badcase_times < retry_badcase_max_times:
|
| 907 |
+
inference_result = self._inference(
|
| 908 |
+
text_token,
|
| 909 |
+
text_mask,
|
| 910 |
+
audio_feat,
|
| 911 |
+
audio_mask,
|
| 912 |
+
min_len=min_len,
|
| 913 |
+
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
|
| 914 |
+
inference_timesteps=inference_timesteps,
|
| 915 |
+
cfg_value=cfg_value,
|
| 916 |
+
streaming=streaming,
|
| 917 |
+
streaming_prefix_len=streaming_prefix_len,
|
| 918 |
+
)
|
| 919 |
+
if streaming:
|
| 920 |
+
out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
|
| 921 |
+
for latent_pred, pred_audio_feat in inference_result:
|
| 922 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 923 |
+
decode_audio = decode_audio[..., -out_patch_len:].squeeze(1).cpu()
|
| 924 |
+
yield (decode_audio, target_text_token, pred_audio_feat)
|
| 925 |
+
break
|
| 926 |
+
else:
|
| 927 |
+
latent_pred, pred_audio_feat = next(inference_result)
|
| 928 |
+
if retry_badcase:
|
| 929 |
+
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
| 930 |
+
print(
|
| 931 |
+
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
| 932 |
+
file=sys.stderr,
|
| 933 |
+
)
|
| 934 |
+
retry_badcase_times += 1
|
| 935 |
+
continue
|
| 936 |
+
else:
|
| 937 |
+
break
|
| 938 |
+
else:
|
| 939 |
+
break
|
| 940 |
+
if not streaming:
|
| 941 |
+
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
| 942 |
+
out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
|
| 943 |
+
if mode in ("continuation", "ref_continuation"):
|
| 944 |
+
decode_audio = decode_audio[..., out_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
| 945 |
+
else:
|
| 946 |
+
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
| 947 |
+
yield (decode_audio, target_text_token, pred_audio_feat)
|
| 948 |
+
|
| 949 |
+
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 950 |
+
return next(self._inference(*args, streaming=False, **kwargs))
|
| 951 |
+
|
| 952 |
+
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
| 953 |
+
return self._inference(*args, streaming=True, **kwargs)
|
| 954 |
+
|
| 955 |
+
@torch.inference_mode()
|
| 956 |
+
def _inference(
|
| 957 |
+
self,
|
| 958 |
+
text: torch.Tensor,
|
| 959 |
+
text_mask: torch.Tensor,
|
| 960 |
+
feat: torch.Tensor,
|
| 961 |
+
feat_mask: torch.Tensor,
|
| 962 |
+
min_len: int = 2,
|
| 963 |
+
max_len: int = 2000,
|
| 964 |
+
inference_timesteps: int = 10,
|
| 965 |
+
cfg_value: float = 2.0,
|
| 966 |
+
streaming: bool = False,
|
| 967 |
+
streaming_prefix_len: int = 4,
|
| 968 |
+
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
| 969 |
+
"""Core inference method for audio generation.
|
| 970 |
+
|
| 971 |
+
This is the main inference loop that generates audio features
|
| 972 |
+
using the language model and diffusion transformer.
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
text: Input text tokens
|
| 976 |
+
text_mask: Mask for text tokens
|
| 977 |
+
feat: Input audio features
|
| 978 |
+
feat_mask: Mask for audio features
|
| 979 |
+
min_len: Minimum generation length
|
| 980 |
+
max_len: Maximum generation length
|
| 981 |
+
inference_timesteps: Number of diffusion steps
|
| 982 |
+
cfg_value: Classifier-free guidance value
|
| 983 |
+
streaming: Whether to yield each step latent feature or just the final result
|
| 984 |
+
|
| 985 |
+
Returns:
|
| 986 |
+
Generator of Tuple containing:
|
| 987 |
+
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
| 988 |
+
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
| 989 |
+
"""
|
| 990 |
+
B, T, P, D = feat.shape
|
| 991 |
+
|
| 992 |
+
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
| 993 |
+
feat_embed = self.enc_to_lm_proj(feat_embed)
|
| 994 |
+
|
| 995 |
+
if self.config.lm_config.use_mup:
|
| 996 |
+
scale_emb = self.config.lm_config.scale_emb
|
| 997 |
+
else:
|
| 998 |
+
scale_emb = 1.0
|
| 999 |
+
|
| 1000 |
+
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
| 1001 |
+
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
| 1002 |
+
|
| 1003 |
+
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
| 1004 |
+
pred_feat_seq = [] # b, t, p, d
|
| 1005 |
+
curr_embed = None
|
| 1006 |
+
|
| 1007 |
+
# Prepare prompt context patches for streaming mode
|
| 1008 |
+
# - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
|
| 1009 |
+
# trailing audio patches as initial context so the VAE can decode smoothly.
|
| 1010 |
+
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
| 1011 |
+
has_continuation_audio = feat_mask[0, -1].item() == 1
|
| 1012 |
+
if has_continuation_audio:
|
| 1013 |
+
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
| 1014 |
+
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
| 1015 |
+
last_audio_indices = audio_indices[-context_len:]
|
| 1016 |
+
pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
|
| 1017 |
+
else:
|
| 1018 |
+
pred_feat_seq = []
|
| 1019 |
+
|
| 1020 |
+
enc_outputs, kv_cache_tuple = self.base_lm(
|
| 1021 |
+
inputs_embeds=combined_embed,
|
| 1022 |
+
is_causal=True,
|
| 1023 |
+
)
|
| 1024 |
+
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
| 1025 |
+
|
| 1026 |
+
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
| 1027 |
+
lm_hidden = enc_outputs[:, -1, :]
|
| 1028 |
+
|
| 1029 |
+
residual_enc_inputs = self.fusion_concat_proj(
|
| 1030 |
+
torch.cat((enc_outputs, feat_mask.unsqueeze(-1) * feat_embed), dim=-1)
|
| 1031 |
+
)
|
| 1032 |
+
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
| 1033 |
+
inputs_embeds=residual_enc_inputs,
|
| 1034 |
+
is_causal=True,
|
| 1035 |
+
)
|
| 1036 |
+
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
| 1037 |
+
residual_hidden = residual_enc_outputs[:, -1, :]
|
| 1038 |
+
|
| 1039 |
+
for i in tqdm(range(max_len)):
|
| 1040 |
+
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
| 1041 |
+
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
| 1042 |
+
dit_hidden = torch.cat((dit_hidden_1, dit_hidden_2), dim=-1)
|
| 1043 |
+
|
| 1044 |
+
pred_feat = self.feat_decoder(
|
| 1045 |
+
mu=dit_hidden,
|
| 1046 |
+
patch_size=self.patch_size,
|
| 1047 |
+
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
| 1048 |
+
n_timesteps=inference_timesteps,
|
| 1049 |
+
cfg_value=cfg_value,
|
| 1050 |
+
).transpose(
|
| 1051 |
+
1, 2
|
| 1052 |
+
) # [b, p, d]
|
| 1053 |
+
|
| 1054 |
+
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
| 1055 |
+
curr_embed = self.enc_to_lm_proj(curr_embed)
|
| 1056 |
+
|
| 1057 |
+
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
| 1058 |
+
prefix_feat_cond = pred_feat
|
| 1059 |
+
|
| 1060 |
+
if streaming:
|
| 1061 |
+
# return the last three predicted latent features to provide enough context for smooth decoding
|
| 1062 |
+
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
| 1063 |
+
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
| 1064 |
+
|
| 1065 |
+
yield feat_pred, pred_feat_seq
|
| 1066 |
+
|
| 1067 |
+
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
| 1068 |
+
if i > min_len and stop_flag == 1:
|
| 1069 |
+
break
|
| 1070 |
+
|
| 1071 |
+
lm_hidden = self.base_lm.forward_step(
|
| 1072 |
+
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
| 1073 |
+
).clone()
|
| 1074 |
+
|
| 1075 |
+
lm_hidden = self.fsq_layer(lm_hidden)
|
| 1076 |
+
curr_residual_input = self.fusion_concat_proj(torch.cat((lm_hidden, curr_embed[:, 0, :]), dim=-1))
|
| 1077 |
+
residual_hidden = self.residual_lm.forward_step(
|
| 1078 |
+
curr_residual_input, torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
| 1079 |
+
).clone()
|
| 1080 |
+
|
| 1081 |
+
if not streaming:
|
| 1082 |
+
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
| 1083 |
+
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
| 1084 |
+
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
| 1085 |
+
|
| 1086 |
+
@classmethod
|
| 1087 |
+
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
| 1088 |
+
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
| 1089 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
| 1090 |
+
audio_vae_config = getattr(config, "audio_vae_config", None)
|
| 1091 |
+
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
|
| 1092 |
+
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
| 1093 |
+
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
| 1094 |
+
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
| 1095 |
+
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
| 1096 |
+
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
| 1097 |
+
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
| 1098 |
+
elif os.path.exists(audiovae_pth_path):
|
| 1099 |
+
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
| 1100 |
+
checkpoint = torch.load(
|
| 1101 |
+
audiovae_pth_path,
|
| 1102 |
+
map_location="cpu",
|
| 1103 |
+
weights_only=True,
|
| 1104 |
+
)
|
| 1105 |
+
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
| 1106 |
+
else:
|
| 1107 |
+
raise FileNotFoundError(
|
| 1108 |
+
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
| 1109 |
+
)
|
| 1110 |
+
model = cls(config, tokenizer, audio_vae, lora_config)
|
| 1111 |
+
if not training:
|
| 1112 |
+
lm_dtype = get_dtype(model.config.dtype)
|
| 1113 |
+
model = model.to(lm_dtype)
|
| 1114 |
+
else: # training mode
|
| 1115 |
+
for name, param in model.named_parameters():
|
| 1116 |
+
if "audio_vae" in name: # freeze VAE weights
|
| 1117 |
+
param.requires_grad = False
|
| 1118 |
+
continue
|
| 1119 |
+
if lora_config is not None:
|
| 1120 |
+
if "lora" not in name: # freeze non-LoRA weights
|
| 1121 |
+
param.requires_grad = False
|
| 1122 |
+
model.audio_vae = model.audio_vae.to(torch.float32)
|
| 1123 |
+
|
| 1124 |
+
# Try to load from safetensors first, fallback to pytorch_model.bin
|
| 1125 |
+
safetensors_path = os.path.join(path, "model.safetensors")
|
| 1126 |
+
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
| 1127 |
+
|
| 1128 |
+
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
| 1129 |
+
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
| 1130 |
+
model_state_dict = load_file(safetensors_path)
|
| 1131 |
+
elif os.path.exists(pytorch_model_path):
|
| 1132 |
+
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
|
| 1133 |
+
checkpoint = torch.load(
|
| 1134 |
+
pytorch_model_path,
|
| 1135 |
+
map_location="cpu",
|
| 1136 |
+
weights_only=True,
|
| 1137 |
+
)
|
| 1138 |
+
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
| 1139 |
+
else:
|
| 1140 |
+
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
| 1141 |
+
|
| 1142 |
+
for kw, val in vae_state_dict.items():
|
| 1143 |
+
model_state_dict[f"audio_vae.{kw}"] = val
|
| 1144 |
+
|
| 1145 |
+
# LoRALinear keeps weight/bias compatible with nn.Linear but adds
|
| 1146 |
+
# lora_A/lora_B, which are absent from base pretrained checkpoints.
|
| 1147 |
+
model.load_state_dict(model_state_dict, strict=False)
|
| 1148 |
+
if training:
|
| 1149 |
+
return model
|
| 1150 |
+
return model.to(model.device).eval().optimize(disable=not optimize)
|
| 1151 |
+
|
| 1152 |
+
# ------------------------------------------------------------------ #
|
| 1153 |
+
# LoRA Weight Management
|
| 1154 |
+
# ------------------------------------------------------------------ #
|
| 1155 |
+
def _iter_lora_modules(self):
|
| 1156 |
+
"""Iterate over all LoRA modules."""
|
| 1157 |
+
from ..modules.layers.lora import LoRALinear
|
| 1158 |
+
|
| 1159 |
+
for module in self.modules():
|
| 1160 |
+
if isinstance(module, LoRALinear):
|
| 1161 |
+
yield module
|
| 1162 |
+
|
| 1163 |
+
def load_lora_weights(self, lora_path: str, device: str = None):
|
| 1164 |
+
"""
|
| 1165 |
+
Load LoRA weights from file, supports calling after torch.compile.
|
| 1166 |
+
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
| 1167 |
+
Supports both safetensors and pytorch formats.
|
| 1168 |
+
|
| 1169 |
+
Args:
|
| 1170 |
+
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
| 1171 |
+
device: Target device, defaults to model's current device
|
| 1172 |
+
Returns:
|
| 1173 |
+
tuple: (loaded_keys, skipped_keys)
|
| 1174 |
+
"""
|
| 1175 |
+
from pathlib import Path
|
| 1176 |
+
|
| 1177 |
+
device = device or self.device
|
| 1178 |
+
lora_p = Path(lora_path)
|
| 1179 |
+
|
| 1180 |
+
# Try safetensors first, then fallback to .ckpt
|
| 1181 |
+
if lora_p.is_dir():
|
| 1182 |
+
safetensors_file = lora_p / "lora_weights.safetensors"
|
| 1183 |
+
ckpt_file = lora_p / "lora_weights.ckpt"
|
| 1184 |
+
else:
|
| 1185 |
+
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
| 1186 |
+
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
| 1187 |
+
|
| 1188 |
+
# Load from safetensors if available
|
| 1189 |
+
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
| 1190 |
+
state_dict = load_file(str(safetensors_file), device=device)
|
| 1191 |
+
elif ckpt_file and ckpt_file.exists():
|
| 1192 |
+
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
| 1193 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 1194 |
+
else:
|
| 1195 |
+
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
| 1196 |
+
|
| 1197 |
+
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
| 1198 |
+
model_params = dict(self.named_parameters())
|
| 1199 |
+
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
| 1200 |
+
|
| 1201 |
+
loaded_keys, skipped_keys = [], []
|
| 1202 |
+
for key, value in state_dict.items():
|
| 1203 |
+
target_key = key if key in model_params else key_mapping.get(key)
|
| 1204 |
+
if target_key:
|
| 1205 |
+
model_params[target_key].data.copy_(value.to(device))
|
| 1206 |
+
loaded_keys.append(key)
|
| 1207 |
+
else:
|
| 1208 |
+
skipped_keys.append(key)
|
| 1209 |
+
|
| 1210 |
+
return loaded_keys, skipped_keys
|
| 1211 |
+
|
| 1212 |
+
def set_lora_enabled(self, enabled: bool):
|
| 1213 |
+
"""Enable/disable all LoRA layers."""
|
| 1214 |
+
for module in self._iter_lora_modules():
|
| 1215 |
+
module.set_enabled(enabled)
|
| 1216 |
+
|
| 1217 |
+
def reset_lora_weights(self):
|
| 1218 |
+
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
|
| 1219 |
+
for module in self._iter_lora_modules():
|
| 1220 |
+
module.reset_lora_parameters()
|
| 1221 |
+
|
| 1222 |
+
def get_lora_state_dict(self) -> dict:
|
| 1223 |
+
"""Get all LoRA parameters (lora_A/lora_B)."""
|
| 1224 |
+
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
voxcpm/modules/__init__.py
ADDED
|
File without changes
|
voxcpm/modules/audiovae/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .audio_vae import AudioVAE, AudioVAEConfig
|
| 2 |
+
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
|
voxcpm/modules/audiovae/audio_vae.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def WNConv1d(*args, **kwargs):
|
| 13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 17 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CausalConv1d(nn.Conv1d):
|
| 21 |
+
def __init__(self, *args, padding: int = 0, **kwargs):
|
| 22 |
+
super().__init__(*args, **kwargs)
|
| 23 |
+
self.__padding = padding
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x_pad = F.pad(x, (self.__padding * 2, 0))
|
| 27 |
+
return super().forward(x_pad)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
| 31 |
+
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
| 32 |
+
super().__init__(*args, **kwargs)
|
| 33 |
+
self.__padding = padding
|
| 34 |
+
self.__output_padding = output_padding
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def WNCausalConv1d(*args, **kwargs):
|
| 41 |
+
return weight_norm(CausalConv1d(*args, **kwargs))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def WNCausalTransposeConv1d(*args, **kwargs):
|
| 45 |
+
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Scripting this brings model speed up 1.4x
|
| 49 |
+
@torch.jit.script
|
| 50 |
+
def snake(x, alpha):
|
| 51 |
+
shape = x.shape
|
| 52 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 53 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 54 |
+
x = x.reshape(shape)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Snake1d(nn.Module):
|
| 59 |
+
def __init__(self, channels):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return snake(x, self.alpha)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def init_weights(m):
|
| 68 |
+
if isinstance(m, nn.Conv1d):
|
| 69 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 70 |
+
if m.bias is not None:
|
| 71 |
+
nn.init.constant_(m.bias, 0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CausalResidualUnit(nn.Module):
|
| 75 |
+
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
| 76 |
+
super().__init__()
|
| 77 |
+
pad = ((7 - 1) * dilation) // 2
|
| 78 |
+
self.block = nn.Sequential(
|
| 79 |
+
Snake1d(dim),
|
| 80 |
+
WNCausalConv1d(
|
| 81 |
+
dim,
|
| 82 |
+
dim,
|
| 83 |
+
kernel_size=kernel,
|
| 84 |
+
dilation=dilation,
|
| 85 |
+
padding=pad,
|
| 86 |
+
groups=groups,
|
| 87 |
+
),
|
| 88 |
+
Snake1d(dim),
|
| 89 |
+
WNCausalConv1d(dim, dim, kernel_size=1),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
y = self.block(x)
|
| 94 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 95 |
+
assert pad == 0
|
| 96 |
+
if pad > 0:
|
| 97 |
+
x = x[..., pad:-pad]
|
| 98 |
+
return x + y
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CausalEncoderBlock(nn.Module):
|
| 102 |
+
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
| 103 |
+
super().__init__()
|
| 104 |
+
input_dim = input_dim or output_dim // 2
|
| 105 |
+
self.block = nn.Sequential(
|
| 106 |
+
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
| 107 |
+
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
| 108 |
+
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
| 109 |
+
Snake1d(input_dim),
|
| 110 |
+
WNCausalConv1d(
|
| 111 |
+
input_dim,
|
| 112 |
+
output_dim,
|
| 113 |
+
kernel_size=2 * stride,
|
| 114 |
+
stride=stride,
|
| 115 |
+
padding=math.ceil(stride / 2),
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
return self.block(x)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CausalEncoder(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
d_model: int = 64,
|
| 127 |
+
latent_dim: int = 32,
|
| 128 |
+
strides: list = [2, 4, 8, 8],
|
| 129 |
+
depthwise: bool = False,
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
# Create first convolution
|
| 133 |
+
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 134 |
+
|
| 135 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 136 |
+
for stride in strides:
|
| 137 |
+
d_model *= 2
|
| 138 |
+
groups = d_model // 2 if depthwise else 1
|
| 139 |
+
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
| 140 |
+
|
| 141 |
+
groups = d_model if depthwise else 1
|
| 142 |
+
|
| 143 |
+
# Create two convolution, for mu and logvar
|
| 144 |
+
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
| 145 |
+
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
| 146 |
+
|
| 147 |
+
# Wrap black into nn.Sequential
|
| 148 |
+
self.block = nn.Sequential(*self.block)
|
| 149 |
+
self.enc_dim = d_model
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
hidden_state = self.block(x)
|
| 153 |
+
return {
|
| 154 |
+
"hidden_state": hidden_state,
|
| 155 |
+
"mu": self.fc_mu(hidden_state),
|
| 156 |
+
"logvar": self.fc_logvar(hidden_state),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class NoiseBlock(nn.Module):
|
| 161 |
+
def __init__(self, dim):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
B, C, T = x.shape
|
| 167 |
+
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
| 168 |
+
h = self.linear(x)
|
| 169 |
+
n = noise * h
|
| 170 |
+
x = x + n
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class CausalDecoderBlock(nn.Module):
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
input_dim: int = 16,
|
| 178 |
+
output_dim: int = 8,
|
| 179 |
+
stride: int = 1,
|
| 180 |
+
groups=1,
|
| 181 |
+
use_noise_block: bool = False,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
layers = [
|
| 185 |
+
Snake1d(input_dim),
|
| 186 |
+
WNCausalTransposeConv1d(
|
| 187 |
+
input_dim,
|
| 188 |
+
output_dim,
|
| 189 |
+
kernel_size=2 * stride,
|
| 190 |
+
stride=stride,
|
| 191 |
+
padding=math.ceil(stride / 2),
|
| 192 |
+
output_padding=stride % 2,
|
| 193 |
+
),
|
| 194 |
+
]
|
| 195 |
+
if use_noise_block:
|
| 196 |
+
layers.append(NoiseBlock(output_dim))
|
| 197 |
+
layers.extend(
|
| 198 |
+
[
|
| 199 |
+
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
| 200 |
+
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
| 201 |
+
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
self.block = nn.Sequential(*layers)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
return self.block(x)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TransposeLastTwoDim(torch.nn.Module):
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
return torch.transpose(x, -1, -2)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class CausalDecoder(nn.Module):
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
input_channel,
|
| 219 |
+
channels,
|
| 220 |
+
rates,
|
| 221 |
+
depthwise: bool = False,
|
| 222 |
+
d_out: int = 1,
|
| 223 |
+
use_noise_block: bool = False,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
|
| 227 |
+
# Add first conv layer
|
| 228 |
+
if depthwise:
|
| 229 |
+
layers = [
|
| 230 |
+
WNCausalConv1d(
|
| 231 |
+
input_channel,
|
| 232 |
+
input_channel,
|
| 233 |
+
kernel_size=7,
|
| 234 |
+
padding=3,
|
| 235 |
+
groups=input_channel,
|
| 236 |
+
),
|
| 237 |
+
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
| 238 |
+
]
|
| 239 |
+
else:
|
| 240 |
+
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 241 |
+
|
| 242 |
+
# Add upsampling + MRF blocks
|
| 243 |
+
for i, stride in enumerate(rates):
|
| 244 |
+
input_dim = channels // 2**i
|
| 245 |
+
output_dim = channels // 2 ** (i + 1)
|
| 246 |
+
groups = output_dim if depthwise else 1
|
| 247 |
+
layers += [
|
| 248 |
+
CausalDecoderBlock(
|
| 249 |
+
input_dim,
|
| 250 |
+
output_dim,
|
| 251 |
+
stride,
|
| 252 |
+
groups=groups,
|
| 253 |
+
use_noise_block=use_noise_block,
|
| 254 |
+
)
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
# Add final conv layer
|
| 258 |
+
layers += [
|
| 259 |
+
Snake1d(output_dim),
|
| 260 |
+
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 261 |
+
nn.Tanh(),
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
self.model = nn.Sequential(*layers)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
return self.model(x)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class AudioVAEConfig(BaseModel):
|
| 271 |
+
encoder_dim: int = 128
|
| 272 |
+
encoder_rates: List[int] = [2, 5, 8, 8]
|
| 273 |
+
latent_dim: int = 64
|
| 274 |
+
decoder_dim: int = 1536
|
| 275 |
+
decoder_rates: List[int] = [8, 8, 5, 2]
|
| 276 |
+
depthwise: bool = True
|
| 277 |
+
sample_rate: int = 16000
|
| 278 |
+
use_noise_block: bool = False
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class AudioVAE(nn.Module):
|
| 282 |
+
"""
|
| 283 |
+
Args:
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
config: AudioVAEConfig = None,
|
| 289 |
+
):
|
| 290 |
+
# 如果没有传入config,使用默认配置
|
| 291 |
+
if config is None:
|
| 292 |
+
config = AudioVAEConfig()
|
| 293 |
+
|
| 294 |
+
super().__init__()
|
| 295 |
+
|
| 296 |
+
encoder_dim = config.encoder_dim
|
| 297 |
+
encoder_rates = config.encoder_rates
|
| 298 |
+
latent_dim = config.latent_dim
|
| 299 |
+
decoder_dim = config.decoder_dim
|
| 300 |
+
decoder_rates = config.decoder_rates
|
| 301 |
+
depthwise = config.depthwise
|
| 302 |
+
sample_rate = config.sample_rate
|
| 303 |
+
use_noise_block = config.use_noise_block
|
| 304 |
+
|
| 305 |
+
self.encoder_dim = encoder_dim
|
| 306 |
+
self.encoder_rates = encoder_rates
|
| 307 |
+
self.decoder_dim = decoder_dim
|
| 308 |
+
self.decoder_rates = decoder_rates
|
| 309 |
+
self.depthwise = depthwise
|
| 310 |
+
|
| 311 |
+
self.use_noise_block = use_noise_block
|
| 312 |
+
|
| 313 |
+
if latent_dim is None:
|
| 314 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 315 |
+
|
| 316 |
+
self.latent_dim = latent_dim
|
| 317 |
+
self.hop_length = np.prod(encoder_rates)
|
| 318 |
+
self.encoder = CausalEncoder(
|
| 319 |
+
encoder_dim,
|
| 320 |
+
latent_dim,
|
| 321 |
+
encoder_rates,
|
| 322 |
+
depthwise=depthwise,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.decoder = CausalDecoder(
|
| 326 |
+
latent_dim,
|
| 327 |
+
decoder_dim,
|
| 328 |
+
decoder_rates,
|
| 329 |
+
depthwise=depthwise,
|
| 330 |
+
use_noise_block=use_noise_block,
|
| 331 |
+
)
|
| 332 |
+
self.sample_rate = sample_rate
|
| 333 |
+
self.chunk_size = math.prod(encoder_rates)
|
| 334 |
+
|
| 335 |
+
def preprocess(self, audio_data, sample_rate):
|
| 336 |
+
if sample_rate is None:
|
| 337 |
+
sample_rate = self.sample_rate
|
| 338 |
+
assert sample_rate == self.sample_rate
|
| 339 |
+
pad_to = self.hop_length
|
| 340 |
+
length = audio_data.shape[-1]
|
| 341 |
+
right_pad = math.ceil(length / pad_to) * pad_to - length
|
| 342 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
| 343 |
+
|
| 344 |
+
return audio_data
|
| 345 |
+
|
| 346 |
+
def decode(self, z: torch.Tensor):
|
| 347 |
+
"""Decode given latent codes and return audio data
|
| 348 |
+
|
| 349 |
+
Parameters
|
| 350 |
+
----------
|
| 351 |
+
z : Tensor[B x D x T]
|
| 352 |
+
Quantized continuous representation of input
|
| 353 |
+
length : int, optional
|
| 354 |
+
Number of samples in output audio, by default None
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
dict
|
| 359 |
+
A dictionary with the following keys:
|
| 360 |
+
"audio" : Tensor[B x 1 x length]
|
| 361 |
+
Decoded audio data.
|
| 362 |
+
"""
|
| 363 |
+
return self.decoder(z)
|
| 364 |
+
|
| 365 |
+
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
| 366 |
+
"""
|
| 367 |
+
Args:
|
| 368 |
+
audio_data: Tensor[B x 1 x T]
|
| 369 |
+
sample_rate: int
|
| 370 |
+
Returns:
|
| 371 |
+
z: Tensor[B x D x T]
|
| 372 |
+
"""
|
| 373 |
+
if audio_data.ndim == 2:
|
| 374 |
+
audio_data = audio_data.unsqueeze(1)
|
| 375 |
+
|
| 376 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
| 377 |
+
return self.encoder(audio_data)["mu"]
|
voxcpm/modules/audiovae/audio_vae_v2.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def WNConv1d(*args, **kwargs):
|
| 13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 17 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CausalConv1d(nn.Conv1d):
|
| 21 |
+
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
| 22 |
+
super().__init__(*args, **kwargs)
|
| 23 |
+
self.__padding = padding
|
| 24 |
+
self.__output_padding = output_padding
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
|
| 28 |
+
return super().forward(x_pad)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
| 32 |
+
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
| 33 |
+
super().__init__(*args, **kwargs)
|
| 34 |
+
self.__padding = padding
|
| 35 |
+
self.__output_padding = output_padding
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def WNCausalConv1d(*args, **kwargs):
|
| 42 |
+
return weight_norm(CausalConv1d(*args, **kwargs))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def WNCausalTransposeConv1d(*args, **kwargs):
|
| 46 |
+
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Scripting this brings model speed up 1.4x
|
| 50 |
+
@torch.jit.script
|
| 51 |
+
def snake(x, alpha):
|
| 52 |
+
shape = x.shape
|
| 53 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 54 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 55 |
+
x = x.reshape(shape)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Snake1d(nn.Module):
|
| 60 |
+
def __init__(self, channels):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return snake(x, self.alpha)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_weights(m):
|
| 69 |
+
if isinstance(m, nn.Conv1d):
|
| 70 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 71 |
+
if m.bias is not None:
|
| 72 |
+
nn.init.constant_(m.bias, 0)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class CausalResidualUnit(nn.Module):
|
| 76 |
+
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
| 77 |
+
super().__init__()
|
| 78 |
+
pad = ((7 - 1) * dilation) // 2
|
| 79 |
+
self.block = nn.Sequential(
|
| 80 |
+
Snake1d(dim),
|
| 81 |
+
WNCausalConv1d(
|
| 82 |
+
dim,
|
| 83 |
+
dim,
|
| 84 |
+
kernel_size=kernel,
|
| 85 |
+
dilation=dilation,
|
| 86 |
+
padding=pad,
|
| 87 |
+
groups=groups,
|
| 88 |
+
),
|
| 89 |
+
Snake1d(dim),
|
| 90 |
+
WNCausalConv1d(dim, dim, kernel_size=1),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
y = self.block(x)
|
| 95 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 96 |
+
assert pad == 0
|
| 97 |
+
if pad > 0:
|
| 98 |
+
x = x[..., pad:-pad]
|
| 99 |
+
return x + y
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class CausalEncoderBlock(nn.Module):
|
| 103 |
+
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
| 104 |
+
super().__init__()
|
| 105 |
+
input_dim = input_dim or output_dim // 2
|
| 106 |
+
self.block = nn.Sequential(
|
| 107 |
+
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
| 108 |
+
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
| 109 |
+
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
| 110 |
+
Snake1d(input_dim),
|
| 111 |
+
WNCausalConv1d(
|
| 112 |
+
input_dim,
|
| 113 |
+
output_dim,
|
| 114 |
+
kernel_size=2 * stride,
|
| 115 |
+
stride=stride,
|
| 116 |
+
padding=math.ceil(stride / 2),
|
| 117 |
+
output_padding=stride % 2,
|
| 118 |
+
),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
return self.block(x)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class CausalEncoder(nn.Module):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
d_model: int = 64,
|
| 129 |
+
latent_dim: int = 32,
|
| 130 |
+
strides: list = [2, 4, 8, 8],
|
| 131 |
+
depthwise: bool = False,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
# Create first convolution
|
| 135 |
+
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 136 |
+
|
| 137 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 138 |
+
for stride in strides:
|
| 139 |
+
d_model *= 2
|
| 140 |
+
groups = d_model // 2 if depthwise else 1
|
| 141 |
+
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
| 142 |
+
|
| 143 |
+
groups = d_model if depthwise else 1
|
| 144 |
+
|
| 145 |
+
# Create two convolution, for mu and logvar
|
| 146 |
+
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
| 147 |
+
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
| 148 |
+
|
| 149 |
+
# Wrap black into nn.Sequential
|
| 150 |
+
self.block = nn.Sequential(*self.block)
|
| 151 |
+
self.enc_dim = d_model
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
hidden_state = self.block(x)
|
| 155 |
+
return {
|
| 156 |
+
"hidden_state": hidden_state,
|
| 157 |
+
"mu": self.fc_mu(hidden_state),
|
| 158 |
+
"logvar": self.fc_logvar(hidden_state),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class NoiseBlock(nn.Module):
|
| 163 |
+
def __init__(self, dim):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
B, C, T = x.shape
|
| 169 |
+
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
| 170 |
+
h = self.linear(x)
|
| 171 |
+
n = noise * h
|
| 172 |
+
x = x + n
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class CausalDecoderBlock(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
input_dim: int = 16,
|
| 180 |
+
output_dim: int = 8,
|
| 181 |
+
stride: int = 1,
|
| 182 |
+
groups=1,
|
| 183 |
+
use_noise_block: bool = False,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
layers = [
|
| 187 |
+
Snake1d(input_dim),
|
| 188 |
+
WNCausalTransposeConv1d(
|
| 189 |
+
input_dim,
|
| 190 |
+
output_dim,
|
| 191 |
+
kernel_size=2 * stride,
|
| 192 |
+
stride=stride,
|
| 193 |
+
padding=math.ceil(stride / 2),
|
| 194 |
+
output_padding=stride % 2,
|
| 195 |
+
),
|
| 196 |
+
]
|
| 197 |
+
if use_noise_block:
|
| 198 |
+
layers.append(NoiseBlock(output_dim))
|
| 199 |
+
layers.extend(
|
| 200 |
+
[
|
| 201 |
+
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
| 202 |
+
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
| 203 |
+
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
self.block = nn.Sequential(*layers)
|
| 207 |
+
self.input_channels = input_dim
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
return self.block(x)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TransposeLastTwoDim(torch.nn.Module):
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
return torch.transpose(x, -1, -2)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class SampleRateConditionLayer(nn.Module):
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
input_dim: int,
|
| 222 |
+
sr_bin_buckets: int = None,
|
| 223 |
+
cond_type: str = "scale_bias",
|
| 224 |
+
cond_dim: int = 128,
|
| 225 |
+
out_layer: bool = False,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
self.cond_type, out_layer_in_dim = cond_type, input_dim
|
| 230 |
+
|
| 231 |
+
if cond_type == "scale_bias":
|
| 232 |
+
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
| 233 |
+
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
| 234 |
+
nn.init.ones_(self.scale_embed.weight)
|
| 235 |
+
nn.init.zeros_(self.bias_embed.weight)
|
| 236 |
+
elif cond_type == "scale_bias_init":
|
| 237 |
+
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
| 238 |
+
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
| 239 |
+
nn.init.normal_(self.scale_embed.weight, mean=1)
|
| 240 |
+
nn.init.normal_(self.bias_embed.weight)
|
| 241 |
+
elif cond_type == "add":
|
| 242 |
+
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
| 243 |
+
nn.init.normal_(self.cond_embed.weight)
|
| 244 |
+
elif cond_type == "concat":
|
| 245 |
+
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
|
| 246 |
+
assert out_layer, "out_layer must be True for concat cond_type"
|
| 247 |
+
out_layer_in_dim = input_dim + cond_dim
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(f"Invalid cond_type: {cond_type}")
|
| 250 |
+
|
| 251 |
+
if out_layer:
|
| 252 |
+
self.out_layer = nn.Sequential(
|
| 253 |
+
Snake1d(out_layer_in_dim),
|
| 254 |
+
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
self.out_layer = nn.Identity()
|
| 258 |
+
|
| 259 |
+
def forward(self, x, sr_cond):
|
| 260 |
+
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
|
| 261 |
+
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
|
| 262 |
+
elif self.cond_type == "add":
|
| 263 |
+
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
|
| 264 |
+
elif self.cond_type == "concat":
|
| 265 |
+
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
| 266 |
+
|
| 267 |
+
return self.out_layer(x)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class CausalDecoder(nn.Module):
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
input_channel,
|
| 274 |
+
channels,
|
| 275 |
+
rates,
|
| 276 |
+
depthwise: bool = False,
|
| 277 |
+
d_out: int = 1,
|
| 278 |
+
use_noise_block: bool = False,
|
| 279 |
+
sr_bin_boundaries: List[int] = None,
|
| 280 |
+
cond_type: str = "scale_bias",
|
| 281 |
+
cond_dim: int = 128,
|
| 282 |
+
cond_out_layer: bool = False,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
# Add first conv layer
|
| 287 |
+
if depthwise:
|
| 288 |
+
layers = [
|
| 289 |
+
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
|
| 290 |
+
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
| 291 |
+
]
|
| 292 |
+
else:
|
| 293 |
+
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 294 |
+
|
| 295 |
+
# Add upsampling + MRF blocks
|
| 296 |
+
for i, stride in enumerate(rates):
|
| 297 |
+
input_dim = channels // 2**i
|
| 298 |
+
output_dim = channels // 2 ** (i + 1)
|
| 299 |
+
groups = output_dim if depthwise else 1
|
| 300 |
+
layers += [
|
| 301 |
+
CausalDecoderBlock(
|
| 302 |
+
input_dim,
|
| 303 |
+
output_dim,
|
| 304 |
+
stride,
|
| 305 |
+
groups=groups,
|
| 306 |
+
use_noise_block=use_noise_block,
|
| 307 |
+
)
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
# Add final conv layer
|
| 311 |
+
layers += [
|
| 312 |
+
Snake1d(output_dim),
|
| 313 |
+
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 314 |
+
nn.Tanh(),
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
if sr_bin_boundaries is None:
|
| 318 |
+
self.model = nn.Sequential(*layers)
|
| 319 |
+
self.sr_bin_boundaries = None
|
| 320 |
+
else:
|
| 321 |
+
self.model = nn.ModuleList(layers)
|
| 322 |
+
|
| 323 |
+
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
|
| 324 |
+
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
|
| 325 |
+
|
| 326 |
+
cond_layers = []
|
| 327 |
+
for layer in self.model:
|
| 328 |
+
if layer.__class__.__name__ == "CausalDecoderBlock":
|
| 329 |
+
cond_layers.append(
|
| 330 |
+
SampleRateConditionLayer(
|
| 331 |
+
input_dim=layer.input_channels,
|
| 332 |
+
sr_bin_buckets=self.sr_bin_buckets,
|
| 333 |
+
cond_type=cond_type,
|
| 334 |
+
cond_dim=cond_dim,
|
| 335 |
+
out_layer=cond_out_layer,
|
| 336 |
+
)
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
cond_layers.append(None)
|
| 340 |
+
self.sr_cond_model = nn.ModuleList(cond_layers)
|
| 341 |
+
|
| 342 |
+
def get_sr_idx(self, sr):
|
| 343 |
+
return torch.bucketize(sr, self.sr_bin_boundaries)
|
| 344 |
+
|
| 345 |
+
def forward(self, x, sr_cond=None):
|
| 346 |
+
if self.sr_bin_boundaries is not None:
|
| 347 |
+
# assert sr_cond is not None
|
| 348 |
+
sr_cond = self.get_sr_idx(sr_cond)
|
| 349 |
+
|
| 350 |
+
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
|
| 351 |
+
if sr_cond_layer is not None:
|
| 352 |
+
x = sr_cond_layer(x, sr_cond)
|
| 353 |
+
x = layer(x)
|
| 354 |
+
return x
|
| 355 |
+
else:
|
| 356 |
+
return self.model(x)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class AudioVAEConfig(BaseModel):
|
| 360 |
+
encoder_dim: int = 128
|
| 361 |
+
encoder_rates: List[int] = [2, 5, 8, 8]
|
| 362 |
+
latent_dim: int = 64
|
| 363 |
+
decoder_dim: int = 2048
|
| 364 |
+
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
|
| 365 |
+
depthwise: bool = True
|
| 366 |
+
sample_rate: int = 16000
|
| 367 |
+
out_sample_rate: int = 48000
|
| 368 |
+
use_noise_block: bool = False
|
| 369 |
+
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
|
| 370 |
+
cond_type: str = "scale_bias"
|
| 371 |
+
cond_dim: int = 128
|
| 372 |
+
cond_out_layer: bool = False
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class AudioVAE(nn.Module):
|
| 376 |
+
"""
|
| 377 |
+
Args:
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
config: AudioVAEConfig = None,
|
| 383 |
+
):
|
| 384 |
+
# 如果没有传入config,使用默认配置
|
| 385 |
+
if config is None:
|
| 386 |
+
config = AudioVAEConfig()
|
| 387 |
+
|
| 388 |
+
super().__init__()
|
| 389 |
+
|
| 390 |
+
encoder_dim = config.encoder_dim
|
| 391 |
+
encoder_rates = config.encoder_rates
|
| 392 |
+
latent_dim = config.latent_dim
|
| 393 |
+
decoder_dim = config.decoder_dim
|
| 394 |
+
decoder_rates = config.decoder_rates
|
| 395 |
+
depthwise = config.depthwise
|
| 396 |
+
sample_rate = config.sample_rate
|
| 397 |
+
out_sample_rate = config.out_sample_rate
|
| 398 |
+
use_noise_block = config.use_noise_block
|
| 399 |
+
sr_bin_boundaries = config.sr_bin_boundaries
|
| 400 |
+
cond_type = config.cond_type
|
| 401 |
+
cond_dim = config.cond_dim
|
| 402 |
+
cond_out_layer = config.cond_out_layer
|
| 403 |
+
|
| 404 |
+
self.encoder_dim = encoder_dim
|
| 405 |
+
self.encoder_rates = encoder_rates
|
| 406 |
+
self.decoder_dim = decoder_dim
|
| 407 |
+
self.decoder_rates = decoder_rates
|
| 408 |
+
self.depthwise = depthwise
|
| 409 |
+
|
| 410 |
+
self.use_noise_block = use_noise_block
|
| 411 |
+
|
| 412 |
+
if latent_dim is None:
|
| 413 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 414 |
+
|
| 415 |
+
self.latent_dim = latent_dim
|
| 416 |
+
self.hop_length = np.prod(encoder_rates)
|
| 417 |
+
self.encoder = CausalEncoder(
|
| 418 |
+
encoder_dim,
|
| 419 |
+
latent_dim,
|
| 420 |
+
encoder_rates,
|
| 421 |
+
depthwise=depthwise,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
self.decoder = CausalDecoder(
|
| 425 |
+
latent_dim,
|
| 426 |
+
decoder_dim,
|
| 427 |
+
decoder_rates,
|
| 428 |
+
depthwise=depthwise,
|
| 429 |
+
use_noise_block=use_noise_block,
|
| 430 |
+
sr_bin_boundaries=sr_bin_boundaries,
|
| 431 |
+
cond_type=cond_type,
|
| 432 |
+
cond_dim=cond_dim,
|
| 433 |
+
cond_out_layer=cond_out_layer,
|
| 434 |
+
)
|
| 435 |
+
self.sample_rate = sample_rate
|
| 436 |
+
self.out_sample_rate = out_sample_rate
|
| 437 |
+
self.sr_bin_boundaries = sr_bin_boundaries
|
| 438 |
+
self.chunk_size = math.prod(encoder_rates)
|
| 439 |
+
|
| 440 |
+
def preprocess(self, audio_data, sample_rate):
|
| 441 |
+
if sample_rate is None:
|
| 442 |
+
sample_rate = self.sample_rate
|
| 443 |
+
assert sample_rate == self.sample_rate
|
| 444 |
+
pad_to = self.hop_length
|
| 445 |
+
length = audio_data.shape[-1]
|
| 446 |
+
right_pad = math.ceil(length / pad_to) * pad_to - length
|
| 447 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
| 448 |
+
|
| 449 |
+
return audio_data
|
| 450 |
+
|
| 451 |
+
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
|
| 452 |
+
"""Decode given latent codes and return audio data
|
| 453 |
+
|
| 454 |
+
Parameters
|
| 455 |
+
----------
|
| 456 |
+
z : Tensor[B x D x T]
|
| 457 |
+
Quantized continuous representation of input
|
| 458 |
+
length : int, optional
|
| 459 |
+
Number of samples in output audio, by default None
|
| 460 |
+
|
| 461 |
+
Returns
|
| 462 |
+
-------
|
| 463 |
+
dict
|
| 464 |
+
A dictionary with the following keys:
|
| 465 |
+
"audio" : Tensor[B x 1 x length]
|
| 466 |
+
Decoded audio data.
|
| 467 |
+
"""
|
| 468 |
+
if self.sr_bin_boundaries is not None:
|
| 469 |
+
# use default output sample rate
|
| 470 |
+
if sr_cond is None:
|
| 471 |
+
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
| 472 |
+
return self.decoder(z, sr_cond)
|
| 473 |
+
|
| 474 |
+
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
| 475 |
+
"""
|
| 476 |
+
Args:
|
| 477 |
+
audio_data: Tensor[B x 1 x T]
|
| 478 |
+
sample_rate: int
|
| 479 |
+
Returns:
|
| 480 |
+
z: Tensor[B x D x T]
|
| 481 |
+
"""
|
| 482 |
+
if audio_data.ndim == 2:
|
| 483 |
+
audio_data = audio_data.unsqueeze(1)
|
| 484 |
+
|
| 485 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
| 486 |
+
return self.encoder(audio_data)["mu"]
|
voxcpm/modules/layers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .scalar_quantization_layer import ScalarQuantizationLayer
|
voxcpm/modules/layers/lora.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LoRALinear(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
|
| 12 |
+
|
| 13 |
+
state_dict 结构:
|
| 14 |
+
- weight: 原始权重(与 nn.Linear 一致)
|
| 15 |
+
- bias: 原始偏置(与 nn.Linear 一致)
|
| 16 |
+
- lora_A: LoRA 低秩矩阵 A
|
| 17 |
+
- lora_B: LoRA 低秩矩阵 B
|
| 18 |
+
|
| 19 |
+
这样设计的好处:加载预训练权重时无需做 key 转换。
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
base: nn.Linear,
|
| 25 |
+
r: int,
|
| 26 |
+
alpha: float = 1.0,
|
| 27 |
+
dropout: float = 0.0,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
|
| 31 |
+
|
| 32 |
+
self.in_features = base.in_features
|
| 33 |
+
self.out_features = base.out_features
|
| 34 |
+
self.r = r
|
| 35 |
+
self.alpha = alpha
|
| 36 |
+
self._base_scaling = alpha / r if r > 0 else 0.0
|
| 37 |
+
|
| 38 |
+
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
| 39 |
+
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
| 40 |
+
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
| 41 |
+
|
| 42 |
+
# 直接持有 weight 和 bias(从原始 Linear 转移过来)
|
| 43 |
+
self.weight = base.weight
|
| 44 |
+
self.bias = base.bias # 可能是 None
|
| 45 |
+
|
| 46 |
+
# LoRA 参数
|
| 47 |
+
if r > 0:
|
| 48 |
+
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
|
| 49 |
+
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
|
| 50 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 51 |
+
nn.init.zeros_(self.lora_B)
|
| 52 |
+
else:
|
| 53 |
+
self.register_parameter("lora_A", None)
|
| 54 |
+
self.register_parameter("lora_B", None)
|
| 55 |
+
|
| 56 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
# 基础 Linear 计算
|
| 60 |
+
result = F.linear(x, self.weight, self.bias)
|
| 61 |
+
if self.r <= 0 or self.lora_A is None:
|
| 62 |
+
return result
|
| 63 |
+
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
|
| 64 |
+
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
|
| 65 |
+
return result + self.dropout(lora_out) * self.scaling
|
| 66 |
+
|
| 67 |
+
def reset_lora_parameters(self):
|
| 68 |
+
"""重置 LoRA 参数到初始状态"""
|
| 69 |
+
if self.r > 0 and self.lora_A is not None:
|
| 70 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 71 |
+
nn.init.zeros_(self.lora_B)
|
| 72 |
+
|
| 73 |
+
def set_enabled(self, enabled: bool):
|
| 74 |
+
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
|
| 75 |
+
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
|
| 76 |
+
self.scaling.fill_(self._base_scaling if enabled else 0.0)
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def enabled(self) -> bool:
|
| 80 |
+
return self.scaling.item() != 0.0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
|
| 84 |
+
"""
|
| 85 |
+
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
|
| 86 |
+
"""
|
| 87 |
+
parts = name.split(".")
|
| 88 |
+
if len(parts) == 1:
|
| 89 |
+
return root
|
| 90 |
+
parent = root
|
| 91 |
+
for p in parts[:-1]:
|
| 92 |
+
if not hasattr(parent, p):
|
| 93 |
+
return None
|
| 94 |
+
parent = getattr(parent, p)
|
| 95 |
+
return parent
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def apply_lora_to_named_linear_modules(
|
| 99 |
+
root: nn.Module,
|
| 100 |
+
*,
|
| 101 |
+
target_submodule_names: list[str],
|
| 102 |
+
r: int,
|
| 103 |
+
alpha: float,
|
| 104 |
+
dropout: float,
|
| 105 |
+
) -> None:
|
| 106 |
+
"""
|
| 107 |
+
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
|
| 108 |
+
|
| 109 |
+
例如 target_submodule_names=["q_proj", "v_proj"] 时,
|
| 110 |
+
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
|
| 111 |
+
"""
|
| 112 |
+
for full_name, module in list(root.named_modules()):
|
| 113 |
+
if not isinstance(module, nn.Linear):
|
| 114 |
+
continue
|
| 115 |
+
short_name = full_name.split(".")[-1]
|
| 116 |
+
if short_name not in target_submodule_names:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
parent = _get_parent_module(root, full_name)
|
| 120 |
+
if parent is None:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# 用 LoRALinear 替换原始 Linear
|
| 124 |
+
lora_layer = LoRALinear(
|
| 125 |
+
base=module,
|
| 126 |
+
r=r,
|
| 127 |
+
alpha=alpha,
|
| 128 |
+
dropout=dropout,
|
| 129 |
+
)
|
| 130 |
+
setattr(parent, short_name, lora_layer)
|
voxcpm/modules/layers/scalar_quantization_layer.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ScalarQuantizationLayer(nn.Module):
|
| 6 |
+
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.in_dim = in_dim
|
| 9 |
+
self.out_dim = out_dim
|
| 10 |
+
self.latent_dim = latent_dim
|
| 11 |
+
self.scale = scale
|
| 12 |
+
|
| 13 |
+
self.in_proj = nn.Linear(in_dim, latent_dim)
|
| 14 |
+
self.out_proj = nn.Linear(latent_dim, out_dim)
|
| 15 |
+
|
| 16 |
+
def forward(self, hidden):
|
| 17 |
+
hidden = self.in_proj(hidden)
|
| 18 |
+
hidden = torch.tanh(hidden)
|
| 19 |
+
|
| 20 |
+
if self.training:
|
| 21 |
+
quantized = torch.round(hidden * self.scale) / self.scale
|
| 22 |
+
hidden = hidden + (quantized - hidden).detach()
|
| 23 |
+
else:
|
| 24 |
+
hidden = torch.round(hidden * self.scale) / self.scale
|
| 25 |
+
|
| 26 |
+
return self.out_proj(hidden)
|
voxcpm/modules/locdit/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .unified_cfm import UnifiedCFM, CfmConfig
|
| 2 |
+
from .local_dit import VoxCPMLocDiT
|
| 3 |
+
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
|
voxcpm/modules/locdit/local_dit.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
| 8 |
+
def __init__(self, dim):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dim = dim
|
| 11 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
| 12 |
+
|
| 13 |
+
def forward(self, x, scale=1000):
|
| 14 |
+
if x.ndim < 1:
|
| 15 |
+
x = x.unsqueeze(0)
|
| 16 |
+
device = x.device
|
| 17 |
+
half_dim = self.dim // 2
|
| 18 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 19 |
+
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
| 20 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 21 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 22 |
+
return emb
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TimestepEmbedding(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
in_channels: int,
|
| 29 |
+
time_embed_dim: int,
|
| 30 |
+
out_dim: int = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
| 35 |
+
self.act = nn.SiLU()
|
| 36 |
+
if out_dim is not None:
|
| 37 |
+
time_embed_dim_out = out_dim
|
| 38 |
+
else:
|
| 39 |
+
time_embed_dim_out = time_embed_dim
|
| 40 |
+
|
| 41 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
| 42 |
+
|
| 43 |
+
def forward(self, sample):
|
| 44 |
+
sample = self.linear_1(sample)
|
| 45 |
+
sample = self.act(sample)
|
| 46 |
+
sample = self.linear_2(sample)
|
| 47 |
+
return sample
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class VoxCPMLocDiT(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Diffusion model with a Transformer backbone.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
config: MiniCPM4Config,
|
| 58 |
+
in_channels: int = 64,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.in_channels = in_channels
|
| 62 |
+
self.out_channels = in_channels
|
| 63 |
+
self.config = config
|
| 64 |
+
|
| 65 |
+
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
| 66 |
+
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
| 67 |
+
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
| 68 |
+
|
| 69 |
+
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
| 70 |
+
self.time_mlp = TimestepEmbedding(
|
| 71 |
+
in_channels=config.hidden_size,
|
| 72 |
+
time_embed_dim=config.hidden_size,
|
| 73 |
+
)
|
| 74 |
+
self.delta_time_mlp = TimestepEmbedding(
|
| 75 |
+
in_channels=config.hidden_size,
|
| 76 |
+
time_embed_dim=config.hidden_size,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
| 80 |
+
self.decoder = MiniCPMModel(config)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
mu: torch.Tensor,
|
| 86 |
+
t: torch.Tensor,
|
| 87 |
+
cond: torch.Tensor,
|
| 88 |
+
dt: torch.Tensor,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Forward pass of DiT.
|
| 92 |
+
x: (N, C, T) tensor of inputs
|
| 93 |
+
mu: (N, C) tensor of hidden embedding
|
| 94 |
+
t: (N,) tensor of diffusion timesteps
|
| 95 |
+
cond: (N, C, T') tensor of prefix conditions
|
| 96 |
+
dt: (N,) used for mean velocity (may be supported in the future...)
|
| 97 |
+
"""
|
| 98 |
+
x = self.in_proj(x.transpose(1, 2).contiguous())
|
| 99 |
+
|
| 100 |
+
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
| 101 |
+
prefix = cond.size(1)
|
| 102 |
+
|
| 103 |
+
t = self.time_embeddings(t).to(x.dtype)
|
| 104 |
+
t = self.time_mlp(t)
|
| 105 |
+
dt = self.time_embeddings(dt).to(x.dtype)
|
| 106 |
+
dt = self.delta_time_mlp(dt)
|
| 107 |
+
t = t + dt
|
| 108 |
+
|
| 109 |
+
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
|
| 110 |
+
hidden, _ = self.decoder(x, is_causal=False)
|
| 111 |
+
hidden = hidden[:, prefix + 1 :, :]
|
| 112 |
+
hidden = self.out_proj(hidden)
|
| 113 |
+
|
| 114 |
+
return hidden.transpose(1, 2).contiguous()
|
voxcpm/modules/locdit/local_dit_v2.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
| 8 |
+
def __init__(self, dim):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dim = dim
|
| 11 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
| 12 |
+
|
| 13 |
+
def forward(self, x, scale=1000):
|
| 14 |
+
if x.ndim < 1:
|
| 15 |
+
x = x.unsqueeze(0)
|
| 16 |
+
device = x.device
|
| 17 |
+
half_dim = self.dim // 2
|
| 18 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 19 |
+
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
| 20 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 21 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 22 |
+
return emb
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TimestepEmbedding(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
in_channels: int,
|
| 29 |
+
time_embed_dim: int,
|
| 30 |
+
out_dim: int = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
| 35 |
+
self.act = nn.SiLU()
|
| 36 |
+
if out_dim is not None:
|
| 37 |
+
time_embed_dim_out = out_dim
|
| 38 |
+
else:
|
| 39 |
+
time_embed_dim_out = time_embed_dim
|
| 40 |
+
|
| 41 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
| 42 |
+
|
| 43 |
+
def forward(self, sample):
|
| 44 |
+
sample = self.linear_1(sample)
|
| 45 |
+
sample = self.act(sample)
|
| 46 |
+
sample = self.linear_2(sample)
|
| 47 |
+
return sample
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class VoxCPMLocDiT(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Diffusion model with a Transformer backbone.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
config: MiniCPM4Config,
|
| 58 |
+
in_channels: int = 64,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.in_channels = in_channels
|
| 62 |
+
self.out_channels = in_channels
|
| 63 |
+
self.config = config
|
| 64 |
+
|
| 65 |
+
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
| 66 |
+
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
| 67 |
+
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
| 68 |
+
|
| 69 |
+
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
| 70 |
+
self.time_mlp = TimestepEmbedding(
|
| 71 |
+
in_channels=config.hidden_size,
|
| 72 |
+
time_embed_dim=config.hidden_size,
|
| 73 |
+
)
|
| 74 |
+
self.delta_time_mlp = TimestepEmbedding(
|
| 75 |
+
in_channels=config.hidden_size,
|
| 76 |
+
time_embed_dim=config.hidden_size,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
| 80 |
+
self.decoder = MiniCPMModel(config)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
mu: torch.Tensor,
|
| 86 |
+
t: torch.Tensor,
|
| 87 |
+
cond: torch.Tensor,
|
| 88 |
+
dt: torch.Tensor,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Forward pass of DiT.
|
| 92 |
+
x: (N, C, T) tensor of inputs
|
| 93 |
+
mu: (N, C) tensor of hidden embedding
|
| 94 |
+
t: (N,) tensor of diffusion timesteps
|
| 95 |
+
cond: (N, C, T') tensor of prefix conditions
|
| 96 |
+
dt: (N,) used for mean velocity (may be supported in the future...)
|
| 97 |
+
"""
|
| 98 |
+
x = self.in_proj(x.transpose(1, 2).contiguous())
|
| 99 |
+
|
| 100 |
+
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
| 101 |
+
prefix = cond.size(1)
|
| 102 |
+
|
| 103 |
+
t = self.time_embeddings(t).to(x.dtype)
|
| 104 |
+
t = self.time_mlp(t)
|
| 105 |
+
dt = self.time_embeddings(dt).to(x.dtype)
|
| 106 |
+
dt = self.delta_time_mlp(dt)
|
| 107 |
+
t = t + dt
|
| 108 |
+
|
| 109 |
+
mu = mu.view(x.size(0), -1, x.size(-1))
|
| 110 |
+
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
|
| 111 |
+
|
| 112 |
+
hidden, _ = self.decoder(x, is_causal=False)
|
| 113 |
+
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
|
| 114 |
+
hidden = self.out_proj(hidden)
|
| 115 |
+
|
| 116 |
+
return hidden.transpose(1, 2).contiguous()
|
voxcpm/modules/locdit/unified_cfm.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.func import jvp
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from .local_dit import VoxCPMLocDiT
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CfmConfig(BaseModel):
|
| 12 |
+
sigma_min: float = 1e-6
|
| 13 |
+
solver: str = "euler"
|
| 14 |
+
t_scheduler: str = "log-norm"
|
| 15 |
+
training_cfg_rate: float = 0.1
|
| 16 |
+
inference_cfg_rate: float = 1.0
|
| 17 |
+
reg_loss_type: str = "l1"
|
| 18 |
+
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
|
| 19 |
+
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
|
| 20 |
+
noise_cond_scale: float = 0.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UnifiedCFM(torch.nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
in_channels: int,
|
| 27 |
+
cfm_params: CfmConfig,
|
| 28 |
+
estimator: VoxCPMLocDiT,
|
| 29 |
+
mean_mode: bool = False,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.solver = cfm_params.solver
|
| 33 |
+
self.sigma_min = cfm_params.sigma_min
|
| 34 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 35 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 36 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 37 |
+
self.reg_loss_type = cfm_params.reg_loss_type
|
| 38 |
+
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
|
| 39 |
+
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
|
| 40 |
+
self.noise_cond_scale = cfm_params.noise_cond_scale
|
| 41 |
+
|
| 42 |
+
self.in_channels = in_channels
|
| 43 |
+
self.mean_mode = mean_mode
|
| 44 |
+
|
| 45 |
+
self.estimator = estimator
|
| 46 |
+
|
| 47 |
+
# ------------------------------------------------------------------ #
|
| 48 |
+
# Inference
|
| 49 |
+
# ------------------------------------------------------------------ #
|
| 50 |
+
@torch.inference_mode()
|
| 51 |
+
def forward(
|
| 52 |
+
self,
|
| 53 |
+
mu: torch.Tensor,
|
| 54 |
+
n_timesteps: int,
|
| 55 |
+
patch_size: int,
|
| 56 |
+
cond: torch.Tensor,
|
| 57 |
+
temperature: float = 1.0,
|
| 58 |
+
cfg_value: float = 1.0,
|
| 59 |
+
sway_sampling_coef: float = 1.0,
|
| 60 |
+
use_cfg_zero_star: bool = True,
|
| 61 |
+
):
|
| 62 |
+
b, _ = mu.shape
|
| 63 |
+
t = patch_size
|
| 64 |
+
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
| 65 |
+
|
| 66 |
+
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 67 |
+
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
| 68 |
+
|
| 69 |
+
return self.solve_euler(
|
| 70 |
+
x=z,
|
| 71 |
+
t_span=t_span,
|
| 72 |
+
mu=mu,
|
| 73 |
+
cond=cond,
|
| 74 |
+
cfg_value=cfg_value,
|
| 75 |
+
use_cfg_zero_star=use_cfg_zero_star,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
|
| 79 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 80 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
| 81 |
+
st_star = dot_product / squared_norm
|
| 82 |
+
return st_star
|
| 83 |
+
|
| 84 |
+
def solve_euler(
|
| 85 |
+
self,
|
| 86 |
+
x: torch.Tensor,
|
| 87 |
+
t_span: torch.Tensor,
|
| 88 |
+
mu: torch.Tensor,
|
| 89 |
+
cond: torch.Tensor,
|
| 90 |
+
cfg_value: float = 1.0,
|
| 91 |
+
use_cfg_zero_star: bool = True,
|
| 92 |
+
):
|
| 93 |
+
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
| 94 |
+
|
| 95 |
+
sol = []
|
| 96 |
+
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
| 97 |
+
for step in range(1, len(t_span)):
|
| 98 |
+
if use_cfg_zero_star and step <= zero_init_steps:
|
| 99 |
+
dphi_dt = torch.zeros_like(x)
|
| 100 |
+
else:
|
| 101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 102 |
+
b = x.size(0)
|
| 103 |
+
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
| 104 |
+
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
| 105 |
+
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
| 106 |
+
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
| 107 |
+
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
|
| 108 |
+
x_in[:b], x_in[b:] = x, x
|
| 109 |
+
mu_in[:b] = mu
|
| 110 |
+
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
| 111 |
+
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
|
| 112 |
+
# not used now
|
| 113 |
+
if not self.mean_mode:
|
| 114 |
+
dt_in = torch.zeros_like(dt_in)
|
| 115 |
+
cond_in[:b], cond_in[b:] = cond, cond
|
| 116 |
+
|
| 117 |
+
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
| 118 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 119 |
+
|
| 120 |
+
if use_cfg_zero_star:
|
| 121 |
+
positive_flat = dphi_dt.view(b, -1)
|
| 122 |
+
negative_flat = cfg_dphi_dt.view(b, -1)
|
| 123 |
+
st_star = self.optimized_scale(positive_flat, negative_flat)
|
| 124 |
+
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
| 125 |
+
else:
|
| 126 |
+
st_star = 1.0
|
| 127 |
+
|
| 128 |
+
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
| 129 |
+
|
| 130 |
+
x = x - dt * dphi_dt
|
| 131 |
+
t = t - dt
|
| 132 |
+
sol.append(x)
|
| 133 |
+
if step < len(t_span) - 1:
|
| 134 |
+
dt = t - t_span[step + 1]
|
| 135 |
+
|
| 136 |
+
return sol[-1]
|
| 137 |
+
|
| 138 |
+
# ------------------------------------------------------------------ #
|
| 139 |
+
# Training loss
|
| 140 |
+
# ------------------------------------------------------------------ #
|
| 141 |
+
def adaptive_loss_weighting(
|
| 142 |
+
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
|
| 143 |
+
):
|
| 144 |
+
weights = 1.0 / ((losses + epsilon).pow(p))
|
| 145 |
+
if mask is not None:
|
| 146 |
+
weights = weights * mask
|
| 147 |
+
return weights.detach()
|
| 148 |
+
|
| 149 |
+
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
|
| 150 |
+
batch_size = x.shape[0]
|
| 151 |
+
if self.t_scheduler == "log-norm":
|
| 152 |
+
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
| 153 |
+
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
| 154 |
+
r = torch.sigmoid(s_r)
|
| 155 |
+
t = torch.sigmoid(s_t)
|
| 156 |
+
elif self.t_scheduler == "uniform":
|
| 157 |
+
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
| 158 |
+
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
|
| 161 |
+
|
| 162 |
+
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
|
| 163 |
+
r, t = torch.where(
|
| 164 |
+
mask,
|
| 165 |
+
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
|
| 166 |
+
torch.stack([t, t], dim=0),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return r.squeeze(), t.squeeze()
|
| 170 |
+
|
| 171 |
+
def compute_loss(
|
| 172 |
+
self,
|
| 173 |
+
x1: torch.Tensor,
|
| 174 |
+
mu: torch.Tensor,
|
| 175 |
+
cond: torch.Tensor | None = None,
|
| 176 |
+
tgt_mask: torch.Tensor | None = None,
|
| 177 |
+
progress: float = 0.0,
|
| 178 |
+
):
|
| 179 |
+
b, _, _ = x1.shape
|
| 180 |
+
|
| 181 |
+
if self.training_cfg_rate > 0:
|
| 182 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 183 |
+
mu = mu * cfg_mask.view(-1, 1)
|
| 184 |
+
|
| 185 |
+
if cond is None:
|
| 186 |
+
cond = torch.zeros_like(x1)
|
| 187 |
+
|
| 188 |
+
noisy_mask = torch.rand(b, device=x1.device) > (
|
| 189 |
+
1.0
|
| 190 |
+
- (
|
| 191 |
+
self.noise_cond_prob_range[0]
|
| 192 |
+
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
| 196 |
+
|
| 197 |
+
ratio_r_neq_t = (
|
| 198 |
+
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
| 199 |
+
if self.mean_mode
|
| 200 |
+
else 0.0
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
|
| 204 |
+
r_ = r.detach().clone()
|
| 205 |
+
t_ = t.detach().clone()
|
| 206 |
+
z = torch.randn_like(x1)
|
| 207 |
+
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
|
| 208 |
+
v = z - x1
|
| 209 |
+
|
| 210 |
+
def model_fn(z_sample, r_sample, t_sample):
|
| 211 |
+
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
|
| 212 |
+
|
| 213 |
+
if self.mean_mode:
|
| 214 |
+
v_r = torch.zeros_like(r)
|
| 215 |
+
v_t = torch.ones_like(t)
|
| 216 |
+
from torch.backends.cuda import sdp_kernel
|
| 217 |
+
|
| 218 |
+
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
|
| 219 |
+
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
|
| 220 |
+
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
|
| 221 |
+
else:
|
| 222 |
+
u_pred = model_fn(y, r, t)
|
| 223 |
+
u_tgt = v
|
| 224 |
+
|
| 225 |
+
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
| 226 |
+
if tgt_mask is not None:
|
| 227 |
+
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
| 228 |
+
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
| 229 |
+
else:
|
| 230 |
+
loss = losses.mean()
|
| 231 |
+
|
| 232 |
+
return loss
|
voxcpm/modules/locenc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .local_encoder import VoxCPMLocEnc
|
voxcpm/modules/locenc/local_encoder.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VoxCPMLocEnc(nn.Module):
|
| 8 |
+
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.config = config
|
| 11 |
+
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
| 12 |
+
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
| 13 |
+
|
| 14 |
+
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
| 15 |
+
self.encoder = MiniCPMModel(config)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
"""
|
| 19 |
+
x: [B, T, P, D]
|
| 20 |
+
"""
|
| 21 |
+
B, T, P, D = x.shape
|
| 22 |
+
|
| 23 |
+
x = self.in_proj(x)
|
| 24 |
+
special_tokens = self.special_token.expand(B, T, 1, -1)
|
| 25 |
+
x = torch.cat([special_tokens, x], dim=2)
|
| 26 |
+
x = rearrange(x, "b t p c -> (b t) p c")
|
| 27 |
+
outputs, _ = self.encoder(x, is_causal=False)
|
| 28 |
+
cls_output = outputs[:, 0, :]
|
| 29 |
+
|
| 30 |
+
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
voxcpm/modules/minicpm4/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import MiniCPM4Config
|
| 2 |
+
from .model import MiniCPMModel
|
| 3 |
+
from .cache import StaticKVCache
|
voxcpm/modules/minicpm4/cache.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class StaticKVCache:
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
num_layers: int,
|
| 9 |
+
num_kv_heads: int,
|
| 10 |
+
dim_kv_head: int,
|
| 11 |
+
batch_size: int,
|
| 12 |
+
device: torch.device,
|
| 13 |
+
dtype: torch.dtype,
|
| 14 |
+
max_length: int = 8192,
|
| 15 |
+
):
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
self.num_layers = num_layers
|
| 18 |
+
|
| 19 |
+
self.kv_cache = torch.zeros(
|
| 20 |
+
2,
|
| 21 |
+
num_layers,
|
| 22 |
+
batch_size,
|
| 23 |
+
num_kv_heads,
|
| 24 |
+
max_length,
|
| 25 |
+
dim_kv_head,
|
| 26 |
+
device=device,
|
| 27 |
+
dtype=dtype,
|
| 28 |
+
)
|
| 29 |
+
self.current_length = 0
|
| 30 |
+
|
| 31 |
+
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 32 |
+
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
|
| 33 |
+
|
| 34 |
+
def step(self) -> int:
|
| 35 |
+
if self.current_length >= self.max_length:
|
| 36 |
+
raise ValueError("KV cache is full")
|
| 37 |
+
|
| 38 |
+
ret = self.current_length
|
| 39 |
+
self.current_length += 1
|
| 40 |
+
return ret
|
| 41 |
+
|
| 42 |
+
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
|
| 43 |
+
self.current_length = kv_caches[0][0].size(2)
|
| 44 |
+
self.kv_cache.zero_()
|
| 45 |
+
for i in range(self.num_layers):
|
| 46 |
+
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
|
| 47 |
+
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
|
voxcpm/modules/minicpm4/config.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RopeScalingConfig(BaseModel):
|
| 6 |
+
type: str
|
| 7 |
+
long_factor: List[float]
|
| 8 |
+
short_factor: List[float]
|
| 9 |
+
original_max_position_embeddings: int
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MiniCPM4Config(BaseModel):
|
| 13 |
+
bos_token_id: int
|
| 14 |
+
eos_token_id: int
|
| 15 |
+
hidden_size: int
|
| 16 |
+
intermediate_size: int
|
| 17 |
+
max_position_embeddings: int
|
| 18 |
+
num_attention_heads: int
|
| 19 |
+
num_hidden_layers: int
|
| 20 |
+
num_key_value_heads: int
|
| 21 |
+
rms_norm_eps: float
|
| 22 |
+
rope_scaling: RopeScalingConfig
|
| 23 |
+
vocab_size: int
|
| 24 |
+
use_mup: bool = True
|
| 25 |
+
scale_emb: float
|
| 26 |
+
dim_model_base: int
|
| 27 |
+
scale_depth: float
|
| 28 |
+
rope_theta: float
|
| 29 |
+
kv_channels: int = None
|
| 30 |
+
no_rope: bool = False
|
voxcpm/modules/minicpm4/model.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import MiniCPM4Config
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
import math
|
| 6 |
+
from .cache import StaticKVCache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
| 10 |
+
old_dtype = hidden.dtype
|
| 11 |
+
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
| 12 |
+
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
| 13 |
+
return hidden * weight
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MiniCPMRMSNorm(nn.Module):
|
| 17 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 18 |
+
"""
|
| 19 |
+
MiniCPMRMSNorm is equivalent to T5LayerNorm
|
| 20 |
+
"""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 23 |
+
self.variance_epsilon = eps
|
| 24 |
+
|
| 25 |
+
def forward(self, hidden_states):
|
| 26 |
+
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def rotate_half(x):
|
| 30 |
+
"""Rotates half the hidden dims of the input."""
|
| 31 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 32 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
q: Tensor(batch_size, num_heads, seq_len, head_dim)
|
| 39 |
+
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
| 40 |
+
cos: Tensor(seq_len, head_dim)
|
| 41 |
+
sin: Tensor(seq_len, head_dim)
|
| 42 |
+
Returns:
|
| 43 |
+
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
| 44 |
+
"""
|
| 45 |
+
orig_dtype = q.dtype
|
| 46 |
+
q = q.to(torch.float32)
|
| 47 |
+
k = k.to(torch.float32)
|
| 48 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 49 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 50 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MiniCPMLongRoPE(nn.Module):
|
| 54 |
+
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: MiniCPM4Config):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.config = config
|
| 59 |
+
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
|
| 60 |
+
self.base = config.rope_theta
|
| 61 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 62 |
+
|
| 63 |
+
self.short_factor = config.rope_scaling.short_factor
|
| 64 |
+
self.long_factor = config.rope_scaling.long_factor
|
| 65 |
+
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
| 66 |
+
|
| 67 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
| 68 |
+
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
| 69 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| 70 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 71 |
+
|
| 72 |
+
self.max_seq_len_cached = 0
|
| 73 |
+
|
| 74 |
+
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
| 75 |
+
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
| 76 |
+
|
| 77 |
+
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
|
| 78 |
+
|
| 79 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 80 |
+
"""设置cos和sin缓存"""
|
| 81 |
+
self.max_seq_len_cached = seq_len
|
| 82 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 83 |
+
|
| 84 |
+
if seq_len > self.original_max_position_embeddings:
|
| 85 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
|
| 86 |
+
else:
|
| 87 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
| 88 |
+
|
| 89 |
+
freqs = torch.mul(
|
| 90 |
+
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# 创建embeddings
|
| 94 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 95 |
+
|
| 96 |
+
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
|
| 97 |
+
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
|
| 98 |
+
|
| 99 |
+
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
|
| 103 |
+
Returns:
|
| 104 |
+
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
|
| 105 |
+
"""
|
| 106 |
+
cos = self.cos_cached[position_ids]
|
| 107 |
+
sin = self.sin_cached[position_ids]
|
| 108 |
+
|
| 109 |
+
return cos, sin
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class MiniCPMAttention(nn.Module):
|
| 113 |
+
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.config = config
|
| 116 |
+
self.layer_idx = layer_idx
|
| 117 |
+
self.hidden_size = config.hidden_size
|
| 118 |
+
self.num_heads = config.num_attention_heads
|
| 119 |
+
self.head_dim = (
|
| 120 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
| 121 |
+
)
|
| 122 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 123 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 124 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 125 |
+
self.rope_theta = 10000.0
|
| 126 |
+
|
| 127 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 128 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 129 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 130 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 131 |
+
|
| 132 |
+
def forward(
|
| 133 |
+
self,
|
| 134 |
+
hidden_states: torch.Tensor,
|
| 135 |
+
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 136 |
+
is_causal: bool,
|
| 137 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 138 |
+
bsz, q_len, _ = hidden_states.size()
|
| 139 |
+
|
| 140 |
+
query_states = self.q_proj(hidden_states)
|
| 141 |
+
key_states = self.k_proj(hidden_states)
|
| 142 |
+
value_states = self.v_proj(hidden_states)
|
| 143 |
+
|
| 144 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 145 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 146 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 147 |
+
|
| 148 |
+
if position_emb is not None:
|
| 149 |
+
cos, sin = position_emb
|
| 150 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 151 |
+
|
| 152 |
+
# ref: https://github.com/pytorch/pytorch/issues/163597
|
| 153 |
+
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
| 154 |
+
query_states = query_states.contiguous()
|
| 155 |
+
key_states = key_states.contiguous()
|
| 156 |
+
value_states = value_states.contiguous()
|
| 157 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 158 |
+
query_states,
|
| 159 |
+
key_states,
|
| 160 |
+
value_states,
|
| 161 |
+
is_causal=is_causal,
|
| 162 |
+
enable_gqa=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 166 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
| 167 |
+
|
| 168 |
+
attn_output = self.o_proj(attn_output)
|
| 169 |
+
|
| 170 |
+
past_key_value = (key_states, value_states)
|
| 171 |
+
return attn_output, past_key_value
|
| 172 |
+
|
| 173 |
+
def forward_step(
|
| 174 |
+
self,
|
| 175 |
+
hidden_states: torch.Tensor,
|
| 176 |
+
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 177 |
+
position_id: int,
|
| 178 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
bsz, _ = hidden_states.size()
|
| 181 |
+
|
| 182 |
+
query_states = self.q_proj(hidden_states)
|
| 183 |
+
key_states = self.k_proj(hidden_states)
|
| 184 |
+
value_states = self.v_proj(hidden_states)
|
| 185 |
+
|
| 186 |
+
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 187 |
+
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 188 |
+
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 189 |
+
|
| 190 |
+
if position_emb is not None:
|
| 191 |
+
cos, sin = position_emb
|
| 192 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 193 |
+
|
| 194 |
+
key_cache, value_cache = kv_cache
|
| 195 |
+
|
| 196 |
+
key_cache[:, :, position_id, :] = key_states
|
| 197 |
+
value_cache[:, :, position_id, :] = value_states
|
| 198 |
+
|
| 199 |
+
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
| 200 |
+
|
| 201 |
+
# ref: https://github.com/pytorch/pytorch/issues/163597
|
| 202 |
+
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
| 203 |
+
query_states = query_states.contiguous()
|
| 204 |
+
key_cache = key_cache.contiguous()
|
| 205 |
+
value_cache = value_cache.contiguous()
|
| 206 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 207 |
+
query_states,
|
| 208 |
+
key_cache,
|
| 209 |
+
value_cache,
|
| 210 |
+
attn_mask=attn_mask,
|
| 211 |
+
enable_gqa=True,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 215 |
+
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
|
| 216 |
+
attn_output = self.o_proj(attn_output)
|
| 217 |
+
|
| 218 |
+
return attn_output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class MiniCPMMLP(nn.Module):
|
| 222 |
+
def __init__(self, config):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.config = config
|
| 225 |
+
self.hidden_size = config.hidden_size
|
| 226 |
+
self.intermediate_size = config.intermediate_size
|
| 227 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 228 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 229 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 230 |
+
self.act_fn = nn.SiLU()
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class MiniCPMDecoderLayer(nn.Module):
|
| 237 |
+
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.hidden_size = config.hidden_size
|
| 240 |
+
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
|
| 241 |
+
|
| 242 |
+
self.mlp = MiniCPMMLP(config)
|
| 243 |
+
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 244 |
+
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 245 |
+
|
| 246 |
+
self.scale_depth = config.scale_depth
|
| 247 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 248 |
+
self.use_mup = config.use_mup
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self,
|
| 252 |
+
hidden_states: torch.Tensor,
|
| 253 |
+
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 254 |
+
is_causal: bool,
|
| 255 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 259 |
+
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
|
| 260 |
+
is_causal (`bool`): whether the attention mask is causal
|
| 261 |
+
"""
|
| 262 |
+
residual = hidden_states
|
| 263 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 264 |
+
# Self Attention
|
| 265 |
+
hidden_states, present_key_value = self.self_attn(
|
| 266 |
+
hidden_states=hidden_states,
|
| 267 |
+
position_emb=position_emb,
|
| 268 |
+
is_causal=is_causal,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if self.use_mup:
|
| 272 |
+
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
| 273 |
+
else:
|
| 274 |
+
hidden_states = residual + hidden_states
|
| 275 |
+
|
| 276 |
+
# Fully Connected
|
| 277 |
+
residual = hidden_states
|
| 278 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 279 |
+
|
| 280 |
+
hidden_states = self.mlp(hidden_states)
|
| 281 |
+
if self.use_mup:
|
| 282 |
+
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
| 283 |
+
else:
|
| 284 |
+
hidden_states = residual + hidden_states
|
| 285 |
+
|
| 286 |
+
return hidden_states, present_key_value
|
| 287 |
+
|
| 288 |
+
def forward_step(
|
| 289 |
+
self,
|
| 290 |
+
hidden_states: torch.Tensor,
|
| 291 |
+
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 292 |
+
position_id: torch.Tensor,
|
| 293 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
| 294 |
+
) -> torch.Tensor:
|
| 295 |
+
residual = hidden_states
|
| 296 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 297 |
+
# Self Attention
|
| 298 |
+
hidden_states = self.self_attn.forward_step(
|
| 299 |
+
hidden_states=hidden_states,
|
| 300 |
+
position_emb=position_emb,
|
| 301 |
+
position_id=position_id,
|
| 302 |
+
kv_cache=kv_cache,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if self.use_mup:
|
| 306 |
+
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
| 307 |
+
else:
|
| 308 |
+
hidden_states = residual + hidden_states
|
| 309 |
+
|
| 310 |
+
# Fully Connected
|
| 311 |
+
residual = hidden_states
|
| 312 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 313 |
+
|
| 314 |
+
hidden_states = self.mlp(hidden_states)
|
| 315 |
+
if self.use_mup:
|
| 316 |
+
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
| 317 |
+
else:
|
| 318 |
+
hidden_states = residual + hidden_states
|
| 319 |
+
|
| 320 |
+
return hidden_states
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class MiniCPMModel(nn.Module):
|
| 324 |
+
"""
|
| 325 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
config: MiniCPMConfig
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, config: MiniCPM4Config):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.vocab_size = config.vocab_size
|
| 334 |
+
self.config = config
|
| 335 |
+
|
| 336 |
+
if config.vocab_size > 0:
|
| 337 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 338 |
+
else:
|
| 339 |
+
self.embed_tokens = nn.Identity()
|
| 340 |
+
|
| 341 |
+
self.layers = nn.ModuleList(
|
| 342 |
+
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 346 |
+
if config.no_rope:
|
| 347 |
+
self.rope_emb = None
|
| 348 |
+
else:
|
| 349 |
+
self.rope_emb = MiniCPMLongRoPE(config)
|
| 350 |
+
|
| 351 |
+
self.kv_cache = None
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
inputs_embeds: torch.Tensor,
|
| 356 |
+
is_causal: bool = True,
|
| 357 |
+
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 358 |
+
"""
|
| 359 |
+
Args:
|
| 360 |
+
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
|
| 361 |
+
is_causal: bool, whether the attention mask is causal
|
| 362 |
+
Returns:
|
| 363 |
+
hidden_states: Tensor(batch_size, seq_length, hidden_size)
|
| 364 |
+
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
|
| 365 |
+
"""
|
| 366 |
+
if self.rope_emb is not None:
|
| 367 |
+
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
|
| 368 |
+
position_emb = self.rope_emb(position_ids)
|
| 369 |
+
else:
|
| 370 |
+
position_emb = None
|
| 371 |
+
hidden_states = inputs_embeds
|
| 372 |
+
|
| 373 |
+
next_decoder_cache = []
|
| 374 |
+
|
| 375 |
+
for decoder_layer in self.layers:
|
| 376 |
+
|
| 377 |
+
hidden_states, this_cache = decoder_layer(
|
| 378 |
+
hidden_states,
|
| 379 |
+
position_emb,
|
| 380 |
+
is_causal,
|
| 381 |
+
)
|
| 382 |
+
next_decoder_cache.append(this_cache)
|
| 383 |
+
hidden_states = self.norm(hidden_states)
|
| 384 |
+
return hidden_states, next_decoder_cache
|
| 385 |
+
|
| 386 |
+
def forward_step(
|
| 387 |
+
self,
|
| 388 |
+
inputs_embeds: torch.Tensor,
|
| 389 |
+
position_id: torch.Tensor,
|
| 390 |
+
) -> torch.Tensor:
|
| 391 |
+
"""
|
| 392 |
+
Args:
|
| 393 |
+
inputs_embeds: Tensor(batch_size, hidden_size)
|
| 394 |
+
Returns:
|
| 395 |
+
hidden_states: Tensor(batch_size, hidden_size)
|
| 396 |
+
"""
|
| 397 |
+
assert self.kv_cache is not None, "KV cache is not setup"
|
| 398 |
+
|
| 399 |
+
if self.rope_emb is not None:
|
| 400 |
+
position_emb = self.rope_emb(position_id)
|
| 401 |
+
else:
|
| 402 |
+
position_emb = None
|
| 403 |
+
hidden_states = inputs_embeds
|
| 404 |
+
|
| 405 |
+
for i, decoder_layer in enumerate(self.layers):
|
| 406 |
+
hidden_states = decoder_layer.forward_step(
|
| 407 |
+
hidden_states,
|
| 408 |
+
position_emb,
|
| 409 |
+
position_id,
|
| 410 |
+
self.kv_cache.get_layer_cache(i),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
hidden_states = self.norm(hidden_states)
|
| 414 |
+
return hidden_states
|
| 415 |
+
|
| 416 |
+
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
|
| 417 |
+
self.kv_cache = StaticKVCache(
|
| 418 |
+
num_layers=self.config.num_hidden_layers,
|
| 419 |
+
num_kv_heads=self.config.num_key_value_heads,
|
| 420 |
+
dim_kv_head=(
|
| 421 |
+
self.config.hidden_size // self.config.num_attention_heads
|
| 422 |
+
if self.config.kv_channels is None
|
| 423 |
+
else self.config.kv_channels
|
| 424 |
+
),
|
| 425 |
+
batch_size=batch_size,
|
| 426 |
+
device=device,
|
| 427 |
+
dtype=dtype,
|
| 428 |
+
max_length=max_length,
|
| 429 |
+
)
|
voxcpm/training/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities for VoxCPM fine-tuning.
|
| 3 |
+
|
| 4 |
+
This package mirrors the training mechanics used in the minicpm-audio
|
| 5 |
+
tooling while relying solely on local audio-text datasets managed via
|
| 6 |
+
the HuggingFace ``datasets`` library.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .accelerator import Accelerator
|
| 10 |
+
from .tracker import TrainingTracker
|
| 11 |
+
from .data import (
|
| 12 |
+
load_audio_text_datasets,
|
| 13 |
+
HFVoxCPMDataset,
|
| 14 |
+
build_dataloader,
|
| 15 |
+
BatchProcessor,
|
| 16 |
+
)
|
| 17 |
+
from .state import TrainingState
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"Accelerator",
|
| 21 |
+
"TrainingTracker",
|
| 22 |
+
"HFVoxCPMDataset",
|
| 23 |
+
"BatchProcessor",
|
| 24 |
+
"TrainingState",
|
| 25 |
+
"load_audio_text_datasets",
|
| 26 |
+
"build_dataloader",
|
| 27 |
+
]
|
voxcpm/training/accelerator.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import typing
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torch.utils.data
|
| 12 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Accelerator:
|
| 16 |
+
"""
|
| 17 |
+
Simplified accelerator that mirrors the behaviour of the minicpm-audio
|
| 18 |
+
training utilities. It initializes a distributed process group when
|
| 19 |
+
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
|
| 20 |
+
preparing models/dataloaders for DDP.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, amp: bool = False, seed: int = 42):
|
| 24 |
+
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 25 |
+
|
| 26 |
+
if self.world_size > 1 and not dist.is_initialized():
|
| 27 |
+
dist.init_process_group("nccl", init_method="env://")
|
| 28 |
+
|
| 29 |
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
| 30 |
+
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 31 |
+
self.amp = amp
|
| 32 |
+
|
| 33 |
+
# Set random seed to ensure model initialization consistency
|
| 34 |
+
self._set_seed(seed)
|
| 35 |
+
|
| 36 |
+
class DummyScaler:
|
| 37 |
+
def step(self, optimizer):
|
| 38 |
+
optimizer.step()
|
| 39 |
+
|
| 40 |
+
def scale(self, loss):
|
| 41 |
+
return loss
|
| 42 |
+
|
| 43 |
+
def unscale_(self, optimizer):
|
| 44 |
+
return optimizer
|
| 45 |
+
|
| 46 |
+
def update(self):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
| 50 |
+
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
| 51 |
+
self._ddp_model = None # For no_sync support
|
| 52 |
+
|
| 53 |
+
def _set_seed(self, seed: int):
|
| 54 |
+
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
|
| 55 |
+
torch.manual_seed(seed)
|
| 56 |
+
np.random.seed(seed)
|
| 57 |
+
random.seed(seed)
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
torch.cuda.manual_seed_all(seed)
|
| 60 |
+
|
| 61 |
+
def __enter__(self):
|
| 62 |
+
if self.device_ctx is not None:
|
| 63 |
+
self.device_ctx.__enter__()
|
| 64 |
+
return self
|
| 65 |
+
|
| 66 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 67 |
+
if self.device_ctx is not None:
|
| 68 |
+
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
| 69 |
+
|
| 70 |
+
def barrier(self):
|
| 71 |
+
"""Synchronize all processes"""
|
| 72 |
+
if dist.is_initialized():
|
| 73 |
+
dist.barrier()
|
| 74 |
+
|
| 75 |
+
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
|
| 76 |
+
"""All-reduce tensor across processes"""
|
| 77 |
+
if dist.is_initialized():
|
| 78 |
+
dist.all_reduce(tensor, op=op)
|
| 79 |
+
return tensor
|
| 80 |
+
|
| 81 |
+
# ------------------------------------------------------------------ #
|
| 82 |
+
# Model helpers
|
| 83 |
+
# ------------------------------------------------------------------ #
|
| 84 |
+
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
| 85 |
+
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
| 86 |
+
model.device = self.device
|
| 87 |
+
model = model.to(self.device)
|
| 88 |
+
if self.world_size > 1:
|
| 89 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 90 |
+
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
|
| 91 |
+
self._ddp_model = model # Save DDP model reference for no_sync support
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
@contextlib.contextmanager
|
| 95 |
+
def no_sync(self):
|
| 96 |
+
"""
|
| 97 |
+
Context manager to skip gradient synchronization during gradient accumulation.
|
| 98 |
+
Only used outside the last micro-batch.
|
| 99 |
+
"""
|
| 100 |
+
if self._ddp_model is not None:
|
| 101 |
+
with self._ddp_model.no_sync():
|
| 102 |
+
yield
|
| 103 |
+
else:
|
| 104 |
+
yield
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def device(self):
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
return torch.device("cuda", self.local_rank)
|
| 110 |
+
if torch.backends.mps.is_available():
|
| 111 |
+
return torch.device("mps")
|
| 112 |
+
return torch.device("cpu")
|
| 113 |
+
|
| 114 |
+
# ------------------------------------------------------------------ #
|
| 115 |
+
# AMP helpers
|
| 116 |
+
# ------------------------------------------------------------------ #
|
| 117 |
+
def autocast(self, *args, **kwargs):
|
| 118 |
+
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
|
| 119 |
+
|
| 120 |
+
def backward(self, loss: torch.Tensor):
|
| 121 |
+
self.scaler.scale(loss).backward()
|
| 122 |
+
|
| 123 |
+
def step(self, optimizer: torch.optim.Optimizer):
|
| 124 |
+
self.scaler.step(optimizer)
|
| 125 |
+
|
| 126 |
+
def update(self):
|
| 127 |
+
self.scaler.update()
|
| 128 |
+
|
| 129 |
+
# ------------------------------------------------------------------ #
|
| 130 |
+
# Data helpers
|
| 131 |
+
# ------------------------------------------------------------------ #
|
| 132 |
+
def prepare_dataloader(
|
| 133 |
+
self,
|
| 134 |
+
dataset: typing.Iterable,
|
| 135 |
+
*,
|
| 136 |
+
batch_size: int,
|
| 137 |
+
num_workers: int = 0,
|
| 138 |
+
shuffle: bool = True,
|
| 139 |
+
collate_fn=None,
|
| 140 |
+
drop_last: bool = False,
|
| 141 |
+
) -> torch.utils.data.DataLoader:
|
| 142 |
+
if self.world_size > 1:
|
| 143 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 144 |
+
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
|
| 145 |
+
)
|
| 146 |
+
shuffle = False
|
| 147 |
+
else:
|
| 148 |
+
sampler = None
|
| 149 |
+
|
| 150 |
+
return torch.utils.data.DataLoader(
|
| 151 |
+
dataset,
|
| 152 |
+
batch_size=batch_size,
|
| 153 |
+
shuffle=shuffle if sampler is None else False,
|
| 154 |
+
sampler=sampler,
|
| 155 |
+
num_workers=num_workers,
|
| 156 |
+
collate_fn=collate_fn,
|
| 157 |
+
drop_last=drop_last,
|
| 158 |
+
pin_memory=True,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
| 163 |
+
return model.module if hasattr(model, "module") else model
|
voxcpm/training/config.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argbind
|
| 4 |
+
import yaml
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
|
| 10 |
+
"""
|
| 11 |
+
Load a YAML configuration file into a dictionary suitable for argbind.
|
| 12 |
+
"""
|
| 13 |
+
path = Path(path)
|
| 14 |
+
with path.open("r", encoding="utf-8") as f:
|
| 15 |
+
data = yaml.safe_load(f)
|
| 16 |
+
if not isinstance(data, dict):
|
| 17 |
+
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
|
| 18 |
+
return data
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args_with_config(config_path: str | Path | None = None):
|
| 22 |
+
"""
|
| 23 |
+
Helper to unify CLI arguments and YAML configuration.
|
| 24 |
+
|
| 25 |
+
Usage mirrors minicpm-audio:
|
| 26 |
+
args = parse_args_with_config("conf/voxcpm/finetune.yml")
|
| 27 |
+
with argbind.scope(args):
|
| 28 |
+
...
|
| 29 |
+
"""
|
| 30 |
+
cli_args = argbind.parse_args()
|
| 31 |
+
if config_path is None:
|
| 32 |
+
return cli_args
|
| 33 |
+
|
| 34 |
+
yaml_args = load_yaml_config(config_path)
|
| 35 |
+
with argbind.scope(cli_args):
|
| 36 |
+
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
| 37 |
+
cli_args.update(yaml_args)
|
| 38 |
+
return cli_args
|
voxcpm/training/data.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import argbind
|
| 5 |
+
import torch
|
| 6 |
+
from datasets import Audio, Dataset, DatasetDict, load_dataset
|
| 7 |
+
from torch.utils.data import Dataset as TorchDataset
|
| 8 |
+
|
| 9 |
+
from ..model.voxcpm import VoxCPMConfig
|
| 10 |
+
from ..modules.audiovae import AudioVAE
|
| 11 |
+
from .packers import AudioFeatureProcessingPacker
|
| 12 |
+
|
| 13 |
+
DEFAULT_TEXT_COLUMN = "text"
|
| 14 |
+
DEFAULT_AUDIO_COLUMN = "audio"
|
| 15 |
+
DEFAULT_ID_COLUMN = "dataset_id"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@argbind.bind()
|
| 19 |
+
def load_audio_text_datasets(
|
| 20 |
+
train_manifest: str,
|
| 21 |
+
val_manifest: str = "",
|
| 22 |
+
text_column: str = DEFAULT_TEXT_COLUMN,
|
| 23 |
+
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
| 24 |
+
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
| 25 |
+
sample_rate: int = 16_000,
|
| 26 |
+
num_proc: int = 1,
|
| 27 |
+
) -> Tuple[Dataset, Optional[Dataset]]:
|
| 28 |
+
data_files = {"train": train_manifest}
|
| 29 |
+
if val_manifest:
|
| 30 |
+
data_files["validation"] = val_manifest
|
| 31 |
+
|
| 32 |
+
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
|
| 33 |
+
|
| 34 |
+
def prepare(ds: Dataset) -> Dataset:
|
| 35 |
+
if audio_column not in ds.column_names:
|
| 36 |
+
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
| 37 |
+
# We cast to Audio to ensure proper handling during training,
|
| 38 |
+
# but for length calculation we might need raw path or duration if available.
|
| 39 |
+
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
| 40 |
+
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
| 41 |
+
if audio_column != DEFAULT_AUDIO_COLUMN:
|
| 42 |
+
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
| 43 |
+
if text_column != DEFAULT_TEXT_COLUMN:
|
| 44 |
+
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
| 45 |
+
if dataset_id_column and dataset_id_column in ds.column_names:
|
| 46 |
+
if dataset_id_column != DEFAULT_ID_COLUMN:
|
| 47 |
+
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
| 48 |
+
else:
|
| 49 |
+
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
|
| 50 |
+
return ds
|
| 51 |
+
|
| 52 |
+
train_ds = prepare(dataset_dict["train"])
|
| 53 |
+
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
|
| 54 |
+
return train_ds, val_ds
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_sample_lengths(
|
| 58 |
+
ds: Dataset,
|
| 59 |
+
audio_vae_fps: int = 25,
|
| 60 |
+
patch_size: int = 1,
|
| 61 |
+
) -> List[int]:
|
| 62 |
+
"""
|
| 63 |
+
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
|
| 64 |
+
|
| 65 |
+
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
|
| 66 |
+
- 文本长度: len(text_ids)
|
| 67 |
+
- 音频长度:
|
| 68 |
+
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
| 69 |
+
t_seq = ceil(t_vae / patch_size)
|
| 70 |
+
- 序列总长约为: text_len + t_seq + 2
|
| 71 |
+
|
| 72 |
+
Optimized: Use batch column access instead of iterating item by item.
|
| 73 |
+
"""
|
| 74 |
+
# Batch access columns - much faster than per-item access
|
| 75 |
+
text_ids_list = ds["text_ids"]
|
| 76 |
+
text_lens = [len(t) for t in text_ids_list]
|
| 77 |
+
|
| 78 |
+
has_duration = "duration" in ds.column_names
|
| 79 |
+
if has_duration:
|
| 80 |
+
durations = ds["duration"]
|
| 81 |
+
else:
|
| 82 |
+
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
| 83 |
+
durations = []
|
| 84 |
+
for i in range(len(ds)):
|
| 85 |
+
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
| 86 |
+
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
| 87 |
+
|
| 88 |
+
# Vectorized length computation
|
| 89 |
+
lengths = []
|
| 90 |
+
for text_len, duration in zip(text_lens, durations):
|
| 91 |
+
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
| 92 |
+
t_seq = math.ceil(t_vae / patch_size)
|
| 93 |
+
total_len = text_len + t_seq + 2
|
| 94 |
+
lengths.append(total_len)
|
| 95 |
+
|
| 96 |
+
return lengths
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class HFVoxCPMDataset(TorchDataset):
|
| 100 |
+
"""
|
| 101 |
+
Thin wrapper around a tokenized HuggingFace dataset that returns
|
| 102 |
+
PyTorch-friendly samples.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, dataset: Dataset):
|
| 106 |
+
self.dataset = dataset
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.dataset)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, idx: int):
|
| 112 |
+
item = self.dataset[idx]
|
| 113 |
+
audio = item[DEFAULT_AUDIO_COLUMN]
|
| 114 |
+
return {
|
| 115 |
+
"text_ids": item["text_ids"],
|
| 116 |
+
"audio_array": audio["array"],
|
| 117 |
+
"audio_sampling_rate": audio["sampling_rate"],
|
| 118 |
+
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
| 119 |
+
"is_prompt": item.get("is_prompt", False),
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
| 124 |
+
if not seqs:
|
| 125 |
+
return torch.empty(0)
|
| 126 |
+
max_len = max(seq.shape[0] for seq in seqs)
|
| 127 |
+
padded = []
|
| 128 |
+
for seq in seqs:
|
| 129 |
+
if seq.shape[0] < max_len:
|
| 130 |
+
pad_width = (0, max_len - seq.shape[0])
|
| 131 |
+
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
|
| 132 |
+
padded.append(seq)
|
| 133 |
+
return torch.stack(padded)
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def collate_fn(cls, batch: List[Dict]):
|
| 137 |
+
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
|
| 138 |
+
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
|
| 139 |
+
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
|
| 140 |
+
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
|
| 141 |
+
|
| 142 |
+
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
|
| 143 |
+
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
| 144 |
+
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
"text_tokens": text_padded,
|
| 148 |
+
"audio_tokens": audio_padded,
|
| 149 |
+
"task_ids": task_ids,
|
| 150 |
+
"dataset_ids": dataset_ids,
|
| 151 |
+
"is_prompts": is_prompts,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class BatchProcessor:
|
| 156 |
+
"""
|
| 157 |
+
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
|
| 158 |
+
the minicpm-audio mechanics.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
*,
|
| 164 |
+
config: VoxCPMConfig,
|
| 165 |
+
audio_vae: AudioVAE,
|
| 166 |
+
dataset_cnt: int,
|
| 167 |
+
device: torch.device,
|
| 168 |
+
):
|
| 169 |
+
self.device = device
|
| 170 |
+
self.dataset_cnt = dataset_cnt
|
| 171 |
+
self.audio_vae = audio_vae
|
| 172 |
+
self.audio_vae.to(device)
|
| 173 |
+
self.packer = AudioFeatureProcessingPacker(
|
| 174 |
+
dataset_cnt=dataset_cnt,
|
| 175 |
+
max_len=config.max_length,
|
| 176 |
+
patch_size=config.patch_size,
|
| 177 |
+
feat_dim=config.feat_dim,
|
| 178 |
+
audio_vae=self.audio_vae,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 182 |
+
audio_tokens = batch["audio_tokens"].to(self.device)
|
| 183 |
+
text_tokens = batch["text_tokens"].to(self.device)
|
| 184 |
+
task_ids = batch["task_ids"].to(self.device)
|
| 185 |
+
dataset_ids = batch["dataset_ids"].to(self.device)
|
| 186 |
+
|
| 187 |
+
packed = self.packer(
|
| 188 |
+
audio_tokens=audio_tokens,
|
| 189 |
+
text_tokens=text_tokens,
|
| 190 |
+
task_ids=task_ids,
|
| 191 |
+
dataset_ids=dataset_ids,
|
| 192 |
+
is_prompts=batch["is_prompts"],
|
| 193 |
+
)
|
| 194 |
+
return packed
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def build_dataloader(
|
| 198 |
+
hf_dataset: Dataset,
|
| 199 |
+
*,
|
| 200 |
+
accelerator,
|
| 201 |
+
batch_size: int,
|
| 202 |
+
num_workers: int,
|
| 203 |
+
drop_last: bool = False,
|
| 204 |
+
) -> torch.utils.data.DataLoader:
|
| 205 |
+
torch_dataset = HFVoxCPMDataset(hf_dataset)
|
| 206 |
+
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
|
| 207 |
+
return accelerator.prepare_dataloader(
|
| 208 |
+
torch_dataset,
|
| 209 |
+
batch_size=batch_size,
|
| 210 |
+
num_workers=num_workers,
|
| 211 |
+
shuffle=True,
|
| 212 |
+
collate_fn=HFVoxCPMDataset.collate_fn,
|
| 213 |
+
drop_last=drop_last,
|
| 214 |
+
)
|
voxcpm/training/packers.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AudioFeatureProcessingPacker:
|
| 9 |
+
"""
|
| 10 |
+
Adapted from the minicpm-audio training utilities. It converts raw text and
|
| 11 |
+
audio tokens into the packed multimodal representation required by VoxCPM.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
| 15 |
+
self.audio_start_id = 101
|
| 16 |
+
self.audio_end_id = 102
|
| 17 |
+
# unused now
|
| 18 |
+
self.audio_prompt_start_id = 103
|
| 19 |
+
self.audio_prompt_end_id = 104
|
| 20 |
+
self.text_eos_token_id = 2
|
| 21 |
+
|
| 22 |
+
self.patch_size = patch_size
|
| 23 |
+
self.patch_len = audio_vae.hop_length * self.patch_size
|
| 24 |
+
self.feat_dim = feat_dim
|
| 25 |
+
self.dataset_cnt = max(dataset_cnt, 1)
|
| 26 |
+
self.max_len = max_len
|
| 27 |
+
|
| 28 |
+
self.audio_vae = audio_vae
|
| 29 |
+
|
| 30 |
+
self.process_functions = {"tts": self.process_tts_data}
|
| 31 |
+
self.task_id_map = {"tts": 1}
|
| 32 |
+
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
|
| 33 |
+
|
| 34 |
+
# ------------------------------------------------------------------ #
|
| 35 |
+
# Helpers
|
| 36 |
+
# ------------------------------------------------------------------ #
|
| 37 |
+
@staticmethod
|
| 38 |
+
def _first_pad_position(tokens: torch.Tensor):
|
| 39 |
+
positions = (tokens == -100).nonzero(as_tuple=True)
|
| 40 |
+
if positions[0].numel() == 0:
|
| 41 |
+
return None
|
| 42 |
+
return int(positions[0][0])
|
| 43 |
+
|
| 44 |
+
def unpad_text_tokens(self, tokens: torch.Tensor):
|
| 45 |
+
pad_pos = self._first_pad_position(tokens)
|
| 46 |
+
return tokens if pad_pos is None else tokens[:pad_pos]
|
| 47 |
+
|
| 48 |
+
def unpad_audio_tokens(self, tokens: torch.Tensor):
|
| 49 |
+
pad_pos = self._first_pad_position(tokens)
|
| 50 |
+
return tokens if pad_pos is None else tokens[:pad_pos]
|
| 51 |
+
|
| 52 |
+
def encode_audio(self, wav: torch.Tensor):
|
| 53 |
+
"""
|
| 54 |
+
Encode raw waveform into latent features using AudioVAE.
|
| 55 |
+
|
| 56 |
+
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
|
| 57 |
+
We then transpose to [B, T, D] to match downstream expectations.
|
| 58 |
+
"""
|
| 59 |
+
wav = wav.unsqueeze(0) # [1, T]
|
| 60 |
+
wav = wav.unsqueeze(1) # [1, 1, T]
|
| 61 |
+
wav_len = wav.size(-1)
|
| 62 |
+
if wav_len % self.patch_len != 0:
|
| 63 |
+
padding_size = self.patch_len - wav_len % self.patch_len
|
| 64 |
+
wav = torch.nn.functional.pad(wav, (0, padding_size))
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
|
| 68 |
+
feat = z.transpose(1, 2) # [1, T', D]
|
| 69 |
+
return feat
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------ #
|
| 72 |
+
# Main entry point
|
| 73 |
+
# ------------------------------------------------------------------ #
|
| 74 |
+
def __call__(
|
| 75 |
+
self,
|
| 76 |
+
audio_tokens: torch.Tensor,
|
| 77 |
+
text_tokens: torch.Tensor,
|
| 78 |
+
task_ids: torch.Tensor,
|
| 79 |
+
dataset_ids: torch.Tensor,
|
| 80 |
+
is_prompts: List[bool],
|
| 81 |
+
) -> Dict[str, torch.Tensor]:
|
| 82 |
+
"""
|
| 83 |
+
Padding-based batching: each sample in the input batch is processed
|
| 84 |
+
independently and then padded to a common length (capped by ``max_len``).
|
| 85 |
+
The result tensors all have shape [B, T, ...].
|
| 86 |
+
"""
|
| 87 |
+
device = audio_tokens.device
|
| 88 |
+
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
| 89 |
+
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
|
| 90 |
+
|
| 91 |
+
text_tokens_list: List[torch.Tensor] = []
|
| 92 |
+
audio_feats_list: List[torch.Tensor] = []
|
| 93 |
+
text_mask_list: List[torch.Tensor] = []
|
| 94 |
+
audio_mask_list: List[torch.Tensor] = []
|
| 95 |
+
loss_mask_list: List[torch.Tensor] = []
|
| 96 |
+
labels_list: List[torch.Tensor] = []
|
| 97 |
+
audio_task_ids_list: List[torch.Tensor] = []
|
| 98 |
+
audio_dataset_ids_list: List[torch.Tensor] = []
|
| 99 |
+
lengths: List[int] = []
|
| 100 |
+
|
| 101 |
+
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
| 102 |
+
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
| 103 |
+
|
| 104 |
+
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
| 105 |
+
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
| 106 |
+
):
|
| 107 |
+
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
| 108 |
+
unpad_text_token = self.unpad_text_tokens(text_token)
|
| 109 |
+
usage = self.id_to_task[task_id]
|
| 110 |
+
|
| 111 |
+
(
|
| 112 |
+
packed_text,
|
| 113 |
+
audio_feat,
|
| 114 |
+
text_mask,
|
| 115 |
+
audio_mask,
|
| 116 |
+
loss_mask,
|
| 117 |
+
labels,
|
| 118 |
+
audio_duration,
|
| 119 |
+
text_token_count,
|
| 120 |
+
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
| 121 |
+
|
| 122 |
+
audio_duration_consumed[dataset_idx] += audio_duration
|
| 123 |
+
text_token_consumed[dataset_idx] += text_token_count
|
| 124 |
+
|
| 125 |
+
audio_task_id = torch.zeros_like(audio_mask)
|
| 126 |
+
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
|
| 127 |
+
|
| 128 |
+
audio_dataset_id = torch.zeros_like(audio_mask)
|
| 129 |
+
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
|
| 130 |
+
|
| 131 |
+
text_tokens_list.append(packed_text)
|
| 132 |
+
text_mask_list.append(text_mask)
|
| 133 |
+
audio_feats_list.append(audio_feat)
|
| 134 |
+
audio_mask_list.append(audio_mask)
|
| 135 |
+
loss_mask_list.append(loss_mask)
|
| 136 |
+
labels_list.append(labels)
|
| 137 |
+
audio_task_ids_list.append(audio_task_id)
|
| 138 |
+
audio_dataset_ids_list.append(audio_dataset_id)
|
| 139 |
+
lengths.append(packed_text.shape[0])
|
| 140 |
+
|
| 141 |
+
# Determine padded length per batch (cap by self.max_len)
|
| 142 |
+
if lengths:
|
| 143 |
+
max_len = min(self.max_len, max(lengths))
|
| 144 |
+
else:
|
| 145 |
+
max_len = self.max_len
|
| 146 |
+
|
| 147 |
+
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
| 148 |
+
if x.size(0) >= max_len:
|
| 149 |
+
return x[:max_len]
|
| 150 |
+
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
| 151 |
+
return torch.cat([x, pad], dim=0)
|
| 152 |
+
|
| 153 |
+
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
# x: [T, P, D]
|
| 155 |
+
if x.size(0) >= max_len:
|
| 156 |
+
return x[:max_len]
|
| 157 |
+
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
| 158 |
+
return torch.cat([x, pad], dim=0)
|
| 159 |
+
|
| 160 |
+
if lengths:
|
| 161 |
+
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
| 162 |
+
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
| 163 |
+
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
| 164 |
+
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
| 165 |
+
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
| 166 |
+
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
| 167 |
+
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
| 168 |
+
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
| 169 |
+
|
| 170 |
+
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
| 171 |
+
position_ids_list = []
|
| 172 |
+
for L in lengths:
|
| 173 |
+
L_clip = min(L, max_len)
|
| 174 |
+
pos = torch.arange(0, L_clip, device=device)
|
| 175 |
+
if L_clip < max_len:
|
| 176 |
+
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
|
| 177 |
+
pos = torch.cat([pos, pad], dim=0)
|
| 178 |
+
position_ids_list.append(pos)
|
| 179 |
+
position_ids = torch.stack(position_ids_list, dim=0)
|
| 180 |
+
else:
|
| 181 |
+
# Empty batch fallback (shouldn't really happen)
|
| 182 |
+
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
|
| 183 |
+
text_mask_batch = torch.zeros_like(text_tokens_batch)
|
| 184 |
+
audio_feats_batch = torch.zeros(
|
| 185 |
+
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
|
| 186 |
+
)
|
| 187 |
+
audio_mask_batch = torch.zeros_like(text_tokens_batch)
|
| 188 |
+
loss_mask_batch = torch.zeros_like(text_tokens_batch)
|
| 189 |
+
labels_batch = torch.zeros_like(text_tokens_batch)
|
| 190 |
+
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
|
| 191 |
+
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
|
| 192 |
+
position_ids = torch.zeros_like(text_tokens_batch)
|
| 193 |
+
|
| 194 |
+
audio_duration_consumed = audio_duration_consumed.to(torch.long)
|
| 195 |
+
text_token_consumed = text_token_consumed.to(torch.long)
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"text_tokens": text_tokens_batch,
|
| 199 |
+
"audio_feats": audio_feats_batch,
|
| 200 |
+
"text_mask": text_mask_batch,
|
| 201 |
+
"audio_mask": audio_mask_batch,
|
| 202 |
+
"loss_mask": loss_mask_batch,
|
| 203 |
+
"position_ids": position_ids,
|
| 204 |
+
"labels": labels_batch,
|
| 205 |
+
"audio_task_ids": audio_task_ids_batch,
|
| 206 |
+
"audio_dataset_ids": audio_dataset_ids_batch,
|
| 207 |
+
"audio_duration_consumed": audio_duration_consumed,
|
| 208 |
+
"text_token_consumed": text_token_consumed,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
# ------------------------------------------------------------------ #
|
| 212 |
+
# Feature extraction helpers
|
| 213 |
+
# ------------------------------------------------------------------ #
|
| 214 |
+
def extract_audio_feats(self, audio_data: torch.Tensor):
|
| 215 |
+
audio_feats = self.encode_audio(audio_data)
|
| 216 |
+
if audio_feats.size(1) % self.patch_size != 0:
|
| 217 |
+
audio_feats_ = audio_feats.transpose(1, 2)
|
| 218 |
+
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
|
| 219 |
+
audio_feats = padding.transpose(1, 2)
|
| 220 |
+
|
| 221 |
+
audio_duration = audio_feats.size(1) / 25
|
| 222 |
+
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
|
| 223 |
+
return audio_feats, audio_duration
|
| 224 |
+
|
| 225 |
+
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
|
| 226 |
+
text_token_info = torch.cat(
|
| 227 |
+
[
|
| 228 |
+
text_token,
|
| 229 |
+
torch.tensor(
|
| 230 |
+
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
|
| 231 |
+
dtype=torch.int32,
|
| 232 |
+
device=text_token.device,
|
| 233 |
+
),
|
| 234 |
+
],
|
| 235 |
+
dim=-1,
|
| 236 |
+
)
|
| 237 |
+
text_token_count = len(text_token)
|
| 238 |
+
text_length = text_token_info.shape[0]
|
| 239 |
+
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
|
| 240 |
+
audio_feat_info = audio_feat_info.squeeze(0)
|
| 241 |
+
audio_length = audio_feat_info.shape[0]
|
| 242 |
+
|
| 243 |
+
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
| 244 |
+
text_token_info = torch.cat(
|
| 245 |
+
[
|
| 246 |
+
text_token_info,
|
| 247 |
+
text_pad_token,
|
| 248 |
+
torch.tensor(
|
| 249 |
+
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
|
| 250 |
+
dtype=torch.int32,
|
| 251 |
+
device=text_token.device,
|
| 252 |
+
),
|
| 253 |
+
]
|
| 254 |
+
)
|
| 255 |
+
audio_pad_feat = torch.zeros(
|
| 256 |
+
(text_length, self.patch_size, audio_feat_info.size(-1)),
|
| 257 |
+
dtype=torch.float32,
|
| 258 |
+
device=text_token.device,
|
| 259 |
+
)
|
| 260 |
+
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
| 261 |
+
|
| 262 |
+
text_mask = (
|
| 263 |
+
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
| 264 |
+
.type(torch.int32)
|
| 265 |
+
.to(text_token.device)
|
| 266 |
+
)
|
| 267 |
+
audio_mask = (
|
| 268 |
+
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
| 269 |
+
.type(torch.int32)
|
| 270 |
+
.to(text_token.device)
|
| 271 |
+
)
|
| 272 |
+
loss_mask = (
|
| 273 |
+
torch.cat(
|
| 274 |
+
[
|
| 275 |
+
torch.zeros(text_length),
|
| 276 |
+
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
| 277 |
+
torch.zeros(1),
|
| 278 |
+
]
|
| 279 |
+
)
|
| 280 |
+
.type(torch.int32)
|
| 281 |
+
.to(text_token.device)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
| 285 |
+
labels[-2] = 1
|
| 286 |
+
|
| 287 |
+
return (
|
| 288 |
+
text_token_info,
|
| 289 |
+
audio_feat_info,
|
| 290 |
+
text_mask,
|
| 291 |
+
audio_mask,
|
| 292 |
+
loss_mask,
|
| 293 |
+
labels,
|
| 294 |
+
audio_duration,
|
| 295 |
+
text_token_count,
|
| 296 |
+
)
|
voxcpm/training/state.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class TrainingState:
|
| 8 |
+
"""
|
| 9 |
+
Container that mirrors the object returned in the minicpm-audio training
|
| 10 |
+
loop. It holds persistent references to the model, optimizer, scheduler,
|
| 11 |
+
dataloaders and tracker.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
generator: object
|
| 15 |
+
optimizer: object
|
| 16 |
+
scheduler: object
|
| 17 |
+
train_loader: object
|
| 18 |
+
val_loader: object
|
| 19 |
+
tracker: object
|
| 20 |
+
batch_processor: object
|
voxcpm/training/tracker.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrainingTracker:
|
| 11 |
+
"""
|
| 12 |
+
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
| 13 |
+
|
| 14 |
+
It keeps track of the current global step, prints rank-aware messages,
|
| 15 |
+
optionally writes to TensorBoard via a provided writer, and stores progress
|
| 16 |
+
in a logfile for later inspection.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
*,
|
| 22 |
+
writer=None,
|
| 23 |
+
log_file: Optional[str] = None,
|
| 24 |
+
rank: int = 0,
|
| 25 |
+
):
|
| 26 |
+
self.writer = writer
|
| 27 |
+
self.log_file = Path(log_file) if log_file else None
|
| 28 |
+
if self.log_file:
|
| 29 |
+
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
self.rank = rank
|
| 31 |
+
self.step = 0
|
| 32 |
+
# Record the time of the last log to calculate the interval
|
| 33 |
+
self._last_log_time: float | None = None
|
| 34 |
+
|
| 35 |
+
# ------------------------------------------------------------------ #
|
| 36 |
+
# Logging helpers
|
| 37 |
+
# ------------------------------------------------------------------ #
|
| 38 |
+
def print(self, message: str):
|
| 39 |
+
if self.rank == 0:
|
| 40 |
+
print(message, flush=True, file=sys.stderr)
|
| 41 |
+
if self.log_file:
|
| 42 |
+
with self.log_file.open("a", encoding="utf-8") as f:
|
| 43 |
+
f.write(message + "\n")
|
| 44 |
+
|
| 45 |
+
def log_metrics(self, metrics: Dict[str, float], split: str):
|
| 46 |
+
if self.rank == 0:
|
| 47 |
+
now = time.time()
|
| 48 |
+
dt_str = ""
|
| 49 |
+
if self._last_log_time is not None:
|
| 50 |
+
dt = now - self._last_log_time
|
| 51 |
+
dt_str = f", log interval: {dt:.2f}s"
|
| 52 |
+
self._last_log_time = now
|
| 53 |
+
|
| 54 |
+
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
| 55 |
+
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
| 56 |
+
if self.writer is not None:
|
| 57 |
+
for key, value in metrics.items():
|
| 58 |
+
if isinstance(value, (int, float)):
|
| 59 |
+
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
| 60 |
+
|
| 61 |
+
def done(self, split: str, message: str):
|
| 62 |
+
self.print(f"[{split}] {message}")
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------------ #
|
| 65 |
+
# State dict
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
def state_dict(self):
|
| 68 |
+
return {"step": self.step}
|
| 69 |
+
|
| 70 |
+
def load_state_dict(self, state):
|
| 71 |
+
self.step = int(state.get("step", 0))
|
| 72 |
+
|
| 73 |
+
# ------------------------------------------------------------------ #
|
| 74 |
+
# Context manager compatibility (for parity with minicpm-audio code)
|
| 75 |
+
# ------------------------------------------------------------------ #
|
| 76 |
+
@contextlib.contextmanager
|
| 77 |
+
def live(self):
|
| 78 |
+
yield
|
voxcpm/utils/text_normalize.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
|
| 2 |
+
import re
|
| 3 |
+
import regex
|
| 4 |
+
import inflect
|
| 5 |
+
from wetext import Normalizer
|
| 6 |
+
|
| 7 |
+
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# whether contain chinese character
|
| 11 |
+
def contains_chinese(text):
|
| 12 |
+
return bool(chinese_char_pattern.search(text))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# replace special symbol
|
| 16 |
+
def replace_corner_mark(text):
|
| 17 |
+
text = text.replace("²", "平方")
|
| 18 |
+
text = text.replace("³", "立方")
|
| 19 |
+
text = text.replace("√", "根号")
|
| 20 |
+
text = text.replace("≈", "约等于")
|
| 21 |
+
text = text.replace("<", "小于")
|
| 22 |
+
return text
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# remove meaningless symbol
|
| 26 |
+
def remove_bracket(text):
|
| 27 |
+
text = text.replace("(", " ").replace(")", " ")
|
| 28 |
+
text = text.replace("【", " ").replace("】", " ")
|
| 29 |
+
text = text.replace("`", "").replace("`", "")
|
| 30 |
+
text = text.replace("——", " ")
|
| 31 |
+
return text
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# spell Arabic numerals
|
| 35 |
+
def spell_out_number(text: str, inflect_parser):
|
| 36 |
+
new_text = []
|
| 37 |
+
st = None
|
| 38 |
+
for i, c in enumerate(text):
|
| 39 |
+
if not c.isdigit():
|
| 40 |
+
if st is not None:
|
| 41 |
+
num_str = inflect_parser.number_to_words(text[st:i])
|
| 42 |
+
new_text.append(num_str)
|
| 43 |
+
st = None
|
| 44 |
+
new_text.append(c)
|
| 45 |
+
else:
|
| 46 |
+
if st is None:
|
| 47 |
+
st = i
|
| 48 |
+
if st is not None and st < len(text):
|
| 49 |
+
num_str = inflect_parser.number_to_words(text[st:])
|
| 50 |
+
new_text.append(num_str)
|
| 51 |
+
return "".join(new_text)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# split paragrah logic:
|
| 55 |
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
| 56 |
+
# 2. cal sentence len according to lang
|
| 57 |
+
# 3. split sentence according to puncatation
|
| 58 |
+
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
| 59 |
+
def calc_utt_length(_text: str):
|
| 60 |
+
if lang == "zh":
|
| 61 |
+
return len(_text)
|
| 62 |
+
else:
|
| 63 |
+
return len(tokenize(_text))
|
| 64 |
+
|
| 65 |
+
def should_merge(_text: str):
|
| 66 |
+
if lang == "zh":
|
| 67 |
+
return len(_text) < merge_len
|
| 68 |
+
else:
|
| 69 |
+
return len(tokenize(_text)) < merge_len
|
| 70 |
+
|
| 71 |
+
if lang == "zh":
|
| 72 |
+
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
| 73 |
+
else:
|
| 74 |
+
pounc = [".", "?", "!", ";", ":"]
|
| 75 |
+
if comma_split:
|
| 76 |
+
pounc.extend([",", ","])
|
| 77 |
+
st = 0
|
| 78 |
+
utts = []
|
| 79 |
+
for i, c in enumerate(text):
|
| 80 |
+
if c in pounc:
|
| 81 |
+
if len(text[st:i]) > 0:
|
| 82 |
+
utts.append(text[st:i] + c)
|
| 83 |
+
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
| 84 |
+
tmp = utts.pop(-1)
|
| 85 |
+
utts.append(tmp + text[i + 1])
|
| 86 |
+
st = i + 2
|
| 87 |
+
else:
|
| 88 |
+
st = i + 1
|
| 89 |
+
if len(utts) == 0:
|
| 90 |
+
if lang == "zh":
|
| 91 |
+
utts.append(text + "。")
|
| 92 |
+
else:
|
| 93 |
+
utts.append(text + ".")
|
| 94 |
+
final_utts = []
|
| 95 |
+
cur_utt = ""
|
| 96 |
+
for utt in utts:
|
| 97 |
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
| 98 |
+
final_utts.append(cur_utt)
|
| 99 |
+
cur_utt = ""
|
| 100 |
+
cur_utt = cur_utt + utt
|
| 101 |
+
if len(cur_utt) > 0:
|
| 102 |
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
| 103 |
+
final_utts[-1] = final_utts[-1] + cur_utt
|
| 104 |
+
else:
|
| 105 |
+
final_utts.append(cur_utt)
|
| 106 |
+
|
| 107 |
+
return final_utts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# remove blank between chinese character
|
| 111 |
+
def replace_blank(text: str):
|
| 112 |
+
out_str = []
|
| 113 |
+
for i, c in enumerate(text):
|
| 114 |
+
if c == " ":
|
| 115 |
+
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
| 116 |
+
out_str.append(c)
|
| 117 |
+
else:
|
| 118 |
+
out_str.append(c)
|
| 119 |
+
return "".join(out_str)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def clean_markdown(md_text: str) -> str:
|
| 123 |
+
# 去除代码块 ``` ```(包括多行)
|
| 124 |
+
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
| 125 |
+
|
| 126 |
+
# 去除内联代码 `code`
|
| 127 |
+
md_text = re.sub(r"`[^`]*`", "", md_text)
|
| 128 |
+
|
| 129 |
+
# 去除图片语法 
|
| 130 |
+
md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
|
| 131 |
+
|
| 132 |
+
# 去除链接但保留文本 [text](url) -> text
|
| 133 |
+
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
| 134 |
+
|
| 135 |
+
# 替换无序列表符号
|
| 136 |
+
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
|
| 137 |
+
|
| 138 |
+
# 去除HTML标签
|
| 139 |
+
md_text = re.sub(r"<[^>]+>", "", md_text)
|
| 140 |
+
|
| 141 |
+
# 去除标题符号(#)
|
| 142 |
+
md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
|
| 143 |
+
|
| 144 |
+
# 去除多余空格和空行
|
| 145 |
+
md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
|
| 146 |
+
md_text = md_text.strip()
|
| 147 |
+
|
| 148 |
+
return md_text
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def clean_text(text):
|
| 152 |
+
# 去除 Markdown 语法
|
| 153 |
+
text = clean_markdown(text)
|
| 154 |
+
# 匹配并移除表情符号
|
| 155 |
+
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
|
| 156 |
+
# 去除换行符
|
| 157 |
+
text = text.replace("\n", " ")
|
| 158 |
+
text = text.replace("\t", " ")
|
| 159 |
+
text = text.replace("“", '"').replace("”", '"')
|
| 160 |
+
return text
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class TextNormalizer:
|
| 164 |
+
def __init__(self, tokenizer=None):
|
| 165 |
+
self.tokenizer = tokenizer
|
| 166 |
+
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
|
| 167 |
+
self.en_tn_model = Normalizer(lang="en", operator="tn")
|
| 168 |
+
self.inflect_parser = inflect.engine()
|
| 169 |
+
|
| 170 |
+
def normalize(self, text, split=False):
|
| 171 |
+
# 去除 Markdown 语法,去除表情符号,去除换行符
|
| 172 |
+
lang = "zh" if contains_chinese(text) else "en"
|
| 173 |
+
text = clean_text(text)
|
| 174 |
+
if lang == "zh":
|
| 175 |
+
text = text.replace(
|
| 176 |
+
"=", "等于"
|
| 177 |
+
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
| 178 |
+
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
|
| 179 |
+
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
|
| 180 |
+
text = self.zh_tn_model.normalize(text)
|
| 181 |
+
text = replace_blank(text)
|
| 182 |
+
text = replace_corner_mark(text)
|
| 183 |
+
text = remove_bracket(text)
|
| 184 |
+
else:
|
| 185 |
+
text = self.en_tn_model.normalize(text)
|
| 186 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 187 |
+
if split is False:
|
| 188 |
+
return text
|
voxcpm/zipenhancer.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ZipEnhancer Module - Audio Denoising Enhancer
|
| 3 |
+
|
| 4 |
+
Provides on-demand import ZipEnhancer functionality for audio denoising processing.
|
| 5 |
+
Related dependencies are imported only when denoising functionality is needed.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import torchaudio
|
| 12 |
+
from modelscope.pipelines import pipeline
|
| 13 |
+
from modelscope.utils.constant import Tasks
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ZipEnhancer:
|
| 17 |
+
"""ZipEnhancer Audio Denoising Enhancer"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
| 20 |
+
"""
|
| 21 |
+
Initialize ZipEnhancer
|
| 22 |
+
Args:
|
| 23 |
+
model_path: ModelScope model path or local path
|
| 24 |
+
"""
|
| 25 |
+
self.model_path = model_path
|
| 26 |
+
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
|
| 27 |
+
|
| 28 |
+
def _normalize_loudness(self, wav_path: str):
|
| 29 |
+
"""
|
| 30 |
+
Audio loudness normalization
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
wav_path: Audio file path
|
| 34 |
+
"""
|
| 35 |
+
audio, sr = torchaudio.load(wav_path)
|
| 36 |
+
loudness = torchaudio.functional.loudness(audio, sr)
|
| 37 |
+
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
|
| 38 |
+
torchaudio.save(wav_path, normalized_audio, sr)
|
| 39 |
+
|
| 40 |
+
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Audio denoising enhancement
|
| 43 |
+
Args:
|
| 44 |
+
input_path: Input audio file path
|
| 45 |
+
output_path: Output audio file path (optional, creates temp file by default)
|
| 46 |
+
normalize_loudness: Whether to perform loudness normalization
|
| 47 |
+
Returns:
|
| 48 |
+
str: Output audio file path
|
| 49 |
+
Raises:
|
| 50 |
+
RuntimeError: If pipeline is not initialized or processing fails
|
| 51 |
+
"""
|
| 52 |
+
if not os.path.exists(input_path):
|
| 53 |
+
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
| 54 |
+
# Create temporary file if no output path is specified
|
| 55 |
+
if output_path is None:
|
| 56 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 57 |
+
output_path = tmp_file.name
|
| 58 |
+
try:
|
| 59 |
+
# Perform denoising processing
|
| 60 |
+
self._pipeline(input_path, output_path=output_path)
|
| 61 |
+
# Loudness normalization
|
| 62 |
+
if normalize_loudness:
|
| 63 |
+
self._normalize_loudness(output_path)
|
| 64 |
+
return output_path
|
| 65 |
+
except Exception as e:
|
| 66 |
+
# Clean up possibly created temporary files
|
| 67 |
+
if output_path and os.path.exists(output_path):
|
| 68 |
+
try:
|
| 69 |
+
os.unlink(output_path)
|
| 70 |
+
except OSError:
|
| 71 |
+
pass
|
| 72 |
+
raise RuntimeError(f"Audio denoising processing failed: {e}")
|