himahande45 commited on
Commit
402a61f
·
verified ·
1 Parent(s): 5903ab3

Add IndicVox paper demo Space

Browse files
Files changed (42) hide show
  1. .gitattributes +3 -0
  2. README.md +29 -5
  3. app.py +451 -0
  4. assets/voices/hin_m_ref_00.wav +3 -0
  5. assets/voices/tam_f_ref_00.wav +3 -0
  6. assets/voices/tam_m_ref_00.wav +3 -0
  7. code_switch_prompts.json +166 -0
  8. packages.txt +2 -0
  9. requirements.txt +13 -0
  10. voxcpm/__init__.py +5 -0
  11. voxcpm/cli.py +598 -0
  12. voxcpm/core.py +333 -0
  13. voxcpm/model/__init__.py +4 -0
  14. voxcpm/model/utils.py +121 -0
  15. voxcpm/model/voxcpm.py +985 -0
  16. voxcpm/model/voxcpm2.py +1224 -0
  17. voxcpm/modules/__init__.py +0 -0
  18. voxcpm/modules/audiovae/__init__.py +2 -0
  19. voxcpm/modules/audiovae/audio_vae.py +377 -0
  20. voxcpm/modules/audiovae/audio_vae_v2.py +486 -0
  21. voxcpm/modules/layers/__init__.py +1 -0
  22. voxcpm/modules/layers/lora.py +130 -0
  23. voxcpm/modules/layers/scalar_quantization_layer.py +26 -0
  24. voxcpm/modules/locdit/__init__.py +3 -0
  25. voxcpm/modules/locdit/local_dit.py +114 -0
  26. voxcpm/modules/locdit/local_dit_v2.py +116 -0
  27. voxcpm/modules/locdit/unified_cfm.py +232 -0
  28. voxcpm/modules/locenc/__init__.py +1 -0
  29. voxcpm/modules/locenc/local_encoder.py +30 -0
  30. voxcpm/modules/minicpm4/__init__.py +3 -0
  31. voxcpm/modules/minicpm4/cache.py +47 -0
  32. voxcpm/modules/minicpm4/config.py +30 -0
  33. voxcpm/modules/minicpm4/model.py +429 -0
  34. voxcpm/training/__init__.py +27 -0
  35. voxcpm/training/accelerator.py +163 -0
  36. voxcpm/training/config.py +38 -0
  37. voxcpm/training/data.py +214 -0
  38. voxcpm/training/packers.py +296 -0
  39. voxcpm/training/state.py +20 -0
  40. voxcpm/training/tracker.py +78 -0
  41. voxcpm/utils/text_normalize.py +188 -0
  42. voxcpm/zipenhancer.py +72 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/voices/hin_m_ref_00.wav filter=lfs diff=lfs merge=lfs -text
37
+ assets/voices/tam_f_ref_00.wav filter=lfs diff=lfs merge=lfs -text
38
+ assets/voices/tam_m_ref_00.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: Indicvox Hindi Tamil Codeswitching Tts
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 6.12.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: "IndicVox: Hindi & Tamil Code-Switching TTS"
3
+ emoji: "🎙️"
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 6.12.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: "3.10.16"
11
+ suggested_hardware: a10g-small
12
  ---
13
 
14
+ # IndicVox
15
+
16
+ IndicVox is a GPU-backed research demo for multilingual text-to-speech across Hindi, Tamil, and code-switched prompts. The Space exposes the paper checkpoints through a clean Gradio UI with built-in voice presets and example prompts.
17
+
18
+ ## What it includes
19
+
20
+ - `Hindi Focus` for Hindi and Hindi-English prompts
21
+ - `Tamil Focus` for Tamil and Tamil-English prompts
22
+ - `Research Baseline` for direct comparison against the untuned multilingual model
23
+ - Built-in research voice presets for fast demo playback
24
+ - Zero-shot `Text Only` mode if you want to skip reference conditioning
25
+
26
+ ## Usage
27
+
28
+ 1. Pick a model profile.
29
+ 2. Type a Hindi, Tamil, or code-switched prompt.
30
+ 3. Pick a built-in voice preset or `Text Only`.
31
+ 4. Click `Generate Speech`.
32
+
33
+ ## Notes
34
+
35
+ - The base multilingual model stays resident on GPU memory and the paper checkpoints are swapped on demand.
36
+ - The Space is meant for inference/demo usage, not batch evaluation.
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ import threading
7
+ import time
8
+ import traceback
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ from huggingface_hub import snapshot_download
15
+
16
+ APP_DIR = Path(__file__).resolve().parent
17
+
18
+
19
+ def resolve_persist_root() -> Path:
20
+ data_root = Path("/data")
21
+ if data_root.exists() and os.access(data_root, os.W_OK):
22
+ return data_root
23
+
24
+ local_root = APP_DIR / ".cache"
25
+ local_root.mkdir(parents=True, exist_ok=True)
26
+ return local_root
27
+
28
+
29
+ PERSIST_ROOT = resolve_persist_root()
30
+ HF_HOME = PERSIST_ROOT / "huggingface"
31
+ HF_HOME.mkdir(parents=True, exist_ok=True)
32
+ os.environ.setdefault("HF_HOME", str(HF_HOME))
33
+ os.environ.setdefault("HF_HUB_CACHE", str(HF_HOME / "hub"))
34
+
35
+ sys.path.insert(0, str(APP_DIR))
36
+
37
+ from voxcpm import VoxCPM
38
+ from voxcpm.model.voxcpm import LoRAConfig
39
+
40
+ SPACE_TITLE = "IndicVox: Hindi & Tamil Code-Switching TTS"
41
+ MODEL_REPO_ID = "himahande45/multilingual-tts"
42
+ PROMPTS_FILE = APP_DIR / "code_switch_prompts.json"
43
+ VOICE_DIR = APP_DIR / "assets" / "voices"
44
+ DEFAULT_PROFILE = "Tamil Focus"
45
+ DEFAULT_VOICE = "Tamil Female Research Voice"
46
+ DEFAULT_TEXT = "இந்த experimentக்கு clean reference audio use பண்ணணும், இல்லனா output quality drop ஆகும்."
47
+
48
+ MODEL_PATTERNS = [
49
+ "VoxCPM2_local/*",
50
+ "finetune_checkpoints/step_0000500/lora_config.json",
51
+ "finetune_checkpoints/step_0000500/lora_weights.safetensors",
52
+ "finetune_checkpoints/step_0001000/lora_config.json",
53
+ "finetune_checkpoints/step_0001000/lora_weights.safetensors",
54
+ ]
55
+
56
+ PROFILES = {
57
+ "Tamil Focus": {
58
+ "description": "Best for Tamil and Tamil-English code-switched prompts.",
59
+ "checkpoint_dir": "finetune_checkpoints/step_0001000",
60
+ },
61
+ "Hindi Focus": {
62
+ "description": "Best for Hindi and Hindi-English code-switched prompts.",
63
+ "checkpoint_dir": "finetune_checkpoints/step_0000500",
64
+ },
65
+ "Research Baseline": {
66
+ "description": "Base multilingual checkpoint without paper fine-tuning.",
67
+ "checkpoint_dir": None,
68
+ },
69
+ }
70
+
71
+ VOICE_PRESETS = {
72
+ "Hindi Research Voice": {
73
+ "path": VOICE_DIR / "hin_m_ref_00.wav",
74
+ "transcript": "लेकिन क्या यह हम सभी कार्यक्रमों के साथ कर सकते?",
75
+ "summary": "Short Hindi reference used for sharper Hindi + English prompting.",
76
+ },
77
+ "Tamil Female Research Voice": {
78
+ "path": VOICE_DIR / "tam_f_ref_00.wav",
79
+ "transcript": "விக்கற நேரத்தையும் லாபத்தையும் பொறுத்து, இந்த டேக்ஸை ஷார்ட் டேர்ம் இல்ல லாங் டேர்ம்னு பிரிப்பாங்க.",
80
+ "summary": "Clear Tamil reference with stable conversational prosody.",
81
+ },
82
+ "Tamil Male Research Voice": {
83
+ "path": VOICE_DIR / "tam_m_ref_00.wav",
84
+ "transcript": "கொரோனா பாதிப்பு காலத்தில் எண்பது கோடி மக்களுக்கு உணவு தானியம் வழங்கப்பட்டதாகவும் அவர் தெரிவித்தார்.",
85
+ "summary": "Tamil male reference that holds rhythm well on longer prompts.",
86
+ },
87
+ "Text Only": {
88
+ "path": None,
89
+ "transcript": None,
90
+ "summary": "Zero-shot generation without a reference voice clip.",
91
+ },
92
+ }
93
+
94
+ CUSTOM_CSS = """
95
+ #app-shell {
96
+ max-width: 1180px;
97
+ margin: 0 auto;
98
+ }
99
+ #hero {
100
+ padding: 24px 26px 12px 26px;
101
+ border: 1px solid rgba(255, 255, 255, 0.08);
102
+ border-radius: 22px;
103
+ background:
104
+ radial-gradient(circle at top right, rgba(99, 102, 241, 0.16), transparent 34%),
105
+ radial-gradient(circle at bottom left, rgba(16, 185, 129, 0.14), transparent 30%),
106
+ rgba(15, 23, 42, 0.74);
107
+ }
108
+ .stat-chip {
109
+ display: inline-block;
110
+ margin: 6px 8px 0 0;
111
+ padding: 8px 12px;
112
+ border-radius: 999px;
113
+ background: rgba(255, 255, 255, 0.06);
114
+ font-size: 0.92rem;
115
+ }
116
+ .footnote {
117
+ opacity: 0.78;
118
+ font-size: 0.94rem;
119
+ }
120
+ footer {
121
+ visibility: hidden;
122
+ }
123
+ """
124
+
125
+ if torch.cuda.is_available():
126
+ torch.backends.cuda.matmul.allow_tf32 = True
127
+ torch.backends.cudnn.allow_tf32 = True
128
+ torch.set_float32_matmul_precision("high")
129
+
130
+ THEME = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
131
+
132
+
133
+ def load_examples() -> list[list[str]]:
134
+ with PROMPTS_FILE.open("r", encoding="utf-8") as f:
135
+ prompt_bank = json.load(f)
136
+
137
+ return [
138
+ [prompt_bank["hi_en"][0]["text"], "Hindi Focus", "Hindi Research Voice"],
139
+ [prompt_bank["hi_en"][9]["text"], "Hindi Focus", "Hindi Research Voice"],
140
+ [prompt_bank["hi_en"][16]["text"], "Hindi Focus", "Hindi Research Voice"],
141
+ [prompt_bank["ta_en"][0]["text"], "Tamil Focus", "Tamil Female Research Voice"],
142
+ [prompt_bank["ta_en"][9]["text"], "Tamil Focus", "Tamil Female Research Voice"],
143
+ [prompt_bank["ta_en"][14]["text"], "Tamil Focus", "Tamil Male Research Voice"],
144
+ ]
145
+
146
+
147
+ def profile_markdown(profile_name: str) -> str:
148
+ description = PROFILES[profile_name]["description"]
149
+ return f"**{profile_name}** \n{description}"
150
+
151
+
152
+ def voice_markdown(voice_name: str) -> str:
153
+ voice = VOICE_PRESETS[voice_name]
154
+ if voice["path"] is None:
155
+ return f"**{voice_name}** \n{voice['summary']}"
156
+ transcript = voice["transcript"]
157
+ return f"**{voice_name}** \n{voice['summary']} \nReference transcript: `{transcript}`"
158
+
159
+
160
+ def dynamic_max_len(text: str) -> int:
161
+ char_count = max(len(text.strip()), 1)
162
+ return max(280, min(900, int(char_count * 7.5)))
163
+
164
+
165
+ class ModelManager:
166
+ def __init__(self) -> None:
167
+ self.lock = threading.Lock()
168
+ self.repo_dir = self._resolve_repo_dir()
169
+ self.base_dir = self.repo_dir / "VoxCPM2_local"
170
+ self.loaded_profile: str | None = None
171
+ self.active_profile: str | None = None
172
+ self.model = self._load_model()
173
+ self.activate_profile(DEFAULT_PROFILE)
174
+
175
+ def _resolve_repo_dir(self) -> Path:
176
+ local_repo = os.getenv("INDICVOX_LOCAL_MODEL_REPO")
177
+ if local_repo:
178
+ path = Path(local_repo).expanduser().resolve()
179
+ if path.exists():
180
+ return path
181
+ raise FileNotFoundError(f"INDICVOX_LOCAL_MODEL_REPO does not exist: {path}")
182
+
183
+ token = os.getenv("HF_TOKEN")
184
+ snapshot_path = snapshot_download(
185
+ repo_id=MODEL_REPO_ID,
186
+ repo_type="model",
187
+ allow_patterns=MODEL_PATTERNS,
188
+ token=token,
189
+ )
190
+ return Path(snapshot_path)
191
+
192
+ def _load_lora_config(self, checkpoint_dir: Path) -> LoRAConfig:
193
+ payload = json.loads((checkpoint_dir / "lora_config.json").read_text(encoding="utf-8"))
194
+ return LoRAConfig(**payload["lora_config"])
195
+
196
+ def _load_model(self) -> VoxCPM:
197
+ if not torch.cuda.is_available():
198
+ raise RuntimeError("A GPU runtime is required. Request an A10G/L4 Space and restart.")
199
+
200
+ checkpoint_dir = self.repo_dir / PROFILES[DEFAULT_PROFILE]["checkpoint_dir"]
201
+ lora_config = self._load_lora_config(checkpoint_dir)
202
+ model = VoxCPM.from_pretrained(
203
+ hf_model_id=str(self.base_dir),
204
+ load_denoiser=False,
205
+ optimize=False,
206
+ lora_config=lora_config,
207
+ )
208
+ return model
209
+
210
+ def activate_profile(self, profile_name: str) -> None:
211
+ spec = PROFILES[profile_name]
212
+ checkpoint_dir = spec["checkpoint_dir"]
213
+
214
+ if checkpoint_dir is None:
215
+ self.model.set_lora_enabled(False)
216
+ self.active_profile = profile_name
217
+ return
218
+
219
+ if self.loaded_profile != profile_name:
220
+ if self.loaded_profile is not None:
221
+ self.model.unload_lora()
222
+ self.model.load_lora(str(self.repo_dir / checkpoint_dir))
223
+ self.loaded_profile = profile_name
224
+
225
+ self.model.set_lora_enabled(True)
226
+ self.active_profile = profile_name
227
+
228
+ def synthesize(
229
+ self,
230
+ text: str,
231
+ profile_name: str,
232
+ voice_name: str,
233
+ cfg_value: float,
234
+ inference_steps: int,
235
+ ) -> tuple[tuple[int, np.ndarray], str]:
236
+ clean_text = text.strip()
237
+ if not clean_text:
238
+ raise gr.Error("Enter a prompt first.")
239
+
240
+ start = time.perf_counter()
241
+ with self.lock:
242
+ self.activate_profile(profile_name)
243
+ kwargs = {
244
+ "text": clean_text,
245
+ "cfg_value": float(cfg_value),
246
+ "inference_timesteps": int(inference_steps),
247
+ "max_len": dynamic_max_len(clean_text),
248
+ }
249
+
250
+ voice = VOICE_PRESETS[voice_name]
251
+ if voice["path"] is not None:
252
+ kwargs["prompt_wav_path"] = str(voice["path"])
253
+ kwargs["prompt_text"] = voice["transcript"]
254
+
255
+ wav = self.model.generate(**kwargs)
256
+ sample_rate = int(self.model.tts_model.sample_rate)
257
+
258
+ if isinstance(wav, torch.Tensor):
259
+ wav = wav.detach().cpu().numpy()
260
+ wav = np.asarray(wav, dtype=np.float32).squeeze()
261
+ wav = np.clip(wav, -1.0, 1.0)
262
+
263
+ elapsed = time.perf_counter() - start
264
+ duration = float(wav.shape[-1]) / sample_rate if wav.size else 0.0
265
+ rtf = elapsed / duration if duration > 0 else float("nan")
266
+ speed_line = f"RTF {rtf:.2f}x" if np.isfinite(rtf) else "RTF n/a"
267
+ status = (
268
+ f"**Ready** \n"
269
+ f"Profile: `{profile_name}` \n"
270
+ f"Voice: `{voice_name}` \n"
271
+ f"Audio length: `{duration:.2f}s` \n"
272
+ f"Generation time: `{elapsed:.2f}s` ({speed_line})"
273
+ )
274
+ return (sample_rate, wav), status
275
+
276
+ def boot_markdown(self) -> str:
277
+ gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU"
278
+ active_profile = self.active_profile or DEFAULT_PROFILE
279
+ return (
280
+ f"**GPU Ready** \n"
281
+ f"Runtime: `{gpu_name}` \n"
282
+ f"Warm profile: `{active_profile}` \n"
283
+ f"Model source: `{MODEL_REPO_ID}`"
284
+ )
285
+
286
+
287
+ BOOT_ERROR: str | None = None
288
+ MODEL_MANAGER: ModelManager | None = None
289
+
290
+ try:
291
+ MODEL_MANAGER = ModelManager()
292
+ except Exception:
293
+ BOOT_ERROR = traceback.format_exc()
294
+
295
+ EXAMPLES = load_examples()
296
+
297
+
298
+ def synthesize(text: str, profile_name: str, voice_name: str, cfg_value: float, inference_steps: int):
299
+ if MODEL_MANAGER is None:
300
+ raise gr.Error(f"Model initialization failed.\n\n{BOOT_ERROR}")
301
+ return MODEL_MANAGER.synthesize(text, profile_name, voice_name, cfg_value, inference_steps)
302
+
303
+
304
+ def voice_preview(voice_name: str):
305
+ voice = VOICE_PRESETS[voice_name]
306
+ preview_path = str(voice["path"]) if voice["path"] is not None else None
307
+ return preview_path, voice_markdown(voice_name)
308
+
309
+
310
+ def clear_prompt() -> str:
311
+ return ""
312
+
313
+
314
+ def boot_status() -> str:
315
+ if MODEL_MANAGER is not None:
316
+ return MODEL_MANAGER.boot_markdown()
317
+ return f"**Startup Error** \n```text\n{BOOT_ERROR}\n```"
318
+
319
+
320
+ with gr.Blocks() as demo:
321
+ with gr.Column(elem_id="app-shell"):
322
+ gr.HTML(
323
+ """
324
+ <div id="hero">
325
+ <h1>IndicVox</h1>
326
+ <p>Research demo for multilingual TTS across Hindi, Tamil, and code-switched prompts.</p>
327
+ <div>
328
+ <span class="stat-chip">GPU-backed Space</span>
329
+ <span class="stat-chip">Warm-loaded model</span>
330
+ <span class="stat-chip">Hindi + Tamil + English prompts</span>
331
+ </div>
332
+ </div>
333
+ """
334
+ )
335
+
336
+ with gr.Row():
337
+ with gr.Column(scale=5):
338
+ prompt = gr.Textbox(
339
+ label="Prompt",
340
+ value=DEFAULT_TEXT,
341
+ lines=5,
342
+ max_lines=8,
343
+ placeholder="Type Hindi, Tamil, or code-switched text here...",
344
+ )
345
+
346
+ with gr.Row():
347
+ profile = gr.Dropdown(
348
+ choices=list(PROFILES.keys()),
349
+ value=DEFAULT_PROFILE,
350
+ label="Model Profile",
351
+ info="Switch between the Hindi-tuned and Tamil-tuned research profiles.",
352
+ )
353
+ voice = gr.Dropdown(
354
+ choices=list(VOICE_PRESETS.keys()),
355
+ value=DEFAULT_VOICE,
356
+ label="Voice Preset",
357
+ info="Built-in research voices plus a zero-shot option.",
358
+ )
359
+
360
+ with gr.Accordion("Advanced Settings", open=False):
361
+ with gr.Row():
362
+ cfg_value = gr.Slider(
363
+ minimum=1.0,
364
+ maximum=4.0,
365
+ value=2.0,
366
+ step=0.1,
367
+ label="CFG",
368
+ info="Higher values usually sound more guided but less relaxed.",
369
+ )
370
+ inference_steps = gr.Slider(
371
+ minimum=6,
372
+ maximum=16,
373
+ value=10,
374
+ step=1,
375
+ label="Diffusion Steps",
376
+ info="10 is the paper demo default.",
377
+ )
378
+
379
+ with gr.Row():
380
+ generate_btn = gr.Button("Generate Speech", variant="primary", size="lg")
381
+ clear_btn = gr.Button("Clear Prompt")
382
+
383
+ with gr.Row():
384
+ profile_info = gr.Markdown(profile_markdown(DEFAULT_PROFILE))
385
+ voice_info = gr.Markdown(voice_markdown(DEFAULT_VOICE))
386
+
387
+ with gr.Column(scale=4):
388
+ status = gr.Markdown(boot_status())
389
+ output_audio = gr.Audio(
390
+ label="Synthesized Audio",
391
+ autoplay=False,
392
+ format="wav",
393
+ )
394
+ voice_preview_audio = gr.Audio(
395
+ label="Voice Preset Preview",
396
+ value=str(VOICE_PRESETS[DEFAULT_VOICE]["path"]),
397
+ interactive=False,
398
+ autoplay=False,
399
+ format="wav",
400
+ )
401
+ gr.Markdown(
402
+ "The demo keeps the base model resident on GPU and swaps paper checkpoints on demand.",
403
+ elem_classes=["footnote"],
404
+ )
405
+
406
+ with gr.Tabs():
407
+ with gr.Tab("Hindi + English Examples"):
408
+ gr.Examples(
409
+ examples=[row for row in EXAMPLES if row[1] == "Hindi Focus"],
410
+ inputs=[prompt, profile, voice],
411
+ cache_examples=False,
412
+ )
413
+ with gr.Tab("Tamil + English Examples"):
414
+ gr.Examples(
415
+ examples=[row for row in EXAMPLES if row[1] == "Tamil Focus"],
416
+ inputs=[prompt, profile, voice],
417
+ cache_examples=False,
418
+ )
419
+
420
+ gr.Markdown(
421
+ """
422
+ **Demo notes**
423
+
424
+ - `Hindi Focus` maps to the Hindi-strong checkpoint from the paper experiments.
425
+ - `Tamil Focus` maps to the Tamil + code-switch checkpoint and is the default for the Space.
426
+ - `Text Only` skips the reference clip and runs zero-shot synthesis.
427
+ """,
428
+ elem_classes=["footnote"],
429
+ )
430
+
431
+ generate_btn.click(
432
+ fn=synthesize,
433
+ inputs=[prompt, profile, voice, cfg_value, inference_steps],
434
+ outputs=[output_audio, status],
435
+ api_name="synthesize",
436
+ )
437
+ prompt.submit(
438
+ fn=synthesize,
439
+ inputs=[prompt, profile, voice, cfg_value, inference_steps],
440
+ outputs=[output_audio, status],
441
+ api_name=False,
442
+ )
443
+ profile.change(fn=profile_markdown, inputs=profile, outputs=profile_info, api_name=False)
444
+ voice.change(fn=voice_preview, inputs=voice, outputs=[voice_preview_audio, voice_info], api_name=False)
445
+ clear_btn.click(fn=clear_prompt, outputs=prompt, api_name=False)
446
+
447
+
448
+ demo.queue(default_concurrency_limit=1, max_size=16)
449
+
450
+ if __name__ == "__main__":
451
+ demo.launch(theme=THEME, css=CUSTOM_CSS)
assets/voices/hin_m_ref_00.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5400ed6ce26df5efddce5e264153423f378623af05f987cc1c435c06cfd24df2
3
+ size 398732
assets/voices/tam_f_ref_00.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71564e96ac0378d91f35cf70e27f56e5ad814db267c59eac91ac370e612998f0
3
+ size 605228
assets/voices/tam_m_ref_00.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07ba36b8a8ac46246b0ae5cabe12b90325bd62eb4532dd4b6a117b306c8658d3
3
+ size 573452
code_switch_prompts.json ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hi_en": [
3
+ {
4
+ "id": "hi_en_001",
5
+ "text": "आज morning standup में हमने Hindi और English prompts पर ASR output compare किया।"
6
+ },
7
+ {
8
+ "id": "hi_en_002",
9
+ "text": "कल client demo से पहले तुम latest checkpoint का audio sample एक बार verify कर लो।"
10
+ },
11
+ {
12
+ "id": "hi_en_003",
13
+ "text": "अगर final report ready है तो उसे shared drive में upload कर दो।"
14
+ },
15
+ {
16
+ "id": "hi_en_004",
17
+ "text": "ये model normal sentences अच्छा बोलता है, लेकिन code-switch parts में अभी भी थोड़ा hesitation आता है।"
18
+ },
19
+ {
20
+ "id": "hi_en_005",
21
+ "text": "मुझे लगता है कि speaker similarity के लिए हमें clean reference clip use करना चाहिए।"
22
+ },
23
+ {
24
+ "id": "hi_en_006",
25
+ "text": "तुमने meeting notes में Tamil section add किया या वो अभी pending है?"
26
+ },
27
+ {
28
+ "id": "hi_en_007",
29
+ "text": "आज lab में GPU free है, इसलिए full evaluation run अभी start कर देते हैं।"
30
+ },
31
+ {
32
+ "id": "hi_en_008",
33
+ "text": "अगर transcript में punctuation ज्यादा हो तो Whisper कभी कभी extra words insert कर देता है।"
34
+ },
35
+ {
36
+ "id": "hi_en_009",
37
+ "text": "इस experiment के लिए मैंने short reference audio चुना ताकि cloning stable रहे।"
38
+ },
39
+ {
40
+ "id": "hi_en_010",
41
+ "text": "हम paper में monolingual results और code-switch results अलग tables में दिखाएँगे।"
42
+ },
43
+ {
44
+ "id": "hi_en_011",
45
+ "text": "please final plots save कर लेना, वरना thesis draft फिर से update करना पड़ेगा।"
46
+ },
47
+ {
48
+ "id": "hi_en_012",
49
+ "text": "आज के test set में proper nouns, news style और casual conversation तीनों mix किए गए हैं।"
50
+ },
51
+ {
52
+ "id": "hi_en_013",
53
+ "text": "अगर base model Tamil शब्द गलत बोलता है तो LoRA adaptation का effect तुरंत दिख जाएगा।"
54
+ },
55
+ {
56
+ "id": "hi_en_014",
57
+ "text": "मैंने summary sheet में WER, CER, switch-WER और speaker similarity सब add कर दिया है।"
58
+ },
59
+ {
60
+ "id": "hi_en_015",
61
+ "text": "आज evening तक तुम generated audio folders को model-wise sort कर दो।"
62
+ },
63
+ {
64
+ "id": "hi_en_016",
65
+ "text": "meeting के बाद हम ASR transcripts manually spot-check भी करेंगे ताकि obvious errors miss न हों।"
66
+ },
67
+ {
68
+ "id": "hi_en_017",
69
+ "text": "ये checkpoint short prompts पर ठीक है, पर long mixed sentences में इसकी rhythm थोड़ी uneven लगती है।"
70
+ },
71
+ {
72
+ "id": "hi_en_018",
73
+ "text": "अगर inference time ज्यादा हुआ तो पहले pilot run करेंगे और फिर full batch launch करेंगे।"
74
+ },
75
+ {
76
+ "id": "hi_en_019",
77
+ "text": "reference speaker clean है, लेकिन generated output में English words का stress अभी consistent नहीं है।"
78
+ },
79
+ {
80
+ "id": "hi_en_020",
81
+ "text": "इस बार final appendix में example prompts, transcripts और metric formulas तीनों include करना।"
82
+ }
83
+ ],
84
+ "ta_en": [
85
+ {
86
+ "id": "ta_en_001",
87
+ "text": "நேத்து team meetingல புதிய checkpoint பற்றி detailedஆ பேசினோம்."
88
+ },
89
+ {
90
+ "id": "ta_en_002",
91
+ "text": "இந்த experimentக்கு clean reference audio use பண்ணணும், இல்லனா output quality drop ஆகும்."
92
+ },
93
+ {
94
+ "id": "ta_en_003",
95
+ "text": "final report ready ஆனதும் அதை shared folderல upload பண்ணிடு."
96
+ },
97
+ {
98
+ "id": "ta_en_004",
99
+ "text": "இந்த model Tamil words நல்லா பேசுது, ஆனா English switch வரும் இடங்களில் இன்னும் slight hesitation இருக்கு."
100
+ },
101
+ {
102
+ "id": "ta_en_005",
103
+ "text": "speaker similarity score stable ஆகணும்னா same voice reference தொடர்ந்து use பண்ணணும்."
104
+ },
105
+ {
106
+ "id": "ta_en_006",
107
+ "text": "இன்று full evaluation run start பண்ணலாம், ஏன்னா GPU slot இப்போ free இருக்கு."
108
+ },
109
+ {
110
+ "id": "ta_en_007",
111
+ "text": "Whisper transcriptல punctuation இல்லாதப்போ சில code-switch words betterஆ capture ஆகுது."
112
+ },
113
+ {
114
+ "id": "ta_en_008",
115
+ "text": "paper tableல monolingual Tamil resultsவும் Tamil-English resultsவும் separateஆ காட்டணும்."
116
+ },
117
+ {
118
+ "id": "ta_en_009",
119
+ "text": "இந்த promptல proper noun, news style, casual speech மூன்றும் mixedஆ இருக்கு."
120
+ },
121
+ {
122
+ "id": "ta_en_010",
123
+ "text": "latest checkpoint load பண்ணதுக்குப் பிறகு ஒரு short sanity test முதலில் run பண்ணலாம்."
124
+ },
125
+ {
126
+ "id": "ta_en_011",
127
+ "text": "please generated audio files எல்லாம் model-wise sort பண்ணி metrics folderக்குள் move பண்ணு."
128
+ },
129
+ {
130
+ "id": "ta_en_012",
131
+ "text": "இந்த setupல base modelக்கு Tamil pronunciation கொஞ்சம் weakஆ இருந்தா LoRA gain clearஆ தெரியும்."
132
+ },
133
+ {
134
+ "id": "ta_en_013",
135
+ "text": "summary sheetல WER, CER, switch-WER, speaker similarity எல்லாமே சேர்க்கணும்."
136
+ },
137
+ {
138
+ "id": "ta_en_014",
139
+ "text": "meeting முடிஞ்சதும் manual spot-check பண்ணி obvious ASR mistakes இருக்கா என்று பார்க்கலாம்."
140
+ },
141
+ {
142
+ "id": "ta_en_015",
143
+ "text": "short promptsல output cleanஆ இருக்கு, ஆனா long mixed sentenceல rhythm கொஞ்சம் unevenஆ இருக்கு."
144
+ },
145
+ {
146
+ "id": "ta_en_016",
147
+ "text": "if the plots look clean, appendixல example promptsமும் generated transcriptsமும் add பண்ணலாம்."
148
+ },
149
+ {
150
+ "id": "ta_en_017",
151
+ "text": "இந்த reference clip calmஆ இருக்குது, அதனால் generated voiceவும் naturalஆ வர வாய்ப்பு அதிகம்."
152
+ },
153
+ {
154
+ "id": "ta_en_018",
155
+ "text": "tonightக்குள் full batch finish ஆயிடுச்சுனா நாளைக்கு paper draftல numbers insert பண்ணலாம்."
156
+ },
157
+ {
158
+ "id": "ta_en_019",
159
+ "text": "speaker clone நல்லா இருக்கு, ஆனால் English stress pattern இன்னும் fully consistent இல்ல."
160
+ },
161
+ {
162
+ "id": "ta_en_020",
163
+ "text": "இந்த round முடிஞ்சதும் next stepஆ human listening test plan பண்ணலாம்."
164
+ }
165
+ ]
166
+ }
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=6,<7
2
+ huggingface_hub>=1.0
3
+ numpy<3
4
+ torch>=2.5.0
5
+ torchaudio>=2.5.0
6
+ transformers>=4.36.2
7
+ einops>=0.8.0
8
+ inflect>=7.0.0
9
+ wetext
10
+ librosa>=0.10.2
11
+ soundfile>=0.12.1
12
+ pydantic>=2
13
+ safetensors>=0.4.5
voxcpm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .core import VoxCPM
2
+
3
+ __all__ = [
4
+ "VoxCPM",
5
+ ]
voxcpm/cli.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VoxCPM Command Line Interface
4
+
5
+ VoxCPM2-first CLI for voice design, cloning, and batch processing.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import os
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ import soundfile as sf
15
+
16
+ from voxcpm.core import VoxCPM
17
+
18
+
19
+ DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
20
+
21
+ # -----------------------------
22
+ # Validators
23
+ # -----------------------------
24
+
25
+
26
+ def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
27
+ path = Path(file_path)
28
+ if not path.exists():
29
+ raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
30
+ return path
31
+
32
+
33
+ def require_file_exists(file_path: str, parser, file_type: str = "file") -> Path:
34
+ try:
35
+ return validate_file_exists(file_path, file_type)
36
+ except FileNotFoundError as exc:
37
+ parser.error(str(exc))
38
+
39
+
40
+ def validate_output_path(output_path: str) -> Path:
41
+ path = Path(output_path)
42
+ path.parent.mkdir(parents=True, exist_ok=True)
43
+ return path
44
+
45
+
46
+ def validate_ranges(args, parser):
47
+ """Validate numeric argument ranges."""
48
+ if not (0.1 <= args.cfg_value <= 10.0):
49
+ parser.error("--cfg-value must be between 0.1 and 10.0 (recommended: 1.0–3.0)")
50
+
51
+ if not (1 <= args.inference_timesteps <= 100):
52
+ parser.error("--inference-timesteps must be between 1 and 100 (recommended: 4–30)")
53
+
54
+ if args.lora_r <= 0:
55
+ parser.error("--lora-r must be a positive integer")
56
+
57
+ if args.lora_alpha <= 0:
58
+ parser.error("--lora-alpha must be a positive integer")
59
+
60
+ if not (0.0 <= args.lora_dropout <= 1.0):
61
+ parser.error("--lora-dropout must be between 0.0 and 1.0")
62
+
63
+
64
+ def warn_legacy_mode():
65
+ print(
66
+ "Warning: legacy root CLI arguments are deprecated. Prefer `voxcpm design|clone|batch ...`.",
67
+ file=sys.stderr,
68
+ )
69
+
70
+
71
+ def build_final_text(text: str, control: str | None) -> str:
72
+ control = (control or "").strip()
73
+ return f"({control}){text}" if control else text
74
+
75
+
76
+ def resolve_prompt_text(args, parser) -> str | None:
77
+ prompt_text = getattr(args, "prompt_text", None)
78
+ prompt_file = getattr(args, "prompt_file", None)
79
+
80
+ if prompt_text and prompt_file:
81
+ parser.error("Use either --prompt-text or --prompt-file, not both.")
82
+
83
+ if prompt_file:
84
+ prompt_path = require_file_exists(prompt_file, parser, "prompt text file")
85
+ return prompt_path.read_text(encoding="utf-8").strip()
86
+
87
+ if prompt_text:
88
+ return prompt_text.strip()
89
+
90
+ return None
91
+
92
+
93
+ def detect_model_architecture(args) -> str | None:
94
+ model_location = getattr(args, "model_path", None) or getattr(
95
+ args, "hf_model_id", None
96
+ )
97
+ if not model_location:
98
+ return None
99
+
100
+ if os.path.isdir(model_location):
101
+ config_path = Path(model_location) / "config.json"
102
+ if not config_path.exists():
103
+ return None
104
+
105
+ with open(config_path, "r", encoding="utf-8") as f:
106
+ return json.load(f).get("architecture", "voxcpm").lower()
107
+
108
+ model_hint = str(model_location).lower()
109
+ if "voxcpm2" in model_hint:
110
+ return "voxcpm2"
111
+ if (
112
+ "voxcpm1.5" in model_hint
113
+ or "voxcpm-1.5" in model_hint
114
+ or "voxcpm_1.5" in model_hint
115
+ ):
116
+ return "voxcpm"
117
+
118
+ return None
119
+
120
+
121
+ def validate_prompt_related_args(args, parser, prompt_text: str | None):
122
+ if prompt_text and not args.prompt_audio:
123
+ parser.error("--prompt-text/--prompt-file requires --prompt-audio.")
124
+
125
+ if args.prompt_audio and not prompt_text:
126
+ parser.error("--prompt-audio requires --prompt-text or --prompt-file.")
127
+
128
+ if args.control and prompt_text:
129
+ parser.error(
130
+ "--control cannot be used together with --prompt-text or --prompt-file."
131
+ )
132
+
133
+
134
+ def validate_reference_support(args, parser):
135
+ if not getattr(args, "reference_audio", None):
136
+ return
137
+
138
+ arch = detect_model_architecture(args)
139
+ if arch == "voxcpm":
140
+ parser.error("--reference-audio is only supported with VoxCPM2 models.")
141
+
142
+
143
+ def validate_design_args(args, parser):
144
+ prompt_text = resolve_prompt_text(args, parser)
145
+ if args.prompt_audio or args.reference_audio or prompt_text:
146
+ parser.error(
147
+ "`design` does not accept prompt/reference audio. Use `clone` instead."
148
+ )
149
+
150
+
151
+ def validate_clone_args(args, parser):
152
+ prompt_text = resolve_prompt_text(args, parser)
153
+ validate_prompt_related_args(args, parser, prompt_text)
154
+ validate_reference_support(args, parser)
155
+
156
+ if not args.prompt_audio and not args.reference_audio:
157
+ parser.error(
158
+ "`clone` requires --reference-audio, or --prompt-audio with --prompt-text/--prompt-file."
159
+ )
160
+
161
+ return prompt_text
162
+
163
+
164
+ def validate_batch_args(args, parser):
165
+ prompt_text = resolve_prompt_text(args, parser)
166
+ validate_prompt_related_args(args, parser, prompt_text)
167
+ validate_reference_support(args, parser)
168
+ return prompt_text
169
+
170
+
171
+ # -----------------------------
172
+ # Model loading
173
+ # -----------------------------
174
+
175
+
176
+ def load_model(args) -> VoxCPM:
177
+ print("Loading VoxCPM model...", file=sys.stderr)
178
+
179
+ zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
180
+ "ZIPENHANCER_MODEL_PATH", None
181
+ )
182
+
183
+ # Build LoRA config if provided
184
+ lora_config = None
185
+ lora_weights_path = getattr(args, "lora_path", None)
186
+ if lora_weights_path:
187
+ from voxcpm.model.voxcpm import LoRAConfig
188
+
189
+ lora_config = LoRAConfig(
190
+ enable_lm=not args.lora_disable_lm,
191
+ enable_dit=not args.lora_disable_dit,
192
+ enable_proj=args.lora_enable_proj,
193
+ r=args.lora_r,
194
+ alpha=args.lora_alpha,
195
+ dropout=args.lora_dropout,
196
+ )
197
+
198
+ print(
199
+ f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
200
+ f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}",
201
+ file=sys.stderr,
202
+ )
203
+
204
+ # Load local model if specified
205
+ if args.model_path:
206
+ try:
207
+ model = VoxCPM(
208
+ voxcpm_model_path=args.model_path,
209
+ zipenhancer_model_path=zipenhancer_path,
210
+ enable_denoiser=not args.no_denoiser,
211
+ optimize=not args.no_optimize,
212
+ lora_config=lora_config,
213
+ lora_weights_path=lora_weights_path,
214
+ )
215
+ print("Model loaded (local).", file=sys.stderr)
216
+ return model
217
+ except Exception as e:
218
+ print(f"Failed to load model (local): {e}", file=sys.stderr)
219
+ sys.exit(1)
220
+
221
+ # Load from Hugging Face Hub
222
+ try:
223
+ model = VoxCPM.from_pretrained(
224
+ hf_model_id=args.hf_model_id,
225
+ load_denoiser=not args.no_denoiser,
226
+ zipenhancer_model_id=zipenhancer_path,
227
+ cache_dir=args.cache_dir,
228
+ local_files_only=args.local_files_only,
229
+ optimize=not args.no_optimize,
230
+ lora_config=lora_config,
231
+ lora_weights_path=lora_weights_path,
232
+ )
233
+ print("Model loaded (from_pretrained).", file=sys.stderr)
234
+ return model
235
+ except Exception as e:
236
+ print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
237
+ sys.exit(1)
238
+
239
+
240
+ # -----------------------------
241
+ # Commands
242
+ # -----------------------------
243
+
244
+
245
+ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None):
246
+ output_path = validate_output_path(output)
247
+
248
+ if args.prompt_audio:
249
+ require_file_exists(args.prompt_audio, parser, "prompt audio file")
250
+ if args.reference_audio:
251
+ require_file_exists(args.reference_audio, parser, "reference audio file")
252
+
253
+ model = load_model(args)
254
+
255
+ audio_array = model.generate(
256
+ text=text,
257
+ prompt_wav_path=args.prompt_audio,
258
+ prompt_text=prompt_text,
259
+ reference_wav_path=args.reference_audio,
260
+ cfg_value=args.cfg_value,
261
+ inference_timesteps=args.inference_timesteps,
262
+ normalize=args.normalize,
263
+ denoise=args.denoise
264
+ and (args.prompt_audio is not None or args.reference_audio is not None),
265
+ )
266
+
267
+ sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
268
+
269
+ duration = len(audio_array) / model.tts_model.sample_rate
270
+ print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
271
+
272
+
273
+ def cmd_design(args, parser):
274
+ validate_design_args(args, parser)
275
+ final_text = build_final_text(args.text, args.control)
276
+ return _run_single(
277
+ args, parser, text=final_text, output=args.output, prompt_text=None
278
+ )
279
+
280
+
281
+ def cmd_clone(args, parser):
282
+ prompt_text = validate_clone_args(args, parser)
283
+ final_text = build_final_text(args.text, args.control)
284
+ return _run_single(
285
+ args, parser, text=final_text, output=args.output, prompt_text=prompt_text
286
+ )
287
+
288
+
289
+ def cmd_batch(args, parser):
290
+ input_file = require_file_exists(args.input, parser, "input file")
291
+ output_dir = Path(args.output_dir)
292
+ output_dir.mkdir(parents=True, exist_ok=True)
293
+
294
+ with open(input_file, "r", encoding="utf-8") as f:
295
+ texts = [line.strip() for line in f if line.strip()]
296
+
297
+ if not texts:
298
+ sys.exit("Error: Input file is empty")
299
+
300
+ prompt_text = validate_batch_args(args, parser)
301
+ model = load_model(args)
302
+
303
+ prompt_audio_path = None
304
+ if args.prompt_audio:
305
+ prompt_audio_path = str(
306
+ require_file_exists(args.prompt_audio, parser, "prompt audio file")
307
+ )
308
+
309
+ reference_audio_path = None
310
+ if args.reference_audio:
311
+ reference_audio_path = str(
312
+ require_file_exists(args.reference_audio, parser, "reference audio file")
313
+ )
314
+
315
+ success_count = 0
316
+
317
+ for i, text in enumerate(texts, 1):
318
+ try:
319
+ final_text = build_final_text(text, args.control)
320
+ audio_array = model.generate(
321
+ text=final_text,
322
+ prompt_wav_path=prompt_audio_path,
323
+ prompt_text=prompt_text,
324
+ reference_wav_path=reference_audio_path,
325
+ cfg_value=args.cfg_value,
326
+ inference_timesteps=args.inference_timesteps,
327
+ normalize=args.normalize,
328
+ denoise=args.denoise
329
+ and (prompt_audio_path is not None or reference_audio_path is not None),
330
+ )
331
+
332
+ output_file = output_dir / f"output_{i:03d}.wav"
333
+ sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
334
+
335
+ duration = len(audio_array) / model.tts_model.sample_rate
336
+ print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
337
+ success_count += 1
338
+
339
+ except Exception as e:
340
+ print(f"Failed on line {i}: {e}", file=sys.stderr)
341
+
342
+ print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
343
+
344
+
345
+ # -----------------------------
346
+ # Parser
347
+ # -----------------------------
348
+
349
+
350
+ def _add_common_generation_args(parser):
351
+ parser.add_argument("--text", "-t", help="Text to synthesize")
352
+ parser.add_argument(
353
+ "--control",
354
+ type=str,
355
+ help="Control instruction for VoxCPM2 voice design/cloning",
356
+ )
357
+ parser.add_argument(
358
+ "--cfg-value",
359
+ type=float,
360
+ default=2.0,
361
+ help="CFG guidance scale (float, recommended 1.0–3.0, default: 2.0)",
362
+ )
363
+ parser.add_argument(
364
+ "--inference-timesteps",
365
+ type=int,
366
+ default=10,
367
+ help="Inference steps (int, recommended 4–30, default: 10)",
368
+ )
369
+ parser.add_argument(
370
+ "--normalize", action="store_true", help="Enable text normalization"
371
+ )
372
+
373
+
374
+ def _add_prompt_reference_args(parser):
375
+ parser.add_argument(
376
+ "--prompt-audio",
377
+ "-pa",
378
+ help="Prompt audio file path (continuation mode, requires --prompt-text or --prompt-file)",
379
+ )
380
+ parser.add_argument(
381
+ "--prompt-text", "-pt", help="Text corresponding to the prompt audio"
382
+ )
383
+ parser.add_argument(
384
+ "--prompt-file", type=str, help="Text file corresponding to the prompt audio"
385
+ )
386
+ parser.add_argument(
387
+ "--reference-audio",
388
+ "-ra",
389
+ help="Reference audio for voice cloning (VoxCPM2 only)",
390
+ )
391
+ parser.add_argument(
392
+ "--denoise",
393
+ action="store_true",
394
+ help="Enable prompt/reference speech enhancement",
395
+ )
396
+
397
+
398
+ def _add_model_args(parser):
399
+ parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
400
+ parser.add_argument(
401
+ "--hf-model-id",
402
+ type=str,
403
+ default=DEFAULT_HF_MODEL_ID,
404
+ help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
405
+ )
406
+ parser.add_argument(
407
+ "--cache-dir", type=str, help="Cache directory for Hub downloads"
408
+ )
409
+ parser.add_argument(
410
+ "--local-files-only", action="store_true", help="Disable network access"
411
+ )
412
+ parser.add_argument(
413
+ "--no-denoiser", action="store_true", help="Disable denoiser model loading"
414
+ )
415
+ parser.add_argument(
416
+ "--no-optimize",
417
+ action="store_true",
418
+ help="Disable model optimization during loading",
419
+ )
420
+ parser.add_argument(
421
+ "--zipenhancer-path",
422
+ type=str,
423
+ help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)",
424
+ )
425
+
426
+
427
+ def _add_lora_args(parser):
428
+ parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
429
+ parser.add_argument(
430
+ "--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)"
431
+ )
432
+ parser.add_argument(
433
+ "--lora-alpha",
434
+ type=int,
435
+ default=16,
436
+ help="LoRA alpha (positive int, default: 16)",
437
+ )
438
+ parser.add_argument(
439
+ "--lora-dropout",
440
+ type=float,
441
+ default=0.0,
442
+ help="LoRA dropout rate (0.0–1.0, default: 0.0)",
443
+ )
444
+ parser.add_argument(
445
+ "--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers"
446
+ )
447
+ parser.add_argument(
448
+ "--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers"
449
+ )
450
+ parser.add_argument(
451
+ "--lora-enable-proj",
452
+ action="store_true",
453
+ help="Enable LoRA on projection layers",
454
+ )
455
+
456
+
457
+ def _build_parser():
458
+ parser = argparse.ArgumentParser(
459
+ description="VoxCPM CLI - VoxCPM2-first voice design, cloning, and batch processing",
460
+ formatter_class=argparse.RawDescriptionHelpFormatter,
461
+ epilog="""
462
+ Examples:
463
+ voxcpm design --text "Hello world" --output out.wav
464
+ voxcpm design --text "Hello world" --control "warm female voice" --output out.wav
465
+ voxcpm clone --text "Hello" --reference-audio ref.wav --output out.wav
466
+ voxcpm batch --input texts.txt --output-dir ./outs --reference-audio ref.wav
467
+ """,
468
+ )
469
+
470
+ subparsers = parser.add_subparsers(dest="command")
471
+
472
+ design_parser = subparsers.add_parser(
473
+ "design", help="Generate speech with VoxCPM2-first voice design"
474
+ )
475
+ _add_common_generation_args(design_parser)
476
+ _add_prompt_reference_args(design_parser)
477
+ _add_model_args(design_parser)
478
+ _add_lora_args(design_parser)
479
+ design_parser.add_argument(
480
+ "--output", "-o", required=True, help="Output audio file path"
481
+ )
482
+
483
+ clone_parser = subparsers.add_parser(
484
+ "clone", help="Clone a voice with reference/prompt audio"
485
+ )
486
+ _add_common_generation_args(clone_parser)
487
+ _add_prompt_reference_args(clone_parser)
488
+ _add_model_args(clone_parser)
489
+ _add_lora_args(clone_parser)
490
+ clone_parser.add_argument(
491
+ "--output", "-o", required=True, help="Output audio file path"
492
+ )
493
+
494
+ batch_parser = subparsers.add_parser(
495
+ "batch", help="Batch-generate one line per output file"
496
+ )
497
+ batch_parser.add_argument(
498
+ "--input", "-i", required=True, help="Input text file (one text per line)"
499
+ )
500
+ batch_parser.add_argument(
501
+ "--output-dir", "-od", required=True, help="Output directory"
502
+ )
503
+ batch_parser.add_argument(
504
+ "--control",
505
+ type=str,
506
+ help="Control instruction for VoxCPM2 voice design/cloning",
507
+ )
508
+ _add_prompt_reference_args(batch_parser)
509
+ batch_parser.add_argument(
510
+ "--cfg-value",
511
+ type=float,
512
+ default=2.0,
513
+ help="CFG guidance scale (float, recommended 1.0–3.0, default: 2.0)",
514
+ )
515
+ batch_parser.add_argument(
516
+ "--inference-timesteps",
517
+ type=int,
518
+ default=10,
519
+ help="Inference steps (int, recommended 4–30, default: 10)",
520
+ )
521
+ batch_parser.add_argument(
522
+ "--normalize", action="store_true", help="Enable text normalization"
523
+ )
524
+ _add_model_args(batch_parser)
525
+ _add_lora_args(batch_parser)
526
+
527
+ # Legacy root arguments
528
+ parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
529
+ parser.add_argument(
530
+ "--output-dir", "-od", help="Output directory (batch mode only)"
531
+ )
532
+ _add_common_generation_args(parser)
533
+ parser.add_argument(
534
+ "--output", "-o", help="Output audio file path (single or clone mode)"
535
+ )
536
+ _add_prompt_reference_args(parser)
537
+ _add_model_args(parser)
538
+ _add_lora_args(parser)
539
+
540
+ return parser
541
+
542
+
543
+ def _dispatch_legacy(args, parser):
544
+ warn_legacy_mode()
545
+
546
+ if args.input and args.text:
547
+ parser.error(
548
+ "Use either batch mode (--input) or single mode (--text), not both."
549
+ )
550
+
551
+ if args.input:
552
+ if not args.output_dir:
553
+ parser.error("Batch mode requires --output-dir")
554
+ return cmd_batch(args, parser)
555
+
556
+ if not args.text or not args.output:
557
+ parser.error("Single-sample legacy mode requires --text and --output")
558
+
559
+ if (
560
+ args.prompt_audio
561
+ or args.prompt_text
562
+ or args.prompt_file
563
+ or args.reference_audio
564
+ ):
565
+ return cmd_clone(args, parser)
566
+
567
+ return cmd_design(args, parser)
568
+
569
+
570
+ # -----------------------------
571
+ # Entrypoint
572
+ # -----------------------------
573
+
574
+
575
+ def main():
576
+ parser = _build_parser()
577
+ args = parser.parse_args()
578
+
579
+ validate_ranges(args, parser)
580
+
581
+ if args.command == "design":
582
+ if not args.text:
583
+ parser.error("`design` requires --text")
584
+ return cmd_design(args, parser)
585
+
586
+ if args.command == "clone":
587
+ if not args.text or not args.output:
588
+ parser.error("`clone` requires --text and --output")
589
+ return cmd_clone(args, parser)
590
+
591
+ if args.command == "batch":
592
+ return cmd_batch(args, parser)
593
+
594
+ return _dispatch_legacy(args, parser)
595
+
596
+
597
+ if __name__ == "__main__":
598
+ main()
voxcpm/core.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import json
5
+ import tempfile
6
+ import numpy as np
7
+ from typing import Generator, Optional
8
+ from huggingface_hub import snapshot_download
9
+ from .model.voxcpm import VoxCPMModel, LoRAConfig
10
+ from .model.voxcpm2 import VoxCPM2Model
11
+
12
+
13
+ class VoxCPM:
14
+ def __init__(
15
+ self,
16
+ voxcpm_model_path: str,
17
+ zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
18
+ enable_denoiser: bool = True,
19
+ optimize: bool = True,
20
+ lora_config: Optional[LoRAConfig] = None,
21
+ lora_weights_path: Optional[str] = None,
22
+ ):
23
+ """Initialize VoxCPM TTS pipeline.
24
+
25
+ Args:
26
+ voxcpm_model_path: Local filesystem path to the VoxCPM model assets
27
+ (weights, configs, etc.). Typically the directory returned by
28
+ a prior download step.
29
+ zipenhancer_model_path: ModelScope acoustic noise suppression model
30
+ id or local path. If None, denoiser will not be initialized.
31
+ enable_denoiser: Whether to initialize the denoiser pipeline.
32
+ optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
33
+ lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
34
+ provided without lora_config, a default config will be created.
35
+ lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
36
+ containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
37
+ """
38
+ print(
39
+ f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
40
+ file=sys.stderr,
41
+ )
42
+
43
+ # If lora_weights_path is provided but no lora_config, create a default one
44
+ if lora_weights_path is not None and lora_config is None:
45
+ lora_config = LoRAConfig(
46
+ enable_lm=True,
47
+ enable_dit=True,
48
+ enable_proj=False,
49
+ )
50
+ print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
51
+
52
+ # Determine model type from config.json architecture field
53
+ config_path = os.path.join(voxcpm_model_path, "config.json")
54
+ with open(config_path, "r", encoding="utf-8") as f:
55
+ config = json.load(f)
56
+ arch = config.get("architecture", "voxcpm").lower()
57
+
58
+ if arch == "voxcpm2":
59
+ self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
60
+ print("Loaded VoxCPM2Model", file=sys.stderr)
61
+ elif arch == "voxcpm":
62
+ self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
63
+ print("Loaded VoxCPMModel", file=sys.stderr)
64
+ else:
65
+ raise ValueError(f"Unsupported architecture: {arch}")
66
+
67
+ # Load LoRA weights if path is provided
68
+ if lora_weights_path is not None:
69
+ print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
70
+ loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
71
+ print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
72
+
73
+ self.text_normalizer = None
74
+ self.denoiser = None
75
+ if enable_denoiser and zipenhancer_model_path is not None:
76
+ from .zipenhancer import ZipEnhancer
77
+
78
+ self.denoiser = ZipEnhancer(zipenhancer_model_path)
79
+ else:
80
+ self.denoiser = None
81
+ if optimize:
82
+ print("Warm up VoxCPMModel...", file=sys.stderr)
83
+ self.tts_model.generate(
84
+ target_text="Hello, this is the first test sentence.",
85
+ max_len=10,
86
+ )
87
+
88
+ @classmethod
89
+ def from_pretrained(
90
+ cls,
91
+ hf_model_id: str = "openbmb/VoxCPM2",
92
+ load_denoiser: bool = True,
93
+ zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
94
+ cache_dir: str = None,
95
+ local_files_only: bool = False,
96
+ optimize: bool = True,
97
+ lora_config: Optional[LoRAConfig] = None,
98
+ lora_weights_path: Optional[str] = None,
99
+ **kwargs,
100
+ ):
101
+ """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
102
+
103
+ Args:
104
+ hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
105
+ load_denoiser: Whether to initialize the denoiser pipeline.
106
+ optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
107
+ zipenhancer_model_id: Denoiser model id or path for ModelScope
108
+ acoustic noise suppression.
109
+ cache_dir: Custom cache directory for the snapshot.
110
+ local_files_only: If True, only use local files and do not attempt
111
+ to download.
112
+ lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
113
+ provided without lora_config, a default config will be created with
114
+ enable_lm=True and enable_dit=True.
115
+ lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
116
+ containing lora_weights.ckpt). If provided, LoRA weights will be loaded
117
+ after model initialization.
118
+ Kwargs:
119
+ Additional keyword arguments passed to the ``VoxCPM`` constructor.
120
+
121
+ Returns:
122
+ VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
123
+ the downloaded snapshot directory.
124
+
125
+ Raises:
126
+ ValueError: If neither a valid ``hf_model_id`` nor a resolvable
127
+ ``hf_model_id`` is provided.
128
+ """
129
+ repo_id = hf_model_id
130
+ if not repo_id:
131
+ raise ValueError("You must provide hf_model_id")
132
+
133
+ # Load from local path if provided
134
+ if os.path.isdir(repo_id):
135
+ local_path = repo_id
136
+ else:
137
+ # Otherwise, try from_pretrained (Hub); exit on failure
138
+ local_path = snapshot_download(
139
+ repo_id=repo_id,
140
+ cache_dir=cache_dir,
141
+ local_files_only=local_files_only,
142
+ )
143
+
144
+ return cls(
145
+ voxcpm_model_path=local_path,
146
+ zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
147
+ enable_denoiser=load_denoiser,
148
+ optimize=optimize,
149
+ lora_config=lora_config,
150
+ lora_weights_path=lora_weights_path,
151
+ **kwargs,
152
+ )
153
+
154
+ def generate(self, *args, **kwargs) -> np.ndarray:
155
+ return next(self._generate(*args, streaming=False, **kwargs))
156
+
157
+ def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
158
+ return self._generate(*args, streaming=True, **kwargs)
159
+
160
+ def _generate(
161
+ self,
162
+ text: str,
163
+ prompt_wav_path: str = None,
164
+ prompt_text: str = None,
165
+ reference_wav_path: str = None,
166
+ cfg_value: float = 2.0,
167
+ inference_timesteps: int = 10,
168
+ min_len: int = 2,
169
+ max_len: int = 4096,
170
+ normalize: bool = False,
171
+ denoise: bool = False,
172
+ retry_badcase: bool = True,
173
+ retry_badcase_max_times: int = 3,
174
+ retry_badcase_ratio_threshold: float = 6.0,
175
+ streaming: bool = False,
176
+ ) -> Generator[np.ndarray, None, None]:
177
+ """Synthesize speech for the given text and return a single waveform.
178
+
179
+ Args:
180
+ text: Input text to synthesize.
181
+ prompt_wav_path: Path to prompt audio for continuation mode.
182
+ Must be paired with ``prompt_text``.
183
+ prompt_text: Text content corresponding to the prompt audio.
184
+ reference_wav_path: Path to reference audio for voice cloning
185
+ (structurally isolated via ref_audio tokens). Can be used
186
+ alone or combined with ``prompt_wav_path`` + ``prompt_text``.
187
+ cfg_value: Guidance scale for the generation model.
188
+ inference_timesteps: Number of inference steps.
189
+ min_len: Minimum audio length.
190
+ max_len: Maximum token length during generation.
191
+ normalize: Whether to run text normalization before generation.
192
+ denoise: Whether to denoise the prompt/reference audio if a
193
+ denoiser is available.
194
+ retry_badcase: Whether to retry badcase.
195
+ retry_badcase_max_times: Maximum number of times to retry badcase.
196
+ retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
197
+ streaming: Whether to return a generator of audio chunks.
198
+ Returns:
199
+ Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
200
+ Yields audio chunks for each generation step if ``streaming=True``,
201
+ otherwise yields a single array containing the final audio.
202
+ """
203
+ if not text.strip() or not isinstance(text, str):
204
+ raise ValueError("target text must be a non-empty string")
205
+
206
+ if prompt_wav_path is not None:
207
+ if not os.path.exists(prompt_wav_path):
208
+ raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
209
+
210
+ if reference_wav_path is not None:
211
+ if not os.path.exists(reference_wav_path):
212
+ raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
213
+
214
+ if (prompt_wav_path is None) != (prompt_text is None):
215
+ raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
216
+
217
+ is_v2 = isinstance(self.tts_model, VoxCPM2Model)
218
+ if reference_wav_path is not None and not is_v2:
219
+ raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
220
+
221
+ text = text.replace("\n", " ")
222
+ text = re.sub(r"\s+", " ", text)
223
+ temp_files = []
224
+
225
+ try:
226
+ actual_prompt_path = prompt_wav_path
227
+ actual_ref_path = reference_wav_path
228
+
229
+ if denoise and self.denoiser is not None:
230
+ if prompt_wav_path is not None:
231
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
232
+ temp_files.append(tmp.name)
233
+ self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
234
+ actual_prompt_path = temp_files[-1]
235
+ if reference_wav_path is not None:
236
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
237
+ temp_files.append(tmp.name)
238
+ self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
239
+ actual_ref_path = temp_files[-1]
240
+
241
+ if actual_prompt_path is not None or actual_ref_path is not None:
242
+ if is_v2:
243
+ fixed_prompt_cache = self.tts_model.build_prompt_cache(
244
+ prompt_text=prompt_text,
245
+ prompt_wav_path=actual_prompt_path,
246
+ reference_wav_path=actual_ref_path,
247
+ )
248
+ else:
249
+ fixed_prompt_cache = self.tts_model.build_prompt_cache(
250
+ prompt_text=prompt_text,
251
+ prompt_wav_path=actual_prompt_path,
252
+ )
253
+ else:
254
+ fixed_prompt_cache = None
255
+
256
+ if normalize:
257
+ if self.text_normalizer is None:
258
+ from .utils.text_normalize import TextNormalizer
259
+
260
+ self.text_normalizer = TextNormalizer()
261
+ text = self.text_normalizer.normalize(text)
262
+
263
+ generate_result = self.tts_model._generate_with_prompt_cache(
264
+ target_text=text,
265
+ prompt_cache=fixed_prompt_cache,
266
+ min_len=min_len,
267
+ max_len=max_len,
268
+ inference_timesteps=inference_timesteps,
269
+ cfg_value=cfg_value,
270
+ retry_badcase=retry_badcase,
271
+ retry_badcase_max_times=retry_badcase_max_times,
272
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
273
+ streaming=streaming,
274
+ )
275
+
276
+ for wav, _, _ in generate_result:
277
+ yield wav.squeeze(0).cpu().numpy()
278
+
279
+ finally:
280
+ for tmp_path in temp_files:
281
+ if tmp_path and os.path.exists(tmp_path):
282
+ try:
283
+ os.unlink(tmp_path)
284
+ except OSError:
285
+ pass
286
+
287
+ # ------------------------------------------------------------------ #
288
+ # LoRA Interface (delegated to VoxCPMModel)
289
+ # ------------------------------------------------------------------ #
290
+ def load_lora(self, lora_weights_path: str) -> tuple:
291
+ """Load LoRA weights from a checkpoint file.
292
+
293
+ Args:
294
+ lora_weights_path: Path to LoRA weights (.pth file or directory
295
+ containing lora_weights.ckpt).
296
+
297
+ Returns:
298
+ tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
299
+
300
+ Raises:
301
+ RuntimeError: If model was not initialized with LoRA config.
302
+ """
303
+ if self.tts_model.lora_config is None:
304
+ raise RuntimeError(
305
+ "Cannot load LoRA weights: model was not initialized with LoRA config. "
306
+ "Please reinitialize with lora_config or lora_weights_path parameter."
307
+ )
308
+ return self.tts_model.load_lora_weights(lora_weights_path)
309
+
310
+ def unload_lora(self):
311
+ """Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
312
+ self.tts_model.reset_lora_weights()
313
+
314
+ def set_lora_enabled(self, enabled: bool):
315
+ """Enable or disable LoRA layers without unloading weights.
316
+
317
+ Args:
318
+ enabled: If True, LoRA layers are active; if False, only base model is used.
319
+ """
320
+ self.tts_model.set_lora_enabled(enabled)
321
+
322
+ def get_lora_state_dict(self) -> dict:
323
+ """Get current LoRA parameters state dict.
324
+
325
+ Returns:
326
+ dict: State dict containing all LoRA parameters (lora_A, lora_B).
327
+ """
328
+ return self.tts_model.get_lora_state_dict()
329
+
330
+ @property
331
+ def lora_enabled(self) -> bool:
332
+ """Check if LoRA is currently configured."""
333
+ return self.tts_model.lora_config is not None
voxcpm/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .voxcpm import VoxCPMModel
2
+ from .voxcpm2 import VoxCPM2Model
3
+
4
+ __all__ = ["VoxCPMModel", "VoxCPM2Model"]
voxcpm/model/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from transformers import PreTrainedTokenizer
4
+
5
+
6
+ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
7
+ """Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
8
+
9
+ This function creates a wrapper around the provided tokenizer that automatically
10
+ splits multi-character Chinese tokens into individual characters. This is useful
11
+ for ensuring consistent tokenization of Chinese text.
12
+
13
+ Args:
14
+ tokenizer: The base tokenizer to wrap
15
+
16
+ Returns:
17
+ A CharTokenizerWrapper instance that handles multi-character Chinese tokens
18
+
19
+ Example:
20
+ >>> from transformers import LlamaTokenizerFast
21
+ >>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
22
+ >>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
23
+ >>> tokens = wrapped_tokenizer("你好世界")
24
+ """
25
+ # Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
26
+ multichar_tokens = {
27
+ token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
28
+ }
29
+
30
+ class CharTokenizerWrapper:
31
+ """Wrapper class for tokenizers that handles multi-character Chinese tokens.
32
+
33
+ This wrapper automatically splits multi-character Chinese tokens into
34
+ individual characters while preserving the original tokenizer's interface.
35
+ """
36
+
37
+ def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
38
+ """Initialize the wrapper with a base tokenizer.
39
+
40
+ Args:
41
+ base_tokenizer: The tokenizer to wrap
42
+ """
43
+ self.tokenizer = base_tokenizer
44
+ self.multichar_tokens = multichar_tokens
45
+
46
+ def tokenize(self, text: str, **kwargs) -> List[str]:
47
+ """Tokenize text and split multi-character Chinese tokens into single characters.
48
+
49
+ Args:
50
+ text: Input text to tokenize
51
+ **kwargs: Additional arguments passed to the base tokenizer
52
+
53
+ Returns:
54
+ List of processed tokens with multi-character Chinese tokens split
55
+
56
+ Example:
57
+ >>> wrapper = CharTokenizerWrapper(tokenizer)
58
+ >>> tokens = wrapper.tokenize("你好世界")
59
+ >>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
60
+ """
61
+ if not isinstance(text, str):
62
+ raise TypeError(f"Expected string input, got {type(text)}")
63
+
64
+ tokens = self.tokenizer.tokenize(text, **kwargs)
65
+ processed = []
66
+
67
+ for token in tokens:
68
+ # Remove possible subword prefix
69
+ clean_token = token.replace("▁", "")
70
+
71
+ if clean_token in self.multichar_tokens:
72
+ # Split multi-character token into single characters
73
+ chars = list(clean_token)
74
+ processed.extend(chars)
75
+ else:
76
+ processed.append(token)
77
+
78
+ return processed
79
+
80
+ def __call__(self, text: str, **kwargs) -> List[int]:
81
+ """Call the tokenizer and return token IDs.
82
+
83
+ This method provides the same interface as the original tokenizer
84
+ but with multi-character Chinese token handling.
85
+
86
+ Args:
87
+ text: Input text to tokenize
88
+ **kwargs: Additional arguments passed to the base tokenizer
89
+
90
+ Returns:
91
+ List of token IDs
92
+
93
+ Raises:
94
+ TypeError: If input is not a string
95
+ ValueError: If tokenization fails
96
+ """
97
+ try:
98
+ tokens = self.tokenize(text, **kwargs)
99
+ result = self.tokenizer.convert_tokens_to_ids(tokens)
100
+ return result
101
+ except Exception as e:
102
+ raise ValueError(f"Tokenization failed: {str(e)}") from e
103
+
104
+ return CharTokenizerWrapper(tokenizer)
105
+
106
+
107
+ def get_dtype(dtype: str):
108
+ if dtype == "bfloat16":
109
+ return torch.bfloat16
110
+ elif dtype == "bf16":
111
+ return torch.bfloat16
112
+ elif dtype == "float16":
113
+ return torch.float16
114
+ elif dtype == "fp16":
115
+ return torch.float16
116
+ elif dtype == "float32":
117
+ return torch.float32
118
+ elif dtype == "fp32":
119
+ return torch.float32
120
+ else:
121
+ raise ValueError(f"Unsupported dtype: {dtype}")
voxcpm/model/voxcpm.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoxCPM: A Tokenizer-free speech generation model
3
+
4
+ This module contains the main VoxCPM model implementation, including configuration classes
5
+ and the core VoxCPMModel for text-to-speech generation.
6
+
7
+ Copyright 2025 OpenBMB
8
+ Licensed under the Apache License, Version 2.0 (the "License");
9
+ you may not use this file except in compliance with the License.
10
+ You may obtain a copy of the License at
11
+
12
+ http://www.apache.org/licenses/LICENSE-2.0
13
+
14
+ Unless required by applicable law or agreed to in writing, software
15
+ distributed under the License is distributed on an "AS IS" BASIS,
16
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ See the License for the specific language governing permissions and
18
+ limitations under the License.
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ from typing import Tuple, Union, Generator, List, Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torchaudio
28
+ import warnings
29
+ from einops import rearrange
30
+ from pydantic import BaseModel
31
+
32
+ try:
33
+ from safetensors.torch import load_file
34
+
35
+ SAFETENSORS_AVAILABLE = True
36
+ except ImportError:
37
+ SAFETENSORS_AVAILABLE = False
38
+ from tqdm import tqdm
39
+ from transformers import LlamaTokenizerFast
40
+
41
+ from ..modules.audiovae import AudioVAE, AudioVAEConfig
42
+ from ..modules.layers import ScalarQuantizationLayer
43
+ from ..modules.layers.lora import apply_lora_to_named_linear_modules
44
+ from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
45
+ from ..modules.locenc import VoxCPMLocEnc
46
+ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
47
+ from .utils import get_dtype, mask_multichar_chinese_tokens
48
+
49
+
50
+ class VoxCPMEncoderConfig(BaseModel):
51
+ hidden_dim: int = 1024
52
+ ffn_dim: int = 4096
53
+ num_heads: int = 16
54
+ num_layers: int = 4
55
+ kv_channels: int = None
56
+
57
+
58
+ class VoxCPMDitConfig(BaseModel):
59
+ hidden_dim: int = 1024
60
+ ffn_dim: int = 4096
61
+ num_heads: int = 16
62
+ num_layers: int = 4
63
+ kv_channels: int = None
64
+
65
+ cfm_config: CfmConfig
66
+
67
+
68
+ class VoxCPMConfig(BaseModel):
69
+ lm_config: MiniCPM4Config
70
+ patch_size: int = 2
71
+ feat_dim: int = 64
72
+ residual_lm_num_layers: int = 6
73
+ scalar_quantization_latent_dim: int = 256
74
+ scalar_quantization_scale: int = 9
75
+
76
+ encoder_config: VoxCPMEncoderConfig
77
+ dit_config: VoxCPMDitConfig
78
+ audio_vae_config: Optional[AudioVAEConfig] = None
79
+
80
+ max_length: int = 4096
81
+ device: str = "cuda"
82
+ dtype: str = "bfloat16"
83
+ dit_mean_mode: bool = False
84
+
85
+
86
+ class LoRAConfig(BaseModel):
87
+ enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
88
+ enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
89
+ enable_proj: bool = False # Apply LoRA to projection Linear layers
90
+
91
+ r: int = 8
92
+ alpha: int = 16
93
+ dropout: float = 0.0
94
+
95
+ # Target linear layer names for LM & DiT (matched by attribute name)
96
+ target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
97
+ target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
98
+ # Projection layer attribute names to find on VoxCPMModel
99
+ target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
100
+
101
+
102
+ VoxCPMConfig.model_rebuild()
103
+
104
+
105
+ class VoxCPMModel(nn.Module):
106
+ def __init__(
107
+ self,
108
+ config: VoxCPMConfig,
109
+ tokenizer: LlamaTokenizerFast,
110
+ audio_vae: AudioVAE,
111
+ lora_config: LoRAConfig = None,
112
+ ):
113
+ super().__init__()
114
+ self.config = config
115
+ self.lora_config = lora_config
116
+ self.feat_dim = config.feat_dim
117
+ self.patch_size = config.patch_size
118
+ self.device = config.device
119
+ if not torch.cuda.is_available():
120
+ if torch.backends.mps.is_available():
121
+ self.device = "mps"
122
+ else:
123
+ self.device = "cpu"
124
+ print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
125
+
126
+ # Text-Semantic LM
127
+ self.base_lm = MiniCPMModel(config.lm_config)
128
+ self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
129
+
130
+ self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
131
+ self.audio_start_token = 101
132
+ self.audio_end_token = 102
133
+
134
+ # Residual Acoustic LM
135
+ residual_lm_config = config.lm_config.model_copy(deep=True)
136
+ residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
137
+ residual_lm_config.vocab_size = 0
138
+ self.residual_lm = MiniCPMModel(residual_lm_config)
139
+ self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
140
+
141
+ # Local Encoder
142
+ encoder_config = config.lm_config.model_copy(deep=True)
143
+ encoder_config.hidden_size = config.encoder_config.hidden_dim
144
+ encoder_config.intermediate_size = config.encoder_config.ffn_dim
145
+ encoder_config.num_attention_heads = config.encoder_config.num_heads
146
+ encoder_config.num_hidden_layers = config.encoder_config.num_layers
147
+ encoder_config.kv_channels = config.encoder_config.kv_channels
148
+ encoder_config.vocab_size = 0
149
+ self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
150
+
151
+ # Local DiT
152
+ decoder_config = config.lm_config.model_copy(deep=True)
153
+ decoder_config.hidden_size = config.dit_config.hidden_dim
154
+ decoder_config.intermediate_size = config.dit_config.ffn_dim
155
+ decoder_config.num_attention_heads = config.dit_config.num_heads
156
+ decoder_config.num_hidden_layers = config.dit_config.num_layers
157
+ decoder_config.kv_channels = config.dit_config.kv_channels
158
+ decoder_config.vocab_size = 0
159
+ self.feat_decoder = UnifiedCFM(
160
+ in_channels=config.feat_dim,
161
+ cfm_params=config.dit_config.cfm_config,
162
+ estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
163
+ mean_mode=config.dit_mean_mode,
164
+ )
165
+
166
+ # Projection layers
167
+ self.fsq_layer = ScalarQuantizationLayer(
168
+ config.lm_config.hidden_size,
169
+ config.lm_config.hidden_size,
170
+ config.scalar_quantization_latent_dim,
171
+ config.scalar_quantization_scale,
172
+ )
173
+ self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
174
+ self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
175
+ self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
176
+
177
+ # Stop Predictor
178
+ self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
179
+ self.stop_actn = nn.SiLU()
180
+ self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
181
+ self.stop_loss = nn.CrossEntropyLoss(reduction="none")
182
+
183
+ # Audio VAE
184
+ self.audio_vae = audio_vae
185
+ self.chunk_size = audio_vae.chunk_size
186
+ self.sample_rate = audio_vae.sample_rate
187
+
188
+ if self.lora_config is not None:
189
+ self._apply_lora()
190
+
191
+ def _apply_lora(self):
192
+ """注入 LoRA 到 LM / DiT / 投影层"""
193
+ cfg = self.lora_config
194
+ lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
195
+
196
+ # LM: base_lm + residual_lm
197
+ if cfg.enable_lm:
198
+ for lm in [self.base_lm, self.residual_lm]:
199
+ apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
200
+
201
+ # DiT: feat_decoder.estimator
202
+ if cfg.enable_dit:
203
+ apply_lora_to_named_linear_modules(
204
+ self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
205
+ )
206
+
207
+ # 投影层
208
+ if cfg.enable_proj:
209
+ from ..modules.layers.lora import LoRALinear
210
+
211
+ for attr_name in cfg.target_proj_modules:
212
+ module = getattr(self, attr_name, None)
213
+ if isinstance(module, nn.Linear):
214
+ setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
215
+
216
+ def optimize(self, disable: bool = False):
217
+ if disable:
218
+ return self
219
+ try:
220
+ if self.device != "cuda":
221
+ raise ValueError("VoxCPMModel can only be optimized on CUDA device")
222
+ try:
223
+ import triton # noqa: F401
224
+ except ImportError:
225
+ raise ValueError("triton is not installed")
226
+ self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
227
+ self.residual_lm.forward_step = torch.compile(
228
+ self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
229
+ )
230
+ self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
231
+ self.feat_decoder.estimator = torch.compile(
232
+ self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
233
+ )
234
+ except Exception as e:
235
+ print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
236
+ return self
237
+
238
+ def forward(
239
+ self,
240
+ text_tokens: torch.Tensor,
241
+ text_mask: torch.Tensor,
242
+ audio_feats: torch.Tensor,
243
+ audio_mask: torch.Tensor,
244
+ loss_mask: torch.Tensor,
245
+ position_ids: torch.Tensor,
246
+ labels: torch.Tensor,
247
+ *,
248
+ progress: float = 0.0,
249
+ sample_generate: bool = False,
250
+ ):
251
+ del position_ids # not used yet
252
+
253
+ text_tokens = text_tokens.to(self.device, dtype=torch.long)
254
+ text_mask = text_mask.to(self.device, dtype=self._dtype())
255
+ audio_feats = audio_feats.to(self.device, dtype=self._dtype())
256
+ audio_mask = audio_mask.to(self.device, dtype=self._dtype())
257
+ loss_mask = loss_mask.to(self.device, dtype=self._dtype())
258
+ labels = labels.to(self.device, dtype=torch.long)
259
+
260
+ B, T, P, D = audio_feats.shape
261
+ feat_embed = self.feat_encoder(audio_feats)
262
+ feat_embed = self.enc_to_lm_proj(feat_embed)
263
+
264
+ scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
265
+ if not getattr(self.config.lm_config, "use_mup", False):
266
+ scale_emb = 1.0
267
+ text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
268
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
269
+
270
+ enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
271
+ enc_outputs = enc_outputs.to(self._dtype())
272
+ enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
273
+ lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
274
+
275
+ residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
276
+ residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
277
+ residual_outputs = residual_outputs.to(self._dtype())
278
+ residual_hidden = torch.cat(
279
+ (torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
280
+ dim=1,
281
+ )
282
+
283
+ dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
284
+ dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
285
+
286
+ # Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
287
+ target_dtype = self._dtype()
288
+
289
+ feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
290
+ feat_cond = torch.cat(
291
+ (torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
292
+ dim=1,
293
+ )
294
+ feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
295
+
296
+ loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
297
+ loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
298
+
299
+ diff_loss = self.feat_decoder.compute_loss(
300
+ feat_gt.transpose(1, 2).contiguous(),
301
+ dit_hidden,
302
+ cond=feat_cond.transpose(1, 2).contiguous(),
303
+ tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
304
+ progress=progress,
305
+ )
306
+
307
+ stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
308
+ stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
309
+ denom = torch.clamp(loss_mask.sum(), min=1.0)
310
+ stop_loss = (stop_losses * loss_mask).sum() / denom
311
+
312
+ feat_pred = None
313
+ if sample_generate:
314
+ feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
315
+ feat_pred_seq = self.feat_decoder(
316
+ mu=dit_hidden,
317
+ patch_size=self.patch_size,
318
+ cond=feat_cond_for_sample,
319
+ n_timesteps=(
320
+ self.config.dit_config.cfm_config.inference_cfg_rate
321
+ if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
322
+ else 10
323
+ ),
324
+ )
325
+ feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
326
+
327
+ feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
328
+
329
+ return {
330
+ "loss/diff": diff_loss,
331
+ "loss/stop": stop_loss,
332
+ "feat_gt": feat_gt_tensor,
333
+ "feat_pred": feat_pred,
334
+ }
335
+
336
+ def _dtype(self):
337
+ return get_dtype(self.config.dtype)
338
+
339
+ def generate(self, *args, **kwargs) -> torch.Tensor:
340
+ return next(self._generate(*args, streaming=False, **kwargs))
341
+
342
+ def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
343
+ return self._generate(*args, streaming=True, **kwargs)
344
+
345
+ @torch.inference_mode()
346
+ def _generate(
347
+ self,
348
+ target_text: str,
349
+ prompt_text: str = "",
350
+ prompt_wav_path: str = "",
351
+ min_len: int = 2,
352
+ max_len: int = 2000,
353
+ inference_timesteps: int = 10,
354
+ cfg_value: float = 2.0,
355
+ retry_badcase: bool = False,
356
+ retry_badcase_max_times: int = 3,
357
+ retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
358
+ streaming: bool = False,
359
+ ) -> Generator[torch.Tensor, None, None]:
360
+ if retry_badcase and streaming:
361
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
362
+ retry_badcase = False
363
+ if len(prompt_wav_path) == 0:
364
+ text = target_text
365
+ text_token = torch.LongTensor(self.text_tokenizer(text))
366
+ text_token = torch.cat(
367
+ [
368
+ text_token,
369
+ torch.tensor(
370
+ [self.audio_start_token],
371
+ dtype=torch.int32,
372
+ device=text_token.device,
373
+ ),
374
+ ],
375
+ dim=-1,
376
+ )
377
+ text_length = text_token.shape[0]
378
+
379
+ audio_feat = torch.zeros(
380
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
381
+ dtype=torch.float32,
382
+ device=text_token.device,
383
+ )
384
+ text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
385
+ audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
386
+
387
+ else:
388
+ text = prompt_text + target_text
389
+ text_token = torch.LongTensor(self.text_tokenizer(text))
390
+ text_token = torch.cat(
391
+ [
392
+ text_token,
393
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
394
+ ],
395
+ dim=-1,
396
+ )
397
+ text_length = text_token.shape[0]
398
+
399
+ audio, sr = torchaudio.load(prompt_wav_path)
400
+ if audio.size(0) > 1:
401
+ audio = audio.mean(dim=0, keepdim=True)
402
+
403
+ if sr != self.sample_rate:
404
+ audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
405
+
406
+ patch_len = self.patch_size * self.chunk_size
407
+
408
+ if audio.size(1) % patch_len != 0:
409
+ # 左填充:在音频开头填充,保持有效音频数据在序列末尾
410
+ padding_size = patch_len - audio.size(1) % patch_len
411
+ audio = torch.nn.functional.pad(audio, (padding_size, 0))
412
+
413
+ # (B, D, T)
414
+ audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
415
+ audio_feat = audio_feat.view(
416
+ self.audio_vae.latent_dim,
417
+ -1,
418
+ self.patch_size,
419
+ ).permute(1, 2, 0)
420
+ audio_length = audio_feat.size(0)
421
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
422
+ text_token = torch.cat([text_token, text_pad_token])
423
+ audio_pad_feat = torch.zeros(
424
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
425
+ dtype=torch.float32,
426
+ device=text_token.device,
427
+ )
428
+ audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
429
+ text_mask = (
430
+ torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
431
+ )
432
+ audio_mask = (
433
+ torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
434
+ )
435
+
436
+ text_token = text_token.unsqueeze(0).to(self.device)
437
+ text_mask = text_mask.unsqueeze(0).to(self.device)
438
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
439
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
440
+
441
+ target_text_length = len(self.text_tokenizer(target_text))
442
+
443
+ retry_badcase_times = 0
444
+ while retry_badcase_times < retry_badcase_max_times:
445
+ inference_result = self._inference(
446
+ text_token,
447
+ text_mask,
448
+ audio_feat,
449
+ audio_mask,
450
+ min_len=min_len,
451
+ max_len=min(
452
+ int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
453
+ ), # avoid too long audio
454
+ inference_timesteps=inference_timesteps,
455
+ cfg_value=cfg_value,
456
+ streaming=streaming,
457
+ )
458
+ if streaming:
459
+ patch_len = self.patch_size * self.chunk_size
460
+ for latent_pred, _ in inference_result:
461
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
462
+ decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
463
+ yield decode_audio
464
+ break
465
+ else:
466
+ latent_pred, pred_audio_feat = next(inference_result)
467
+ if retry_badcase:
468
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
469
+ print(
470
+ f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
471
+ file=sys.stderr,
472
+ )
473
+ retry_badcase_times += 1
474
+ continue
475
+ else:
476
+ break
477
+ else:
478
+ break
479
+
480
+ if not streaming:
481
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
482
+ yield decode_audio
483
+
484
+ @torch.inference_mode()
485
+ def build_prompt_cache(
486
+ self,
487
+ prompt_text: str,
488
+ prompt_wav_path: str,
489
+ ):
490
+ """
491
+ Build prompt cache for subsequent fast generation.
492
+
493
+ Args:
494
+ prompt_text: prompt text (required)
495
+ prompt_wav_path: prompt audio path (required)
496
+
497
+ Returns:
498
+ prompt_cache: dict with prompt_text (raw text) and audio features.
499
+ Text tokenization will be done during generation for consistency.
500
+ """
501
+ if not prompt_text or not prompt_wav_path:
502
+ raise ValueError("prompt_text and prompt_wav_path are required")
503
+
504
+ # load audio
505
+ audio, sr = torchaudio.load(prompt_wav_path)
506
+ if audio.size(0) > 1:
507
+ audio = audio.mean(dim=0, keepdim=True)
508
+
509
+ if sr != self.sample_rate:
510
+ audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
511
+
512
+ patch_len = self.patch_size * self.chunk_size
513
+
514
+ if audio.size(1) % patch_len != 0:
515
+ # Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
516
+ padding_size = patch_len - audio.size(1) % patch_len
517
+ audio = torch.nn.functional.pad(audio, (padding_size, 0))
518
+
519
+ # extract audio features
520
+ audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
521
+
522
+ audio_feat = audio_feat.view(
523
+ self.audio_vae.latent_dim,
524
+ -1,
525
+ self.patch_size,
526
+ ).permute(
527
+ 1, 2, 0
528
+ ) # (D, T, P)
529
+ # build prompt cache - only save raw text and audio features
530
+ prompt_cache = {
531
+ "prompt_text": prompt_text,
532
+ "audio_feat": audio_feat,
533
+ }
534
+
535
+ return prompt_cache
536
+
537
+ def merge_prompt_cache(
538
+ self,
539
+ original_cache: dict,
540
+ new_text: str,
541
+ new_audio_feat: torch.Tensor,
542
+ ):
543
+ """
544
+ Merge original prompt cache with newly generated content to stabilize voice.
545
+
546
+ Args:
547
+ original_cache: original prompt cache
548
+ new_text: newly generated text
549
+ new_audio_feat: newly generated audio features
550
+
551
+ Returns:
552
+ merged_cache: merged cache with prompt_text and audio_feat
553
+ """
554
+ if original_cache is None:
555
+ return {
556
+ "prompt_text": new_text,
557
+ "audio_feat": new_audio_feat,
558
+ }
559
+ original_prompt_text = original_cache["prompt_text"]
560
+ original_audio_feat = original_cache["audio_feat"]
561
+ # Merge text by concatenation
562
+ merged_prompt_text = original_prompt_text + new_text
563
+ merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
564
+
565
+ # build new cache
566
+ merged_cache = {
567
+ "prompt_text": merged_prompt_text,
568
+ "audio_feat": merged_audio_feat,
569
+ }
570
+
571
+ return merged_cache
572
+
573
+ def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
574
+ return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
575
+
576
+ def generate_with_prompt_cache_streaming(
577
+ self, *args, **kwargs
578
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
579
+ return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
580
+
581
+ @torch.inference_mode()
582
+ def _generate_with_prompt_cache(
583
+ self,
584
+ target_text: str,
585
+ prompt_cache: dict,
586
+ min_len: int = 2,
587
+ max_len: int = 2000,
588
+ inference_timesteps: int = 10,
589
+ cfg_value: float = 2.0,
590
+ retry_badcase: bool = False,
591
+ retry_badcase_max_times: int = 3,
592
+ retry_badcase_ratio_threshold: float = 6.0,
593
+ streaming: bool = False,
594
+ streaming_prefix_len: int = 3,
595
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
596
+ """
597
+ Generate audio using pre-built prompt cache.
598
+
599
+ Args:
600
+ target_text: Text to convert to speech
601
+ prompt_cache: Cache built by build_prompt_cache (can be None)
602
+ min_len: Minimum audio length to avoid very short audio
603
+ max_len: Maximum audio length
604
+ inference_timesteps: Number of diffusion sampling steps
605
+ cfg_value: Classifier-free guidance value
606
+ retry_badcase: Whether to retry on bad cases
607
+ retry_badcase_max_times: Maximum retry attempts
608
+ retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
609
+ streaming: Whether to return a generator of audio chunks
610
+ streaming_prefix_len: Number of prefix audio patches to use for streaming mode
611
+
612
+ Returns:
613
+ Generator of Tuple containing:
614
+ - Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
615
+ - Tensor of new text tokens
616
+ - New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
617
+ """
618
+ if retry_badcase and streaming:
619
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
620
+ retry_badcase = False
621
+ # get prompt from cache
622
+ if prompt_cache is None:
623
+ prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
624
+ text = target_text
625
+ else:
626
+ prompt_audio_feat = prompt_cache["audio_feat"]
627
+ prompt_text = prompt_cache["prompt_text"]
628
+ text = prompt_text + target_text
629
+
630
+ text_token = torch.LongTensor(self.text_tokenizer(text))
631
+ text_token = torch.cat(
632
+ [
633
+ text_token,
634
+ torch.tensor(
635
+ [self.audio_start_token],
636
+ dtype=torch.int32,
637
+ device=text_token.device,
638
+ ),
639
+ ],
640
+ dim=-1,
641
+ )
642
+
643
+ target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
644
+
645
+ audio_length = prompt_audio_feat.size(0)
646
+ text_length = text_token.shape[0]
647
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
648
+ audio_pad_feat = torch.zeros(
649
+ (text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
650
+ dtype=torch.float32,
651
+ device=text_token.device,
652
+ )
653
+ text_token = torch.cat([text_token, text_pad_token])
654
+ audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
655
+ text_mask = (
656
+ torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
657
+ )
658
+ audio_mask = (
659
+ torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
660
+ )
661
+
662
+ text_token = text_token.unsqueeze(0).to(self.device)
663
+ text_mask = text_mask.unsqueeze(0).to(self.device)
664
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
665
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
666
+
667
+ # run inference
668
+ target_text_length = len(self.text_tokenizer(target_text))
669
+ retry_badcase_times = 0
670
+ while retry_badcase_times < retry_badcase_max_times:
671
+ inference_result = self._inference(
672
+ text_token,
673
+ text_mask,
674
+ audio_feat,
675
+ audio_mask,
676
+ min_len=min_len,
677
+ max_len=min(
678
+ int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
679
+ ), # avoid too long audio
680
+ inference_timesteps=inference_timesteps,
681
+ cfg_value=cfg_value,
682
+ streaming=streaming,
683
+ streaming_prefix_len=streaming_prefix_len,
684
+ )
685
+ if streaming:
686
+ patch_len = self.patch_size * self.chunk_size
687
+ for latent_pred, pred_audio_feat in inference_result:
688
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
689
+ decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
690
+ yield (decode_audio, target_text_token, pred_audio_feat)
691
+ break
692
+ else:
693
+ latent_pred, pred_audio_feat = next(inference_result)
694
+ if retry_badcase:
695
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
696
+ print(
697
+ f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
698
+ file=sys.stderr,
699
+ )
700
+ retry_badcase_times += 1
701
+ continue
702
+ else:
703
+ break
704
+ else:
705
+ break
706
+ if not streaming:
707
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
708
+ patch_len = self.patch_size * self.chunk_size
709
+ if audio_mask.sum().item() > 0:
710
+ decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
711
+ else:
712
+ decode_audio = decode_audio[..., :].squeeze(1).cpu()
713
+ yield (decode_audio, target_text_token, pred_audio_feat)
714
+
715
+ def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
716
+ return next(self._inference(*args, streaming=False, **kwargs))
717
+
718
+ def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
719
+ return self._inference(*args, streaming=True, **kwargs)
720
+
721
+ @torch.inference_mode()
722
+ def _inference(
723
+ self,
724
+ text: torch.Tensor,
725
+ text_mask: torch.Tensor,
726
+ feat: torch.Tensor,
727
+ feat_mask: torch.Tensor,
728
+ min_len: int = 2,
729
+ max_len: int = 2000,
730
+ inference_timesteps: int = 10,
731
+ cfg_value: float = 2.0,
732
+ streaming: bool = False,
733
+ streaming_prefix_len: int = 3,
734
+ ) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
735
+ """Core inference method for audio generation.
736
+
737
+ This is the main inference loop that generates audio features
738
+ using the language model and diffusion transformer.
739
+
740
+ Args:
741
+ text: Input text tokens
742
+ text_mask: Mask for text tokens
743
+ feat: Input audio features
744
+ feat_mask: Mask for audio features
745
+ min_len: Minimum generation length
746
+ max_len: Maximum generation length
747
+ inference_timesteps: Number of diffusion steps
748
+ cfg_value: Classifier-free guidance value
749
+ streaming: Whether to yield each step latent feature or just the final result
750
+
751
+ Returns:
752
+ Generator of Tuple containing:
753
+ - Predicted latent feature at the current step if ``streaming=True``, else final latent features
754
+ - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
755
+ """
756
+ B, T, P, D = feat.shape
757
+
758
+ feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
759
+ feat_embed = self.enc_to_lm_proj(feat_embed)
760
+
761
+ if self.config.lm_config.use_mup:
762
+ scale_emb = self.config.lm_config.scale_emb
763
+ else:
764
+ scale_emb = 1.0
765
+
766
+ text_embed = self.base_lm.embed_tokens(text) * scale_emb
767
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
768
+
769
+ prefix_feat_cond = feat[:, -1, ...] # b, p, d
770
+ pred_feat_seq = [] # b, t, p, d
771
+ curr_embed = None
772
+
773
+ # Prepare prompt context patches for streaming mode
774
+ # When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
775
+ prompt_context_patches = []
776
+ audio_patch_count = int(feat_mask.sum().item())
777
+ if audio_patch_count > 0:
778
+ context_len = min(streaming_prefix_len - 1, audio_patch_count)
779
+ # Take the last context_len patches from prompt audio as initial context
780
+ # Split into list of [b, 1, p, d] tensors to match pred_feat_seq format
781
+ prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
782
+ pred_feat_seq = prompt_context_patches + pred_feat_seq
783
+
784
+ enc_outputs, kv_cache_tuple = self.base_lm(
785
+ inputs_embeds=combined_embed,
786
+ is_causal=True,
787
+ )
788
+ self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
789
+
790
+ enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
791
+ lm_hidden = enc_outputs[:, -1, :]
792
+
793
+ residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
794
+ inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
795
+ is_causal=True,
796
+ )
797
+ self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
798
+ residual_hidden = residual_enc_outputs[:, -1, :]
799
+
800
+ for i in tqdm(range(max_len)):
801
+ dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
802
+ dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
803
+ dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
804
+
805
+ pred_feat = self.feat_decoder(
806
+ mu=dit_hidden,
807
+ patch_size=self.patch_size,
808
+ cond=prefix_feat_cond.transpose(1, 2).contiguous(),
809
+ n_timesteps=inference_timesteps,
810
+ cfg_value=cfg_value,
811
+ ).transpose(
812
+ 1, 2
813
+ ) # [b, p, d]
814
+
815
+ curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
816
+ curr_embed = self.enc_to_lm_proj(curr_embed)
817
+
818
+ pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
819
+ prefix_feat_cond = pred_feat
820
+
821
+ if streaming:
822
+ # return the last three predicted latent features to provide enough context for smooth decoding
823
+ pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
824
+ feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
825
+
826
+ yield feat_pred, pred_feat_seq
827
+
828
+ stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
829
+ if i > min_len and stop_flag == 1:
830
+ break
831
+
832
+ lm_hidden = self.base_lm.forward_step(
833
+ curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
834
+ ).clone()
835
+
836
+ lm_hidden = self.fsq_layer(lm_hidden)
837
+ residual_hidden = self.residual_lm.forward_step(
838
+ lm_hidden + curr_embed[:, 0, :],
839
+ torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
840
+ ).clone()
841
+
842
+ if not streaming:
843
+ pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
844
+ feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
845
+ yield feat_pred, pred_feat_seq.squeeze(0).cpu()
846
+
847
+ @classmethod
848
+ def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
849
+ config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
850
+ tokenizer = LlamaTokenizerFast.from_pretrained(path)
851
+ audio_vae_config = getattr(config, "audio_vae_config", None)
852
+ audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
853
+ # Try to load AudioVAE from safetensors first, fallback to pytorch
854
+ audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
855
+ audiovae_pth_path = os.path.join(path, "audiovae.pth")
856
+ if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
857
+ print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
858
+ vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
859
+ elif os.path.exists(audiovae_pth_path):
860
+ print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
861
+ checkpoint = torch.load(
862
+ audiovae_pth_path,
863
+ map_location="cpu",
864
+ weights_only=True,
865
+ )
866
+ vae_state_dict = checkpoint.get("state_dict", checkpoint)
867
+ else:
868
+ raise FileNotFoundError(
869
+ f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
870
+ )
871
+ model = cls(config, tokenizer, audio_vae, lora_config)
872
+ if not training:
873
+ lm_dtype = get_dtype(model.config.dtype)
874
+ model = model.to(lm_dtype)
875
+ else: # training mode
876
+ for name, param in model.named_parameters():
877
+ if "audio_vae" in name: # freeze VAE weights
878
+ param.requires_grad = False
879
+ continue
880
+ if lora_config is not None:
881
+ if "lora" not in name: # freeze non-LoRA weights
882
+ param.requires_grad = False
883
+ model.audio_vae = model.audio_vae.to(torch.float32)
884
+
885
+ # Try to load from safetensors first, fallback to pytorch_model.bin
886
+ safetensors_path = os.path.join(path, "model.safetensors")
887
+ pytorch_model_path = os.path.join(path, "pytorch_model.bin")
888
+
889
+ if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
890
+ print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
891
+ model_state_dict = load_file(safetensors_path)
892
+ elif os.path.exists(pytorch_model_path):
893
+ print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
894
+ checkpoint = torch.load(
895
+ pytorch_model_path,
896
+ map_location="cpu",
897
+ weights_only=True,
898
+ )
899
+ model_state_dict = checkpoint.get("state_dict", checkpoint)
900
+ else:
901
+ raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
902
+
903
+ for kw, val in vae_state_dict.items():
904
+ model_state_dict[f"audio_vae.{kw}"] = val
905
+
906
+ # LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
907
+ # Using strict=False since pretrained weights don't contain lora_A/lora_B.
908
+ model.load_state_dict(model_state_dict, strict=False)
909
+ if training:
910
+ return model
911
+ return model.to(model.device).eval().optimize(disable=not optimize)
912
+
913
+ # ------------------------------------------------------------------ #
914
+ # LoRA Weight Management
915
+ # ------------------------------------------------------------------ #
916
+ def _iter_lora_modules(self):
917
+ """Iterate over all LoRA modules."""
918
+ from ..modules.layers.lora import LoRALinear
919
+
920
+ for module in self.modules():
921
+ if isinstance(module, LoRALinear):
922
+ yield module
923
+
924
+ def load_lora_weights(self, lora_path: str, device: str = None):
925
+ """
926
+ Load LoRA weights from file, supports calling after torch.compile.
927
+ Uses named_parameters() to handle compile's _orig_mod wrapper.
928
+ Supports both safetensors and pytorch formats.
929
+
930
+ Args:
931
+ lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
932
+ device: Target device, defaults to model's current device
933
+ Returns:
934
+ tuple: (loaded_keys, skipped_keys)
935
+ """
936
+ from pathlib import Path
937
+
938
+ device = device or self.device
939
+ lora_p = Path(lora_path)
940
+
941
+ # Try safetensors first, then fallback to .ckpt
942
+ if lora_p.is_dir():
943
+ safetensors_file = lora_p / "lora_weights.safetensors"
944
+ ckpt_file = lora_p / "lora_weights.ckpt"
945
+ else:
946
+ safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
947
+ ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
948
+
949
+ # Load from safetensors if available
950
+ if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
951
+ state_dict = load_file(str(safetensors_file), device=device)
952
+ elif ckpt_file and ckpt_file.exists():
953
+ ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
954
+ state_dict = ckpt.get("state_dict", ckpt)
955
+ else:
956
+ raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
957
+
958
+ # Build param mapping (handle torch.compile's _orig_mod prefix)
959
+ model_params = dict(self.named_parameters())
960
+ key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
961
+
962
+ loaded_keys, skipped_keys = [], []
963
+ for key, value in state_dict.items():
964
+ target_key = key if key in model_params else key_mapping.get(key)
965
+ if target_key:
966
+ model_params[target_key].data.copy_(value.to(device))
967
+ loaded_keys.append(key)
968
+ else:
969
+ skipped_keys.append(key)
970
+
971
+ return loaded_keys, skipped_keys
972
+
973
+ def set_lora_enabled(self, enabled: bool):
974
+ """Enable/disable all LoRA layers."""
975
+ for module in self._iter_lora_modules():
976
+ module.set_enabled(enabled)
977
+
978
+ def reset_lora_weights(self):
979
+ """Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
980
+ for module in self._iter_lora_modules():
981
+ module.reset_lora_parameters()
982
+
983
+ def get_lora_state_dict(self) -> dict:
984
+ """Get all LoRA parameters (lora_A/lora_B)."""
985
+ return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
voxcpm/model/voxcpm2.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoxCPM: A Tokenizer-free speech generation model
3
+
4
+ This module contains the main VoxCPM model implementation, including configuration classes
5
+ and the core VoxCPMModel for text-to-speech generation.
6
+
7
+ Copyright 2026 OpenBMB
8
+ Licensed under the Apache License, Version 2.0 (the "License");
9
+ you may not use this file except in compliance with the License.
10
+ You may obtain a copy of the License at
11
+
12
+ http://www.apache.org/licenses/LICENSE-2.0
13
+
14
+ Unless required by applicable law or agreed to in writing, software
15
+ distributed under the License is distributed on an "AS IS" BASIS,
16
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ See the License for the specific language governing permissions and
18
+ limitations under the License.
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ from typing import Tuple, Union, Generator, List, Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import warnings
28
+ import librosa
29
+ import numpy as np
30
+ from einops import rearrange
31
+ from pydantic import BaseModel
32
+
33
+ try:
34
+ from safetensors.torch import load_file
35
+
36
+ SAFETENSORS_AVAILABLE = True
37
+ except ImportError:
38
+ SAFETENSORS_AVAILABLE = False
39
+ from tqdm import tqdm
40
+ from transformers import LlamaTokenizerFast
41
+
42
+ from ..modules.audiovae import AudioVAEV2, AudioVAEConfigV2
43
+ from ..modules.layers import ScalarQuantizationLayer
44
+ from ..modules.layers.lora import apply_lora_to_named_linear_modules
45
+ from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
46
+ from ..modules.locenc import VoxCPMLocEnc
47
+ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
48
+ from .utils import get_dtype, mask_multichar_chinese_tokens
49
+
50
+
51
+ def _trim_audio_silence_vad(
52
+ audio: torch.Tensor,
53
+ sample_rate: int,
54
+ max_silence_ms: float = 200.0,
55
+ top_db: float = 35.0,
56
+ ) -> torch.Tensor:
57
+ """使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
58
+
59
+ 会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
60
+
61
+ Args:
62
+ audio: (1, T) 的音频 tensor
63
+ sample_rate: 采样率
64
+ max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
65
+ top_db: 低于参考电平多少 dB 视为静音
66
+
67
+ Returns:
68
+ 截取后的 (1, T') tensor
69
+ """
70
+ if audio.numel() == 0:
71
+ return audio
72
+ y = audio.squeeze(0).numpy()
73
+ n = len(y)
74
+ frame_length = 2048
75
+ hop_length = 512
76
+ ref = np.max(np.abs(y))
77
+ if ref <= 0:
78
+ return audio
79
+ threshold = ref * (10.0 ** (-top_db / 20.0))
80
+
81
+ try:
82
+ _, (start, end) = librosa.effects.trim(
83
+ y, top_db=top_db, ref=np.max, frame_length=frame_length, hop_length=hop_length
84
+ )
85
+ except Exception:
86
+ start, end = 0, n
87
+
88
+ # 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
89
+ n_frames = max(0, (n - frame_length) // hop_length + 1)
90
+ last_voice_frame = -1
91
+ for j in range(n_frames):
92
+ idx = j * hop_length
93
+ if idx + frame_length > n:
94
+ break
95
+ rms = np.sqrt(np.mean(y[idx : idx + frame_length] ** 2))
96
+ if rms >= threshold:
97
+ last_voice_frame = j
98
+ if last_voice_frame >= 0:
99
+ end_by_vad = min(n, (last_voice_frame + 1) * hop_length + (frame_length - hop_length))
100
+ end = min(end, end_by_vad)
101
+
102
+ max_silence_samples = int(max_silence_ms * sample_rate / 1000.0)
103
+ new_start = max(0, start - max_silence_samples)
104
+ new_end = min(n, end + max_silence_samples)
105
+ return audio[:, new_start:new_end]
106
+
107
+
108
+ class VoxCPMEncoderConfig(BaseModel):
109
+ hidden_dim: int = 1024
110
+ ffn_dim: int = 4096
111
+ num_heads: int = 16
112
+ num_layers: int = 4
113
+ kv_channels: int = None
114
+
115
+
116
+ class VoxCPMDitConfig(BaseModel):
117
+ hidden_dim: int = 1024
118
+ ffn_dim: int = 4096
119
+ num_heads: int = 16
120
+ num_layers: int = 4
121
+ kv_channels: int = None
122
+ dit_mean_mode: bool = False
123
+
124
+ cfm_config: CfmConfig
125
+
126
+
127
+ class VoxCPMConfig(BaseModel):
128
+ lm_config: MiniCPM4Config
129
+ patch_size: int = 4
130
+ feat_dim: int = 64
131
+ residual_lm_num_layers: int = 8
132
+ residual_lm_no_rope: bool = False
133
+ scalar_quantization_latent_dim: int = 512
134
+ scalar_quantization_scale: int = 9
135
+
136
+ encoder_config: VoxCPMEncoderConfig
137
+ dit_config: VoxCPMDitConfig
138
+ audio_vae_config: Optional[AudioVAEConfigV2] = None
139
+
140
+ max_length: int = 8192
141
+ device: str = "cuda"
142
+ dtype: str = "bfloat16"
143
+
144
+
145
+ class LoRAConfig(BaseModel):
146
+ enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
147
+ enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
148
+ enable_proj: bool = False # Apply LoRA to projection Linear layers
149
+
150
+ r: int = 8
151
+ alpha: int = 16
152
+ dropout: float = 0.0
153
+
154
+ # Target linear layer names for LM & DiT (matched by attribute name)
155
+ target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
156
+ target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
157
+ # Projection layer attribute names to find on VoxCPM2Model
158
+ target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj", "fusion_concat_proj"]
159
+
160
+
161
+ VoxCPMConfig.model_rebuild()
162
+
163
+
164
+ class VoxCPM2Model(nn.Module):
165
+ def __init__(
166
+ self,
167
+ config: VoxCPMConfig,
168
+ tokenizer: LlamaTokenizerFast,
169
+ audio_vae: AudioVAEV2,
170
+ lora_config: LoRAConfig = None,
171
+ ):
172
+ super().__init__()
173
+ self.config = config
174
+ self.lora_config = lora_config
175
+ self.feat_dim = config.feat_dim
176
+ self.patch_size = config.patch_size
177
+ self.device = config.device
178
+ if not torch.cuda.is_available():
179
+ if torch.backends.mps.is_available():
180
+ self.device = "mps"
181
+ else:
182
+ self.device = "cpu"
183
+ print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
184
+
185
+ # Text-Semantic LM
186
+ self.base_lm = MiniCPMModel(config.lm_config)
187
+ self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
188
+
189
+ self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
190
+ self.audio_start_token = 101
191
+ self.audio_end_token = 102
192
+ self.ref_audio_start_token = 103
193
+ self.ref_audio_end_token = 104
194
+
195
+ # Residual Acoustic LM
196
+ residual_lm_config = config.lm_config.model_copy(deep=True)
197
+ residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
198
+ residual_lm_config.vocab_size = 0
199
+ residual_lm_config.no_rope = config.residual_lm_no_rope
200
+ self.residual_lm = MiniCPMModel(residual_lm_config)
201
+ self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
202
+
203
+ # Local Encoder
204
+ encoder_config = config.lm_config.model_copy(deep=True)
205
+ encoder_config.hidden_size = config.encoder_config.hidden_dim
206
+ encoder_config.intermediate_size = config.encoder_config.ffn_dim
207
+ encoder_config.num_attention_heads = config.encoder_config.num_heads
208
+ encoder_config.num_hidden_layers = config.encoder_config.num_layers
209
+ encoder_config.kv_channels = config.encoder_config.kv_channels
210
+ encoder_config.vocab_size = 0
211
+ self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
212
+
213
+ # Local DiT
214
+ decoder_config = config.lm_config.model_copy(deep=True)
215
+ decoder_config.hidden_size = config.dit_config.hidden_dim
216
+ decoder_config.intermediate_size = config.dit_config.ffn_dim
217
+ decoder_config.num_attention_heads = config.dit_config.num_heads
218
+ decoder_config.num_hidden_layers = config.dit_config.num_layers
219
+ decoder_config.kv_channels = config.dit_config.kv_channels
220
+ decoder_config.vocab_size = 0
221
+ self.feat_decoder = UnifiedCFM(
222
+ in_channels=config.feat_dim,
223
+ cfm_params=config.dit_config.cfm_config,
224
+ estimator=VoxCPMLocDiTV2(decoder_config, in_channels=config.feat_dim),
225
+ mean_mode=config.dit_config.dit_mean_mode,
226
+ )
227
+
228
+ # Projection layers
229
+ self.fsq_layer = ScalarQuantizationLayer(
230
+ config.lm_config.hidden_size,
231
+ config.lm_config.hidden_size,
232
+ config.scalar_quantization_latent_dim,
233
+ config.scalar_quantization_scale,
234
+ )
235
+ self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
236
+ self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
237
+ self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
238
+ self.fusion_concat_proj = nn.Linear(config.lm_config.hidden_size * 2, config.lm_config.hidden_size)
239
+
240
+ # Stop Predictor
241
+ self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
242
+ self.stop_actn = nn.SiLU()
243
+ self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
244
+ self.stop_loss = nn.CrossEntropyLoss(reduction="none")
245
+
246
+ # Audio VAE
247
+ self.audio_vae = audio_vae
248
+ self.chunk_size = audio_vae.chunk_size
249
+ self._encode_sample_rate = audio_vae.sample_rate
250
+ self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
251
+
252
+ if self.lora_config is not None:
253
+ self._apply_lora()
254
+
255
+ def _apply_lora(self):
256
+ """注入 LoRA 到 LM / DiT / 投影层"""
257
+ cfg = self.lora_config
258
+ lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
259
+
260
+ # LM: base_lm + residual_lm
261
+ if cfg.enable_lm:
262
+ for lm in [self.base_lm, self.residual_lm]:
263
+ apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
264
+
265
+ # DiT: feat_decoder.estimator
266
+ if cfg.enable_dit:
267
+ apply_lora_to_named_linear_modules(
268
+ self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
269
+ )
270
+
271
+ # 投影层
272
+ if cfg.enable_proj:
273
+ from ..modules.layers.lora import LoRALinear
274
+
275
+ for attr_name in cfg.target_proj_modules:
276
+ module = getattr(self, attr_name, None)
277
+ if isinstance(module, nn.Linear):
278
+ setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
279
+
280
+ def optimize(self, disable: bool = False):
281
+ if disable:
282
+ return self
283
+ try:
284
+ if self.device != "cuda":
285
+ raise ValueError("VoxCPMModel can only be optimized on CUDA device")
286
+ try:
287
+ import triton # noqa: F401
288
+ except ImportError:
289
+ raise ValueError("triton is not installed")
290
+ self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
291
+ self.residual_lm.forward_step = torch.compile(
292
+ self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
293
+ )
294
+ self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
295
+ self.feat_decoder.estimator = torch.compile(
296
+ self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
297
+ )
298
+ except Exception as e:
299
+ print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
300
+ return self
301
+
302
+ def forward(
303
+ self,
304
+ text_tokens: torch.Tensor,
305
+ text_mask: torch.Tensor,
306
+ audio_feats: torch.Tensor,
307
+ audio_mask: torch.Tensor,
308
+ loss_mask: torch.Tensor,
309
+ position_ids: torch.Tensor,
310
+ labels: torch.Tensor,
311
+ *,
312
+ progress: float = 0.0,
313
+ sample_generate: bool = False,
314
+ ):
315
+ del position_ids # not used yet
316
+
317
+ text_tokens = text_tokens.to(self.device, dtype=torch.long)
318
+ text_mask = text_mask.to(self.device, dtype=self._dtype())
319
+ audio_feats = audio_feats.to(self.device, dtype=self._dtype())
320
+ audio_mask = audio_mask.to(self.device, dtype=self._dtype())
321
+ loss_mask = loss_mask.to(self.device, dtype=self._dtype())
322
+ labels = labels.to(self.device, dtype=torch.long)
323
+
324
+ B, T, P, D = audio_feats.shape
325
+ feat_embed = self.feat_encoder(audio_feats)
326
+ feat_embed = self.enc_to_lm_proj(feat_embed)
327
+
328
+ scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
329
+ if not getattr(self.config.lm_config, "use_mup", False):
330
+ scale_emb = 1.0
331
+ text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
332
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
333
+
334
+ enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
335
+ enc_outputs = enc_outputs.to(self._dtype())
336
+ enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
337
+ lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
338
+
339
+ residual_inputs = self.fusion_concat_proj(
340
+ torch.cat((enc_outputs, audio_mask.unsqueeze(-1) * feat_embed), dim=-1)
341
+ )
342
+ residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
343
+ residual_outputs = residual_outputs.to(self._dtype())
344
+ residual_hidden = torch.cat(
345
+ (torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
346
+ dim=1,
347
+ )
348
+
349
+ dit_hidden = torch.cat((self.lm_to_dit_proj(lm_hidden), self.res_to_dit_proj(residual_hidden)), dim=-1)
350
+ dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
351
+
352
+ # Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
353
+ target_dtype = self._dtype()
354
+
355
+ feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
356
+ feat_cond = torch.cat(
357
+ (torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
358
+ dim=1,
359
+ )
360
+ feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
361
+
362
+ loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
363
+ loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
364
+
365
+ diff_loss = self.feat_decoder.compute_loss(
366
+ feat_gt.transpose(1, 2).contiguous(),
367
+ dit_hidden,
368
+ cond=feat_cond.transpose(1, 2).contiguous(),
369
+ tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
370
+ progress=progress,
371
+ )
372
+
373
+ stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
374
+ stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
375
+ denom = torch.clamp(loss_mask.sum(), min=1.0)
376
+ stop_loss = (stop_losses * loss_mask).sum() / denom
377
+
378
+ feat_pred = None
379
+ if sample_generate:
380
+ feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
381
+ feat_pred_seq = self.feat_decoder(
382
+ mu=dit_hidden,
383
+ patch_size=self.patch_size,
384
+ cond=feat_cond_for_sample,
385
+ n_timesteps=(
386
+ self.config.dit_config.cfm_config.inference_cfg_rate
387
+ if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
388
+ else 10
389
+ ),
390
+ )
391
+ feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
392
+
393
+ feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
394
+
395
+ return {
396
+ "loss/diff": diff_loss,
397
+ "loss/stop": stop_loss,
398
+ "feat_gt": feat_gt_tensor,
399
+ "feat_pred": feat_pred,
400
+ }
401
+
402
+ def _dtype(self):
403
+ return get_dtype(self.config.dtype)
404
+
405
+ def _encode_wav(self, wav_path: str, padding_mode: str = "right") -> torch.Tensor:
406
+ """Load, trim, pad and VAE-encode an audio file.
407
+
408
+ Args:
409
+ wav_path: path to the audio file.
410
+ padding_mode: "right" (default) or "left" padding for alignment.
411
+
412
+ Returns:
413
+ audio_feat: (T, P, D) tensor of latent patches.
414
+ """
415
+ audio, _ = librosa.load(wav_path, sr=self._encode_sample_rate, mono=True)
416
+ audio = torch.from_numpy(audio).unsqueeze(0)
417
+ audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
418
+ patch_len = self.patch_size * self.chunk_size
419
+ if audio.size(1) % patch_len != 0:
420
+ padding_size = patch_len - audio.size(1) % patch_len
421
+ pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
422
+ audio = torch.nn.functional.pad(audio, pad)
423
+ feat = self.audio_vae.encode(audio.to(self.device), self._encode_sample_rate).cpu()
424
+ return feat.view(self.audio_vae.latent_dim, -1, self.patch_size).permute(1, 2, 0)
425
+
426
+ def _make_ref_prefix(self, ref_feat: torch.Tensor, device: torch.device):
427
+ """Build the [ref_start ref_audio ref_end] prefix segments.
428
+
429
+ Returns:
430
+ tokens, feats, text_mask, audio_mask (all 1-D / 2-D tensors)
431
+ """
432
+ ref_len = ref_feat.size(0)
433
+ z1 = torch.zeros((1, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32, device=device)
434
+ tokens = torch.cat(
435
+ [
436
+ torch.tensor([self.ref_audio_start_token], dtype=torch.int32, device=device),
437
+ torch.zeros(ref_len, dtype=torch.int32, device=device),
438
+ torch.tensor([self.ref_audio_end_token], dtype=torch.int32, device=device),
439
+ ]
440
+ )
441
+ feats = torch.cat([z1, ref_feat, z1], dim=0)
442
+ t_mask = torch.cat(
443
+ [
444
+ torch.tensor([1], dtype=torch.int32),
445
+ torch.zeros(ref_len, dtype=torch.int32),
446
+ torch.tensor([1], dtype=torch.int32),
447
+ ]
448
+ ).to(device)
449
+ a_mask = torch.cat(
450
+ [
451
+ torch.tensor([0], dtype=torch.int32),
452
+ torch.ones(ref_len, dtype=torch.int32),
453
+ torch.tensor([0], dtype=torch.int32),
454
+ ]
455
+ ).to(device)
456
+ return tokens, feats, t_mask, a_mask
457
+
458
+ def generate(self, *args, **kwargs) -> torch.Tensor:
459
+ return next(self._generate(*args, streaming=False, **kwargs))
460
+
461
+ def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
462
+ return self._generate(*args, streaming=True, **kwargs)
463
+
464
+ @torch.inference_mode()
465
+ def _generate(
466
+ self,
467
+ target_text: str,
468
+ prompt_text: str = "",
469
+ prompt_wav_path: str = "",
470
+ reference_wav_path: str = "",
471
+ min_len: int = 2,
472
+ max_len: int = 2000,
473
+ inference_timesteps: int = 10,
474
+ cfg_value: float = 2.0,
475
+ retry_badcase: bool = False,
476
+ retry_badcase_max_times: int = 3,
477
+ retry_badcase_ratio_threshold: float = 6.0,
478
+ streaming: bool = False,
479
+ streaming_prefix_len: int = 4,
480
+ ) -> Generator[torch.Tensor, None, None]:
481
+ if retry_badcase and streaming:
482
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
483
+ retry_badcase = False
484
+
485
+ if reference_wav_path and prompt_wav_path:
486
+ # Combined mode: reference isolation prefix + continuation suffix
487
+ text = prompt_text + target_text
488
+ text_token = torch.LongTensor(self.text_tokenizer(text))
489
+ text_token = torch.cat(
490
+ [
491
+ text_token,
492
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
493
+ ],
494
+ dim=-1,
495
+ )
496
+ text_length = text_token.shape[0]
497
+
498
+ ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
499
+ prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
500
+ prompt_audio_length = prompt_feat.size(0)
501
+
502
+ ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
503
+
504
+ prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
505
+ text_pad_feat = torch.zeros(
506
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
507
+ dtype=torch.float32,
508
+ device=text_token.device,
509
+ )
510
+
511
+ text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
512
+ audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_feat], dim=0)
513
+ text_mask = torch.cat(
514
+ [
515
+ ref_t_mask,
516
+ torch.ones(text_length, dtype=torch.int32).to(text_token.device),
517
+ torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
518
+ ]
519
+ )
520
+ audio_mask = torch.cat(
521
+ [
522
+ ref_a_mask,
523
+ torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
524
+ torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
525
+ ]
526
+ )
527
+
528
+ elif reference_wav_path:
529
+ # Reference-only mode (prompt isolation)
530
+ text = target_text
531
+ text_token = torch.LongTensor(self.text_tokenizer(text))
532
+ text_token = torch.cat(
533
+ [
534
+ text_token,
535
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
536
+ ],
537
+ dim=-1,
538
+ )
539
+ text_length = text_token.shape[0]
540
+
541
+ ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
542
+ ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
543
+
544
+ text_pad_feat = torch.zeros(
545
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
546
+ dtype=torch.float32,
547
+ device=text_token.device,
548
+ )
549
+ text_token = torch.cat([ref_tokens, text_token])
550
+ audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
551
+ text_mask = torch.cat(
552
+ [
553
+ ref_t_mask,
554
+ torch.ones(text_length, dtype=torch.int32).to(text_token.device),
555
+ ]
556
+ )
557
+ audio_mask = torch.cat(
558
+ [
559
+ ref_a_mask,
560
+ torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
561
+ ]
562
+ )
563
+
564
+ elif len(prompt_wav_path) == 0:
565
+ # Zero-shot mode
566
+ text = target_text
567
+ text_token = torch.LongTensor(self.text_tokenizer(text))
568
+ text_token = torch.cat(
569
+ [
570
+ text_token,
571
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
572
+ ],
573
+ dim=-1,
574
+ )
575
+ text_length = text_token.shape[0]
576
+
577
+ audio_feat = torch.zeros(
578
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
579
+ dtype=torch.float32,
580
+ device=text_token.device,
581
+ )
582
+ text_mask = torch.ones(text_length, dtype=torch.int32).to(text_token.device)
583
+ audio_mask = torch.zeros(text_length, dtype=torch.int32).to(text_token.device)
584
+
585
+ else:
586
+ # Continuation-only mode
587
+ text = prompt_text + target_text
588
+ text_token = torch.LongTensor(self.text_tokenizer(text))
589
+ text_token = torch.cat(
590
+ [
591
+ text_token,
592
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
593
+ ],
594
+ dim=-1,
595
+ )
596
+ text_length = text_token.shape[0]
597
+
598
+ prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
599
+ prompt_audio_length = prompt_feat.size(0)
600
+ prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
601
+ text_pad_feat = torch.zeros(
602
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
603
+ dtype=torch.float32,
604
+ device=text_token.device,
605
+ )
606
+ text_token = torch.cat([text_token, prompt_pad_token])
607
+ audio_feat = torch.cat([text_pad_feat, prompt_feat], dim=0)
608
+ text_mask = torch.cat(
609
+ [
610
+ torch.ones(text_length, dtype=torch.int32),
611
+ torch.zeros(prompt_audio_length, dtype=torch.int32),
612
+ ]
613
+ ).to(text_token.device)
614
+ audio_mask = torch.cat(
615
+ [
616
+ torch.zeros(text_length, dtype=torch.int32),
617
+ torch.ones(prompt_audio_length, dtype=torch.int32),
618
+ ]
619
+ ).to(text_token.device)
620
+
621
+ text_token = text_token.unsqueeze(0).to(self.device)
622
+ text_mask = text_mask.unsqueeze(0).to(self.device)
623
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
624
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
625
+
626
+ target_text_length = len(self.text_tokenizer(target_text))
627
+
628
+ retry_badcase_times = 0
629
+ while retry_badcase_times < retry_badcase_max_times:
630
+ inference_result = self._inference(
631
+ text_token,
632
+ text_mask,
633
+ audio_feat,
634
+ audio_mask,
635
+ min_len=min_len,
636
+ max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
637
+ inference_timesteps=inference_timesteps,
638
+ cfg_value=cfg_value,
639
+ streaming=streaming,
640
+ streaming_prefix_len=streaming_prefix_len,
641
+ )
642
+ if streaming:
643
+ out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
644
+ for latent_pred, _ in inference_result:
645
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
646
+ decode_audio = decode_audio[..., -out_patch_len:].squeeze(1).cpu()
647
+ yield decode_audio
648
+ break
649
+ else:
650
+ latent_pred, pred_audio_feat = next(inference_result)
651
+ if retry_badcase:
652
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
653
+ print(
654
+ f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
655
+ file=sys.stderr,
656
+ )
657
+ retry_badcase_times += 1
658
+ continue
659
+ else:
660
+ break
661
+ else:
662
+ break
663
+
664
+ if not streaming:
665
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
666
+ out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
667
+ has_continuation = bool(prompt_wav_path)
668
+ if has_continuation:
669
+ decode_audio = decode_audio[..., out_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
670
+ else:
671
+ decode_audio = decode_audio.squeeze(1).cpu()
672
+ yield decode_audio
673
+
674
+ @torch.inference_mode()
675
+ def build_prompt_cache(
676
+ self,
677
+ prompt_text: str = None,
678
+ prompt_wav_path: str = None,
679
+ reference_wav_path: str = None,
680
+ ):
681
+ """
682
+ Build prompt cache for subsequent generation.
683
+
684
+ Supports the same parameter combinations as ``generate()``:
685
+ - ``reference_wav_path`` only -> reference mode (voice cloning, isolated)
686
+ - ``prompt_text`` + ``prompt_wav_path`` -> continuation mode
687
+ - all three -> combined ref + continuation mode
688
+
689
+ Args:
690
+ prompt_text: prompt text for continuation mode.
691
+ Must be paired with ``prompt_wav_path``.
692
+ prompt_wav_path: prompt audio path for continuation mode.
693
+ Must be paired with ``prompt_text``.
694
+ reference_wav_path: reference audio path for voice cloning
695
+ (structurally isolated via ref_audio tokens).
696
+
697
+ Returns:
698
+ prompt_cache: dict used by ``_generate_with_prompt_cache``.
699
+ """
700
+ if (prompt_wav_path is None) != (prompt_text is None):
701
+ raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
702
+ if prompt_wav_path is None and reference_wav_path is None:
703
+ raise ValueError("At least one of prompt_wav_path or reference_wav_path must be provided")
704
+
705
+ cache = {}
706
+
707
+ if reference_wav_path:
708
+ cache["ref_audio_feat"] = self._encode_wav(reference_wav_path, padding_mode="right")
709
+
710
+ if prompt_wav_path and prompt_text is not None:
711
+ cache["prompt_text"] = prompt_text
712
+ cache["audio_feat"] = self._encode_wav(prompt_wav_path, padding_mode="left")
713
+
714
+ has_ref = "ref_audio_feat" in cache
715
+ has_prompt = "audio_feat" in cache
716
+ if has_ref and has_prompt:
717
+ cache["mode"] = "ref_continuation"
718
+ elif has_ref:
719
+ cache["mode"] = "reference"
720
+ else:
721
+ cache["mode"] = "continuation"
722
+
723
+ return cache
724
+
725
+ def merge_prompt_cache(
726
+ self,
727
+ original_cache: dict,
728
+ new_text: str,
729
+ new_audio_feat: torch.Tensor,
730
+ ):
731
+ """
732
+ Merge original prompt cache with newly generated content to stabilize voice.
733
+
734
+ Args:
735
+ original_cache: original prompt cache (any mode)
736
+ new_text: newly generated text
737
+ new_audio_feat: newly generated audio features
738
+
739
+ Returns:
740
+ merged_cache: merged cache with prompt_text and audio_feat
741
+ """
742
+ if original_cache is None:
743
+ return {
744
+ "prompt_text": new_text,
745
+ "audio_feat": new_audio_feat,
746
+ "mode": "continuation",
747
+ }
748
+ merged = {}
749
+ if "ref_audio_feat" in original_cache:
750
+ merged["ref_audio_feat"] = original_cache["ref_audio_feat"]
751
+ merged["prompt_text"] = original_cache.get("prompt_text", "") + new_text
752
+ old_feat = original_cache.get("audio_feat", new_audio_feat.new_empty(0, *new_audio_feat.shape[1:]))
753
+ merged["audio_feat"] = torch.cat([old_feat, new_audio_feat], dim=0)
754
+ merged["mode"] = "ref_continuation" if "ref_audio_feat" in merged else "continuation"
755
+ return merged
756
+
757
+ def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
758
+ return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
759
+
760
+ def generate_with_prompt_cache_streaming(
761
+ self, *args, **kwargs
762
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
763
+ return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
764
+
765
+ @torch.inference_mode()
766
+ def _generate_with_prompt_cache(
767
+ self,
768
+ target_text: str,
769
+ prompt_cache: dict,
770
+ min_len: int = 2,
771
+ max_len: int = 2000,
772
+ inference_timesteps: int = 10,
773
+ cfg_value: float = 2.0,
774
+ retry_badcase: bool = False,
775
+ retry_badcase_max_times: int = 3,
776
+ retry_badcase_ratio_threshold: float = 6.0,
777
+ streaming: bool = False,
778
+ streaming_prefix_len: int = 4,
779
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
780
+ """
781
+ Generate audio using pre-built prompt cache.
782
+
783
+ Args:
784
+ target_text: Text to convert to speech
785
+ prompt_cache: Cache built by ``build_prompt_cache()``. Can be None
786
+ for zero-shot generation.
787
+ min_len: Minimum audio length to avoid very short audio
788
+ max_len: Maximum audio length
789
+ inference_timesteps: Number of diffusion sampling steps
790
+ cfg_value: Classifier-free guidance value
791
+ retry_badcase: Whether to retry on bad cases
792
+ retry_badcase_max_times: Maximum retry attempts
793
+ retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
794
+ streaming: Whether to return a generator of audio chunks
795
+ streaming_prefix_len: Number of prefix audio patches to use for streaming mode
796
+
797
+ Returns:
798
+ Generator of Tuple containing:
799
+ - Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
800
+ - Tensor of new text tokens
801
+ - New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
802
+ """
803
+ if retry_badcase and streaming:
804
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
805
+ retry_badcase = False
806
+
807
+ # Determine mode from cache
808
+ if prompt_cache is None:
809
+ mode = "zero_shot"
810
+ text = target_text
811
+ else:
812
+ mode = prompt_cache.get("mode", "continuation")
813
+ if mode in ("continuation", "ref_continuation"):
814
+ prompt_text = prompt_cache.get("prompt_text", "")
815
+ text = prompt_text + target_text
816
+ else:
817
+ text = target_text
818
+
819
+ text_token = torch.LongTensor(self.text_tokenizer(text))
820
+ text_token = torch.cat(
821
+ [
822
+ text_token,
823
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
824
+ ],
825
+ dim=-1,
826
+ )
827
+
828
+ target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
829
+ text_length = text_token.shape[0]
830
+
831
+ if mode in ("zero_shot", "continuation"):
832
+ prompt_audio_feat = (
833
+ prompt_cache["audio_feat"]
834
+ if prompt_cache
835
+ else torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
836
+ )
837
+ audio_length = prompt_audio_feat.size(0)
838
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
839
+ text_pad_feat = torch.zeros(
840
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
841
+ dtype=torch.float32,
842
+ device=text_token.device,
843
+ )
844
+ text_token = torch.cat([text_token, text_pad_token])
845
+ audio_feat = torch.cat([text_pad_feat, prompt_audio_feat], dim=0)
846
+ text_mask = torch.cat(
847
+ [torch.ones(text_length, dtype=torch.int32), torch.zeros(audio_length, dtype=torch.int32)]
848
+ ).to(text_token.device)
849
+ audio_mask = torch.cat(
850
+ [torch.zeros(text_length, dtype=torch.int32), torch.ones(audio_length, dtype=torch.int32)]
851
+ ).to(text_token.device)
852
+
853
+ elif mode == "reference":
854
+ ref_audio_feat = prompt_cache["ref_audio_feat"]
855
+ ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
856
+ text_pad_feat = torch.zeros(
857
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
858
+ dtype=torch.float32,
859
+ device=text_token.device,
860
+ )
861
+ text_token = torch.cat([ref_tokens, text_token])
862
+ audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
863
+ text_mask = torch.cat([ref_t_mask, torch.ones(text_length, dtype=torch.int32).to(text_token.device)])
864
+ audio_mask = torch.cat([ref_a_mask, torch.zeros(text_length, dtype=torch.int32).to(text_token.device)])
865
+
866
+ else:
867
+ # ref_continuation mode
868
+ ref_audio_feat = prompt_cache["ref_audio_feat"]
869
+ prompt_audio_feat = prompt_cache["audio_feat"]
870
+ prompt_audio_length = prompt_audio_feat.size(0)
871
+
872
+ ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
873
+
874
+ prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
875
+ text_pad_feat = torch.zeros(
876
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
877
+ dtype=torch.float32,
878
+ device=text_token.device,
879
+ )
880
+
881
+ text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
882
+ audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_audio_feat], dim=0)
883
+ text_mask = torch.cat(
884
+ [
885
+ ref_t_mask,
886
+ torch.ones(text_length, dtype=torch.int32).to(text_token.device),
887
+ torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
888
+ ]
889
+ )
890
+ audio_mask = torch.cat(
891
+ [
892
+ ref_a_mask,
893
+ torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
894
+ torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
895
+ ]
896
+ )
897
+
898
+ text_token = text_token.unsqueeze(0).to(self.device)
899
+ text_mask = text_mask.unsqueeze(0).to(self.device)
900
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
901
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
902
+
903
+ # run inference
904
+ target_text_length = len(self.text_tokenizer(target_text))
905
+ retry_badcase_times = 0
906
+ while retry_badcase_times < retry_badcase_max_times:
907
+ inference_result = self._inference(
908
+ text_token,
909
+ text_mask,
910
+ audio_feat,
911
+ audio_mask,
912
+ min_len=min_len,
913
+ max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
914
+ inference_timesteps=inference_timesteps,
915
+ cfg_value=cfg_value,
916
+ streaming=streaming,
917
+ streaming_prefix_len=streaming_prefix_len,
918
+ )
919
+ if streaming:
920
+ out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
921
+ for latent_pred, pred_audio_feat in inference_result:
922
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
923
+ decode_audio = decode_audio[..., -out_patch_len:].squeeze(1).cpu()
924
+ yield (decode_audio, target_text_token, pred_audio_feat)
925
+ break
926
+ else:
927
+ latent_pred, pred_audio_feat = next(inference_result)
928
+ if retry_badcase:
929
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
930
+ print(
931
+ f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
932
+ file=sys.stderr,
933
+ )
934
+ retry_badcase_times += 1
935
+ continue
936
+ else:
937
+ break
938
+ else:
939
+ break
940
+ if not streaming:
941
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
942
+ out_patch_len = self.patch_size * self.chunk_size * (self.sample_rate // self._encode_sample_rate)
943
+ if mode in ("continuation", "ref_continuation"):
944
+ decode_audio = decode_audio[..., out_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
945
+ else:
946
+ decode_audio = decode_audio[..., :].squeeze(1).cpu()
947
+ yield (decode_audio, target_text_token, pred_audio_feat)
948
+
949
+ def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
950
+ return next(self._inference(*args, streaming=False, **kwargs))
951
+
952
+ def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
953
+ return self._inference(*args, streaming=True, **kwargs)
954
+
955
+ @torch.inference_mode()
956
+ def _inference(
957
+ self,
958
+ text: torch.Tensor,
959
+ text_mask: torch.Tensor,
960
+ feat: torch.Tensor,
961
+ feat_mask: torch.Tensor,
962
+ min_len: int = 2,
963
+ max_len: int = 2000,
964
+ inference_timesteps: int = 10,
965
+ cfg_value: float = 2.0,
966
+ streaming: bool = False,
967
+ streaming_prefix_len: int = 4,
968
+ ) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
969
+ """Core inference method for audio generation.
970
+
971
+ This is the main inference loop that generates audio features
972
+ using the language model and diffusion transformer.
973
+
974
+ Args:
975
+ text: Input text tokens
976
+ text_mask: Mask for text tokens
977
+ feat: Input audio features
978
+ feat_mask: Mask for audio features
979
+ min_len: Minimum generation length
980
+ max_len: Maximum generation length
981
+ inference_timesteps: Number of diffusion steps
982
+ cfg_value: Classifier-free guidance value
983
+ streaming: Whether to yield each step latent feature or just the final result
984
+
985
+ Returns:
986
+ Generator of Tuple containing:
987
+ - Predicted latent feature at the current step if ``streaming=True``, else final latent features
988
+ - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
989
+ """
990
+ B, T, P, D = feat.shape
991
+
992
+ feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
993
+ feat_embed = self.enc_to_lm_proj(feat_embed)
994
+
995
+ if self.config.lm_config.use_mup:
996
+ scale_emb = self.config.lm_config.scale_emb
997
+ else:
998
+ scale_emb = 1.0
999
+
1000
+ text_embed = self.base_lm.embed_tokens(text) * scale_emb
1001
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
1002
+
1003
+ prefix_feat_cond = feat[:, -1, ...] # b, p, d
1004
+ pred_feat_seq = [] # b, t, p, d
1005
+ curr_embed = None
1006
+
1007
+ # Prepare prompt context patches for streaming mode
1008
+ # - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
1009
+ # trailing audio patches as initial context so the VAE can decode smoothly.
1010
+ # - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
1011
+ has_continuation_audio = feat_mask[0, -1].item() == 1
1012
+ if has_continuation_audio:
1013
+ audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
1014
+ context_len = min(streaming_prefix_len - 1, len(audio_indices))
1015
+ last_audio_indices = audio_indices[-context_len:]
1016
+ pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
1017
+ else:
1018
+ pred_feat_seq = []
1019
+
1020
+ enc_outputs, kv_cache_tuple = self.base_lm(
1021
+ inputs_embeds=combined_embed,
1022
+ is_causal=True,
1023
+ )
1024
+ self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
1025
+
1026
+ enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
1027
+ lm_hidden = enc_outputs[:, -1, :]
1028
+
1029
+ residual_enc_inputs = self.fusion_concat_proj(
1030
+ torch.cat((enc_outputs, feat_mask.unsqueeze(-1) * feat_embed), dim=-1)
1031
+ )
1032
+ residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
1033
+ inputs_embeds=residual_enc_inputs,
1034
+ is_causal=True,
1035
+ )
1036
+ self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
1037
+ residual_hidden = residual_enc_outputs[:, -1, :]
1038
+
1039
+ for i in tqdm(range(max_len)):
1040
+ dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
1041
+ dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
1042
+ dit_hidden = torch.cat((dit_hidden_1, dit_hidden_2), dim=-1)
1043
+
1044
+ pred_feat = self.feat_decoder(
1045
+ mu=dit_hidden,
1046
+ patch_size=self.patch_size,
1047
+ cond=prefix_feat_cond.transpose(1, 2).contiguous(),
1048
+ n_timesteps=inference_timesteps,
1049
+ cfg_value=cfg_value,
1050
+ ).transpose(
1051
+ 1, 2
1052
+ ) # [b, p, d]
1053
+
1054
+ curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
1055
+ curr_embed = self.enc_to_lm_proj(curr_embed)
1056
+
1057
+ pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
1058
+ prefix_feat_cond = pred_feat
1059
+
1060
+ if streaming:
1061
+ # return the last three predicted latent features to provide enough context for smooth decoding
1062
+ pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
1063
+ feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
1064
+
1065
+ yield feat_pred, pred_feat_seq
1066
+
1067
+ stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
1068
+ if i > min_len and stop_flag == 1:
1069
+ break
1070
+
1071
+ lm_hidden = self.base_lm.forward_step(
1072
+ curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
1073
+ ).clone()
1074
+
1075
+ lm_hidden = self.fsq_layer(lm_hidden)
1076
+ curr_residual_input = self.fusion_concat_proj(torch.cat((lm_hidden, curr_embed[:, 0, :]), dim=-1))
1077
+ residual_hidden = self.residual_lm.forward_step(
1078
+ curr_residual_input, torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
1079
+ ).clone()
1080
+
1081
+ if not streaming:
1082
+ pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
1083
+ feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
1084
+ yield feat_pred, pred_feat_seq.squeeze(0).cpu()
1085
+
1086
+ @classmethod
1087
+ def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
1088
+ config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
1089
+ tokenizer = LlamaTokenizerFast.from_pretrained(path)
1090
+ audio_vae_config = getattr(config, "audio_vae_config", None)
1091
+ audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
1092
+ # Try to load AudioVAE from safetensors first, fallback to pytorch
1093
+ audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
1094
+ audiovae_pth_path = os.path.join(path, "audiovae.pth")
1095
+ if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
1096
+ print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
1097
+ vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
1098
+ elif os.path.exists(audiovae_pth_path):
1099
+ print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
1100
+ checkpoint = torch.load(
1101
+ audiovae_pth_path,
1102
+ map_location="cpu",
1103
+ weights_only=True,
1104
+ )
1105
+ vae_state_dict = checkpoint.get("state_dict", checkpoint)
1106
+ else:
1107
+ raise FileNotFoundError(
1108
+ f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
1109
+ )
1110
+ model = cls(config, tokenizer, audio_vae, lora_config)
1111
+ if not training:
1112
+ lm_dtype = get_dtype(model.config.dtype)
1113
+ model = model.to(lm_dtype)
1114
+ else: # training mode
1115
+ for name, param in model.named_parameters():
1116
+ if "audio_vae" in name: # freeze VAE weights
1117
+ param.requires_grad = False
1118
+ continue
1119
+ if lora_config is not None:
1120
+ if "lora" not in name: # freeze non-LoRA weights
1121
+ param.requires_grad = False
1122
+ model.audio_vae = model.audio_vae.to(torch.float32)
1123
+
1124
+ # Try to load from safetensors first, fallback to pytorch_model.bin
1125
+ safetensors_path = os.path.join(path, "model.safetensors")
1126
+ pytorch_model_path = os.path.join(path, "pytorch_model.bin")
1127
+
1128
+ if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
1129
+ print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
1130
+ model_state_dict = load_file(safetensors_path)
1131
+ elif os.path.exists(pytorch_model_path):
1132
+ print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
1133
+ checkpoint = torch.load(
1134
+ pytorch_model_path,
1135
+ map_location="cpu",
1136
+ weights_only=True,
1137
+ )
1138
+ model_state_dict = checkpoint.get("state_dict", checkpoint)
1139
+ else:
1140
+ raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
1141
+
1142
+ for kw, val in vae_state_dict.items():
1143
+ model_state_dict[f"audio_vae.{kw}"] = val
1144
+
1145
+ # LoRALinear keeps weight/bias compatible with nn.Linear but adds
1146
+ # lora_A/lora_B, which are absent from base pretrained checkpoints.
1147
+ model.load_state_dict(model_state_dict, strict=False)
1148
+ if training:
1149
+ return model
1150
+ return model.to(model.device).eval().optimize(disable=not optimize)
1151
+
1152
+ # ------------------------------------------------------------------ #
1153
+ # LoRA Weight Management
1154
+ # ------------------------------------------------------------------ #
1155
+ def _iter_lora_modules(self):
1156
+ """Iterate over all LoRA modules."""
1157
+ from ..modules.layers.lora import LoRALinear
1158
+
1159
+ for module in self.modules():
1160
+ if isinstance(module, LoRALinear):
1161
+ yield module
1162
+
1163
+ def load_lora_weights(self, lora_path: str, device: str = None):
1164
+ """
1165
+ Load LoRA weights from file, supports calling after torch.compile.
1166
+ Uses named_parameters() to handle compile's _orig_mod wrapper.
1167
+ Supports both safetensors and pytorch formats.
1168
+
1169
+ Args:
1170
+ lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
1171
+ device: Target device, defaults to model's current device
1172
+ Returns:
1173
+ tuple: (loaded_keys, skipped_keys)
1174
+ """
1175
+ from pathlib import Path
1176
+
1177
+ device = device or self.device
1178
+ lora_p = Path(lora_path)
1179
+
1180
+ # Try safetensors first, then fallback to .ckpt
1181
+ if lora_p.is_dir():
1182
+ safetensors_file = lora_p / "lora_weights.safetensors"
1183
+ ckpt_file = lora_p / "lora_weights.ckpt"
1184
+ else:
1185
+ safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
1186
+ ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
1187
+
1188
+ # Load from safetensors if available
1189
+ if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
1190
+ state_dict = load_file(str(safetensors_file), device=device)
1191
+ elif ckpt_file and ckpt_file.exists():
1192
+ ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
1193
+ state_dict = ckpt.get("state_dict", ckpt)
1194
+ else:
1195
+ raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
1196
+
1197
+ # Build param mapping (handle torch.compile's _orig_mod prefix)
1198
+ model_params = dict(self.named_parameters())
1199
+ key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
1200
+
1201
+ loaded_keys, skipped_keys = [], []
1202
+ for key, value in state_dict.items():
1203
+ target_key = key if key in model_params else key_mapping.get(key)
1204
+ if target_key:
1205
+ model_params[target_key].data.copy_(value.to(device))
1206
+ loaded_keys.append(key)
1207
+ else:
1208
+ skipped_keys.append(key)
1209
+
1210
+ return loaded_keys, skipped_keys
1211
+
1212
+ def set_lora_enabled(self, enabled: bool):
1213
+ """Enable/disable all LoRA layers."""
1214
+ for module in self._iter_lora_modules():
1215
+ module.set_enabled(enabled)
1216
+
1217
+ def reset_lora_weights(self):
1218
+ """Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
1219
+ for module in self._iter_lora_modules():
1220
+ module.reset_lora_parameters()
1221
+
1222
+ def get_lora_state_dict(self) -> dict:
1223
+ """Get all LoRA parameters (lora_A/lora_B)."""
1224
+ return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
voxcpm/modules/__init__.py ADDED
File without changes
voxcpm/modules/audiovae/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .audio_vae import AudioVAE, AudioVAEConfig
2
+ from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
voxcpm/modules/audiovae/audio_vae.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import weight_norm
9
+ from pydantic import BaseModel
10
+
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+
16
+ def WNConvTranspose1d(*args, **kwargs):
17
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
18
+
19
+
20
+ class CausalConv1d(nn.Conv1d):
21
+ def __init__(self, *args, padding: int = 0, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ self.__padding = padding
24
+
25
+ def forward(self, x):
26
+ x_pad = F.pad(x, (self.__padding * 2, 0))
27
+ return super().forward(x_pad)
28
+
29
+
30
+ class CausalTransposeConv1d(nn.ConvTranspose1d):
31
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.__padding = padding
34
+ self.__output_padding = output_padding
35
+
36
+ def forward(self, x):
37
+ return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
38
+
39
+
40
+ def WNCausalConv1d(*args, **kwargs):
41
+ return weight_norm(CausalConv1d(*args, **kwargs))
42
+
43
+
44
+ def WNCausalTransposeConv1d(*args, **kwargs):
45
+ return weight_norm(CausalTransposeConv1d(*args, **kwargs))
46
+
47
+
48
+ # Scripting this brings model speed up 1.4x
49
+ @torch.jit.script
50
+ def snake(x, alpha):
51
+ shape = x.shape
52
+ x = x.reshape(shape[0], shape[1], -1)
53
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
54
+ x = x.reshape(shape)
55
+ return x
56
+
57
+
58
+ class Snake1d(nn.Module):
59
+ def __init__(self, channels):
60
+ super().__init__()
61
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
62
+
63
+ def forward(self, x):
64
+ return snake(x, self.alpha)
65
+
66
+
67
+ def init_weights(m):
68
+ if isinstance(m, nn.Conv1d):
69
+ nn.init.trunc_normal_(m.weight, std=0.02)
70
+ if m.bias is not None:
71
+ nn.init.constant_(m.bias, 0)
72
+
73
+
74
+ class CausalResidualUnit(nn.Module):
75
+ def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
76
+ super().__init__()
77
+ pad = ((7 - 1) * dilation) // 2
78
+ self.block = nn.Sequential(
79
+ Snake1d(dim),
80
+ WNCausalConv1d(
81
+ dim,
82
+ dim,
83
+ kernel_size=kernel,
84
+ dilation=dilation,
85
+ padding=pad,
86
+ groups=groups,
87
+ ),
88
+ Snake1d(dim),
89
+ WNCausalConv1d(dim, dim, kernel_size=1),
90
+ )
91
+
92
+ def forward(self, x):
93
+ y = self.block(x)
94
+ pad = (x.shape[-1] - y.shape[-1]) // 2
95
+ assert pad == 0
96
+ if pad > 0:
97
+ x = x[..., pad:-pad]
98
+ return x + y
99
+
100
+
101
+ class CausalEncoderBlock(nn.Module):
102
+ def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
103
+ super().__init__()
104
+ input_dim = input_dim or output_dim // 2
105
+ self.block = nn.Sequential(
106
+ CausalResidualUnit(input_dim, dilation=1, groups=groups),
107
+ CausalResidualUnit(input_dim, dilation=3, groups=groups),
108
+ CausalResidualUnit(input_dim, dilation=9, groups=groups),
109
+ Snake1d(input_dim),
110
+ WNCausalConv1d(
111
+ input_dim,
112
+ output_dim,
113
+ kernel_size=2 * stride,
114
+ stride=stride,
115
+ padding=math.ceil(stride / 2),
116
+ ),
117
+ )
118
+
119
+ def forward(self, x):
120
+ return self.block(x)
121
+
122
+
123
+ class CausalEncoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model: int = 64,
127
+ latent_dim: int = 32,
128
+ strides: list = [2, 4, 8, 8],
129
+ depthwise: bool = False,
130
+ ):
131
+ super().__init__()
132
+ # Create first convolution
133
+ self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
134
+
135
+ # Create EncoderBlocks that double channels as they downsample by `stride`
136
+ for stride in strides:
137
+ d_model *= 2
138
+ groups = d_model // 2 if depthwise else 1
139
+ self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
140
+
141
+ groups = d_model if depthwise else 1
142
+
143
+ # Create two convolution, for mu and logvar
144
+ self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
145
+ self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
146
+
147
+ # Wrap black into nn.Sequential
148
+ self.block = nn.Sequential(*self.block)
149
+ self.enc_dim = d_model
150
+
151
+ def forward(self, x):
152
+ hidden_state = self.block(x)
153
+ return {
154
+ "hidden_state": hidden_state,
155
+ "mu": self.fc_mu(hidden_state),
156
+ "logvar": self.fc_logvar(hidden_state),
157
+ }
158
+
159
+
160
+ class NoiseBlock(nn.Module):
161
+ def __init__(self, dim):
162
+ super().__init__()
163
+ self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
164
+
165
+ def forward(self, x):
166
+ B, C, T = x.shape
167
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
168
+ h = self.linear(x)
169
+ n = noise * h
170
+ x = x + n
171
+ return x
172
+
173
+
174
+ class CausalDecoderBlock(nn.Module):
175
+ def __init__(
176
+ self,
177
+ input_dim: int = 16,
178
+ output_dim: int = 8,
179
+ stride: int = 1,
180
+ groups=1,
181
+ use_noise_block: bool = False,
182
+ ):
183
+ super().__init__()
184
+ layers = [
185
+ Snake1d(input_dim),
186
+ WNCausalTransposeConv1d(
187
+ input_dim,
188
+ output_dim,
189
+ kernel_size=2 * stride,
190
+ stride=stride,
191
+ padding=math.ceil(stride / 2),
192
+ output_padding=stride % 2,
193
+ ),
194
+ ]
195
+ if use_noise_block:
196
+ layers.append(NoiseBlock(output_dim))
197
+ layers.extend(
198
+ [
199
+ CausalResidualUnit(output_dim, dilation=1, groups=groups),
200
+ CausalResidualUnit(output_dim, dilation=3, groups=groups),
201
+ CausalResidualUnit(output_dim, dilation=9, groups=groups),
202
+ ]
203
+ )
204
+ self.block = nn.Sequential(*layers)
205
+
206
+ def forward(self, x):
207
+ return self.block(x)
208
+
209
+
210
+ class TransposeLastTwoDim(torch.nn.Module):
211
+ def forward(self, x):
212
+ return torch.transpose(x, -1, -2)
213
+
214
+
215
+ class CausalDecoder(nn.Module):
216
+ def __init__(
217
+ self,
218
+ input_channel,
219
+ channels,
220
+ rates,
221
+ depthwise: bool = False,
222
+ d_out: int = 1,
223
+ use_noise_block: bool = False,
224
+ ):
225
+ super().__init__()
226
+
227
+ # Add first conv layer
228
+ if depthwise:
229
+ layers = [
230
+ WNCausalConv1d(
231
+ input_channel,
232
+ input_channel,
233
+ kernel_size=7,
234
+ padding=3,
235
+ groups=input_channel,
236
+ ),
237
+ WNCausalConv1d(input_channel, channels, kernel_size=1),
238
+ ]
239
+ else:
240
+ layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
241
+
242
+ # Add upsampling + MRF blocks
243
+ for i, stride in enumerate(rates):
244
+ input_dim = channels // 2**i
245
+ output_dim = channels // 2 ** (i + 1)
246
+ groups = output_dim if depthwise else 1
247
+ layers += [
248
+ CausalDecoderBlock(
249
+ input_dim,
250
+ output_dim,
251
+ stride,
252
+ groups=groups,
253
+ use_noise_block=use_noise_block,
254
+ )
255
+ ]
256
+
257
+ # Add final conv layer
258
+ layers += [
259
+ Snake1d(output_dim),
260
+ WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
261
+ nn.Tanh(),
262
+ ]
263
+
264
+ self.model = nn.Sequential(*layers)
265
+
266
+ def forward(self, x):
267
+ return self.model(x)
268
+
269
+
270
+ class AudioVAEConfig(BaseModel):
271
+ encoder_dim: int = 128
272
+ encoder_rates: List[int] = [2, 5, 8, 8]
273
+ latent_dim: int = 64
274
+ decoder_dim: int = 1536
275
+ decoder_rates: List[int] = [8, 8, 5, 2]
276
+ depthwise: bool = True
277
+ sample_rate: int = 16000
278
+ use_noise_block: bool = False
279
+
280
+
281
+ class AudioVAE(nn.Module):
282
+ """
283
+ Args:
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ config: AudioVAEConfig = None,
289
+ ):
290
+ # 如果没有传入config,使用默认配置
291
+ if config is None:
292
+ config = AudioVAEConfig()
293
+
294
+ super().__init__()
295
+
296
+ encoder_dim = config.encoder_dim
297
+ encoder_rates = config.encoder_rates
298
+ latent_dim = config.latent_dim
299
+ decoder_dim = config.decoder_dim
300
+ decoder_rates = config.decoder_rates
301
+ depthwise = config.depthwise
302
+ sample_rate = config.sample_rate
303
+ use_noise_block = config.use_noise_block
304
+
305
+ self.encoder_dim = encoder_dim
306
+ self.encoder_rates = encoder_rates
307
+ self.decoder_dim = decoder_dim
308
+ self.decoder_rates = decoder_rates
309
+ self.depthwise = depthwise
310
+
311
+ self.use_noise_block = use_noise_block
312
+
313
+ if latent_dim is None:
314
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
315
+
316
+ self.latent_dim = latent_dim
317
+ self.hop_length = np.prod(encoder_rates)
318
+ self.encoder = CausalEncoder(
319
+ encoder_dim,
320
+ latent_dim,
321
+ encoder_rates,
322
+ depthwise=depthwise,
323
+ )
324
+
325
+ self.decoder = CausalDecoder(
326
+ latent_dim,
327
+ decoder_dim,
328
+ decoder_rates,
329
+ depthwise=depthwise,
330
+ use_noise_block=use_noise_block,
331
+ )
332
+ self.sample_rate = sample_rate
333
+ self.chunk_size = math.prod(encoder_rates)
334
+
335
+ def preprocess(self, audio_data, sample_rate):
336
+ if sample_rate is None:
337
+ sample_rate = self.sample_rate
338
+ assert sample_rate == self.sample_rate
339
+ pad_to = self.hop_length
340
+ length = audio_data.shape[-1]
341
+ right_pad = math.ceil(length / pad_to) * pad_to - length
342
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
343
+
344
+ return audio_data
345
+
346
+ def decode(self, z: torch.Tensor):
347
+ """Decode given latent codes and return audio data
348
+
349
+ Parameters
350
+ ----------
351
+ z : Tensor[B x D x T]
352
+ Quantized continuous representation of input
353
+ length : int, optional
354
+ Number of samples in output audio, by default None
355
+
356
+ Returns
357
+ -------
358
+ dict
359
+ A dictionary with the following keys:
360
+ "audio" : Tensor[B x 1 x length]
361
+ Decoded audio data.
362
+ """
363
+ return self.decoder(z)
364
+
365
+ def encode(self, audio_data: torch.Tensor, sample_rate: int):
366
+ """
367
+ Args:
368
+ audio_data: Tensor[B x 1 x T]
369
+ sample_rate: int
370
+ Returns:
371
+ z: Tensor[B x D x T]
372
+ """
373
+ if audio_data.ndim == 2:
374
+ audio_data = audio_data.unsqueeze(1)
375
+
376
+ audio_data = self.preprocess(audio_data, sample_rate)
377
+ return self.encoder(audio_data)["mu"]
voxcpm/modules/audiovae/audio_vae_v2.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import weight_norm
9
+ from pydantic import BaseModel
10
+
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+
16
+ def WNConvTranspose1d(*args, **kwargs):
17
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
18
+
19
+
20
+ class CausalConv1d(nn.Conv1d):
21
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ self.__padding = padding
24
+ self.__output_padding = output_padding
25
+
26
+ def forward(self, x):
27
+ x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
28
+ return super().forward(x_pad)
29
+
30
+
31
+ class CausalTransposeConv1d(nn.ConvTranspose1d):
32
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ self.__padding = padding
35
+ self.__output_padding = output_padding
36
+
37
+ def forward(self, x):
38
+ return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
39
+
40
+
41
+ def WNCausalConv1d(*args, **kwargs):
42
+ return weight_norm(CausalConv1d(*args, **kwargs))
43
+
44
+
45
+ def WNCausalTransposeConv1d(*args, **kwargs):
46
+ return weight_norm(CausalTransposeConv1d(*args, **kwargs))
47
+
48
+
49
+ # Scripting this brings model speed up 1.4x
50
+ @torch.jit.script
51
+ def snake(x, alpha):
52
+ shape = x.shape
53
+ x = x.reshape(shape[0], shape[1], -1)
54
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
55
+ x = x.reshape(shape)
56
+ return x
57
+
58
+
59
+ class Snake1d(nn.Module):
60
+ def __init__(self, channels):
61
+ super().__init__()
62
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
63
+
64
+ def forward(self, x):
65
+ return snake(x, self.alpha)
66
+
67
+
68
+ def init_weights(m):
69
+ if isinstance(m, nn.Conv1d):
70
+ nn.init.trunc_normal_(m.weight, std=0.02)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+
74
+
75
+ class CausalResidualUnit(nn.Module):
76
+ def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
77
+ super().__init__()
78
+ pad = ((7 - 1) * dilation) // 2
79
+ self.block = nn.Sequential(
80
+ Snake1d(dim),
81
+ WNCausalConv1d(
82
+ dim,
83
+ dim,
84
+ kernel_size=kernel,
85
+ dilation=dilation,
86
+ padding=pad,
87
+ groups=groups,
88
+ ),
89
+ Snake1d(dim),
90
+ WNCausalConv1d(dim, dim, kernel_size=1),
91
+ )
92
+
93
+ def forward(self, x):
94
+ y = self.block(x)
95
+ pad = (x.shape[-1] - y.shape[-1]) // 2
96
+ assert pad == 0
97
+ if pad > 0:
98
+ x = x[..., pad:-pad]
99
+ return x + y
100
+
101
+
102
+ class CausalEncoderBlock(nn.Module):
103
+ def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
104
+ super().__init__()
105
+ input_dim = input_dim or output_dim // 2
106
+ self.block = nn.Sequential(
107
+ CausalResidualUnit(input_dim, dilation=1, groups=groups),
108
+ CausalResidualUnit(input_dim, dilation=3, groups=groups),
109
+ CausalResidualUnit(input_dim, dilation=9, groups=groups),
110
+ Snake1d(input_dim),
111
+ WNCausalConv1d(
112
+ input_dim,
113
+ output_dim,
114
+ kernel_size=2 * stride,
115
+ stride=stride,
116
+ padding=math.ceil(stride / 2),
117
+ output_padding=stride % 2,
118
+ ),
119
+ )
120
+
121
+ def forward(self, x):
122
+ return self.block(x)
123
+
124
+
125
+ class CausalEncoder(nn.Module):
126
+ def __init__(
127
+ self,
128
+ d_model: int = 64,
129
+ latent_dim: int = 32,
130
+ strides: list = [2, 4, 8, 8],
131
+ depthwise: bool = False,
132
+ ):
133
+ super().__init__()
134
+ # Create first convolution
135
+ self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
136
+
137
+ # Create EncoderBlocks that double channels as they downsample by `stride`
138
+ for stride in strides:
139
+ d_model *= 2
140
+ groups = d_model // 2 if depthwise else 1
141
+ self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
142
+
143
+ groups = d_model if depthwise else 1
144
+
145
+ # Create two convolution, for mu and logvar
146
+ self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
147
+ self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
148
+
149
+ # Wrap black into nn.Sequential
150
+ self.block = nn.Sequential(*self.block)
151
+ self.enc_dim = d_model
152
+
153
+ def forward(self, x):
154
+ hidden_state = self.block(x)
155
+ return {
156
+ "hidden_state": hidden_state,
157
+ "mu": self.fc_mu(hidden_state),
158
+ "logvar": self.fc_logvar(hidden_state),
159
+ }
160
+
161
+
162
+ class NoiseBlock(nn.Module):
163
+ def __init__(self, dim):
164
+ super().__init__()
165
+ self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
166
+
167
+ def forward(self, x):
168
+ B, C, T = x.shape
169
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
170
+ h = self.linear(x)
171
+ n = noise * h
172
+ x = x + n
173
+ return x
174
+
175
+
176
+ class CausalDecoderBlock(nn.Module):
177
+ def __init__(
178
+ self,
179
+ input_dim: int = 16,
180
+ output_dim: int = 8,
181
+ stride: int = 1,
182
+ groups=1,
183
+ use_noise_block: bool = False,
184
+ ):
185
+ super().__init__()
186
+ layers = [
187
+ Snake1d(input_dim),
188
+ WNCausalTransposeConv1d(
189
+ input_dim,
190
+ output_dim,
191
+ kernel_size=2 * stride,
192
+ stride=stride,
193
+ padding=math.ceil(stride / 2),
194
+ output_padding=stride % 2,
195
+ ),
196
+ ]
197
+ if use_noise_block:
198
+ layers.append(NoiseBlock(output_dim))
199
+ layers.extend(
200
+ [
201
+ CausalResidualUnit(output_dim, dilation=1, groups=groups),
202
+ CausalResidualUnit(output_dim, dilation=3, groups=groups),
203
+ CausalResidualUnit(output_dim, dilation=9, groups=groups),
204
+ ]
205
+ )
206
+ self.block = nn.Sequential(*layers)
207
+ self.input_channels = input_dim
208
+
209
+ def forward(self, x):
210
+ return self.block(x)
211
+
212
+
213
+ class TransposeLastTwoDim(torch.nn.Module):
214
+ def forward(self, x):
215
+ return torch.transpose(x, -1, -2)
216
+
217
+
218
+ class SampleRateConditionLayer(nn.Module):
219
+ def __init__(
220
+ self,
221
+ input_dim: int,
222
+ sr_bin_buckets: int = None,
223
+ cond_type: str = "scale_bias",
224
+ cond_dim: int = 128,
225
+ out_layer: bool = False,
226
+ ):
227
+ super().__init__()
228
+
229
+ self.cond_type, out_layer_in_dim = cond_type, input_dim
230
+
231
+ if cond_type == "scale_bias":
232
+ self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
233
+ self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
234
+ nn.init.ones_(self.scale_embed.weight)
235
+ nn.init.zeros_(self.bias_embed.weight)
236
+ elif cond_type == "scale_bias_init":
237
+ self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
238
+ self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
239
+ nn.init.normal_(self.scale_embed.weight, mean=1)
240
+ nn.init.normal_(self.bias_embed.weight)
241
+ elif cond_type == "add":
242
+ self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
243
+ nn.init.normal_(self.cond_embed.weight)
244
+ elif cond_type == "concat":
245
+ self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
246
+ assert out_layer, "out_layer must be True for concat cond_type"
247
+ out_layer_in_dim = input_dim + cond_dim
248
+ else:
249
+ raise ValueError(f"Invalid cond_type: {cond_type}")
250
+
251
+ if out_layer:
252
+ self.out_layer = nn.Sequential(
253
+ Snake1d(out_layer_in_dim),
254
+ WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
255
+ )
256
+ else:
257
+ self.out_layer = nn.Identity()
258
+
259
+ def forward(self, x, sr_cond):
260
+ if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
261
+ x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
262
+ elif self.cond_type == "add":
263
+ x = x + self.cond_embed(sr_cond).unsqueeze(-1)
264
+ elif self.cond_type == "concat":
265
+ x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
266
+
267
+ return self.out_layer(x)
268
+
269
+
270
+ class CausalDecoder(nn.Module):
271
+ def __init__(
272
+ self,
273
+ input_channel,
274
+ channels,
275
+ rates,
276
+ depthwise: bool = False,
277
+ d_out: int = 1,
278
+ use_noise_block: bool = False,
279
+ sr_bin_boundaries: List[int] = None,
280
+ cond_type: str = "scale_bias",
281
+ cond_dim: int = 128,
282
+ cond_out_layer: bool = False,
283
+ ):
284
+ super().__init__()
285
+
286
+ # Add first conv layer
287
+ if depthwise:
288
+ layers = [
289
+ WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
290
+ WNCausalConv1d(input_channel, channels, kernel_size=1),
291
+ ]
292
+ else:
293
+ layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
294
+
295
+ # Add upsampling + MRF blocks
296
+ for i, stride in enumerate(rates):
297
+ input_dim = channels // 2**i
298
+ output_dim = channels // 2 ** (i + 1)
299
+ groups = output_dim if depthwise else 1
300
+ layers += [
301
+ CausalDecoderBlock(
302
+ input_dim,
303
+ output_dim,
304
+ stride,
305
+ groups=groups,
306
+ use_noise_block=use_noise_block,
307
+ )
308
+ ]
309
+
310
+ # Add final conv layer
311
+ layers += [
312
+ Snake1d(output_dim),
313
+ WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
314
+ nn.Tanh(),
315
+ ]
316
+
317
+ if sr_bin_boundaries is None:
318
+ self.model = nn.Sequential(*layers)
319
+ self.sr_bin_boundaries = None
320
+ else:
321
+ self.model = nn.ModuleList(layers)
322
+
323
+ self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
324
+ self.sr_bin_buckets = len(sr_bin_boundaries) + 1
325
+
326
+ cond_layers = []
327
+ for layer in self.model:
328
+ if layer.__class__.__name__ == "CausalDecoderBlock":
329
+ cond_layers.append(
330
+ SampleRateConditionLayer(
331
+ input_dim=layer.input_channels,
332
+ sr_bin_buckets=self.sr_bin_buckets,
333
+ cond_type=cond_type,
334
+ cond_dim=cond_dim,
335
+ out_layer=cond_out_layer,
336
+ )
337
+ )
338
+ else:
339
+ cond_layers.append(None)
340
+ self.sr_cond_model = nn.ModuleList(cond_layers)
341
+
342
+ def get_sr_idx(self, sr):
343
+ return torch.bucketize(sr, self.sr_bin_boundaries)
344
+
345
+ def forward(self, x, sr_cond=None):
346
+ if self.sr_bin_boundaries is not None:
347
+ # assert sr_cond is not None
348
+ sr_cond = self.get_sr_idx(sr_cond)
349
+
350
+ for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
351
+ if sr_cond_layer is not None:
352
+ x = sr_cond_layer(x, sr_cond)
353
+ x = layer(x)
354
+ return x
355
+ else:
356
+ return self.model(x)
357
+
358
+
359
+ class AudioVAEConfig(BaseModel):
360
+ encoder_dim: int = 128
361
+ encoder_rates: List[int] = [2, 5, 8, 8]
362
+ latent_dim: int = 64
363
+ decoder_dim: int = 2048
364
+ decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
365
+ depthwise: bool = True
366
+ sample_rate: int = 16000
367
+ out_sample_rate: int = 48000
368
+ use_noise_block: bool = False
369
+ sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
370
+ cond_type: str = "scale_bias"
371
+ cond_dim: int = 128
372
+ cond_out_layer: bool = False
373
+
374
+
375
+ class AudioVAE(nn.Module):
376
+ """
377
+ Args:
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ config: AudioVAEConfig = None,
383
+ ):
384
+ # 如果没有传入config,使用默认配置
385
+ if config is None:
386
+ config = AudioVAEConfig()
387
+
388
+ super().__init__()
389
+
390
+ encoder_dim = config.encoder_dim
391
+ encoder_rates = config.encoder_rates
392
+ latent_dim = config.latent_dim
393
+ decoder_dim = config.decoder_dim
394
+ decoder_rates = config.decoder_rates
395
+ depthwise = config.depthwise
396
+ sample_rate = config.sample_rate
397
+ out_sample_rate = config.out_sample_rate
398
+ use_noise_block = config.use_noise_block
399
+ sr_bin_boundaries = config.sr_bin_boundaries
400
+ cond_type = config.cond_type
401
+ cond_dim = config.cond_dim
402
+ cond_out_layer = config.cond_out_layer
403
+
404
+ self.encoder_dim = encoder_dim
405
+ self.encoder_rates = encoder_rates
406
+ self.decoder_dim = decoder_dim
407
+ self.decoder_rates = decoder_rates
408
+ self.depthwise = depthwise
409
+
410
+ self.use_noise_block = use_noise_block
411
+
412
+ if latent_dim is None:
413
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
414
+
415
+ self.latent_dim = latent_dim
416
+ self.hop_length = np.prod(encoder_rates)
417
+ self.encoder = CausalEncoder(
418
+ encoder_dim,
419
+ latent_dim,
420
+ encoder_rates,
421
+ depthwise=depthwise,
422
+ )
423
+
424
+ self.decoder = CausalDecoder(
425
+ latent_dim,
426
+ decoder_dim,
427
+ decoder_rates,
428
+ depthwise=depthwise,
429
+ use_noise_block=use_noise_block,
430
+ sr_bin_boundaries=sr_bin_boundaries,
431
+ cond_type=cond_type,
432
+ cond_dim=cond_dim,
433
+ cond_out_layer=cond_out_layer,
434
+ )
435
+ self.sample_rate = sample_rate
436
+ self.out_sample_rate = out_sample_rate
437
+ self.sr_bin_boundaries = sr_bin_boundaries
438
+ self.chunk_size = math.prod(encoder_rates)
439
+
440
+ def preprocess(self, audio_data, sample_rate):
441
+ if sample_rate is None:
442
+ sample_rate = self.sample_rate
443
+ assert sample_rate == self.sample_rate
444
+ pad_to = self.hop_length
445
+ length = audio_data.shape[-1]
446
+ right_pad = math.ceil(length / pad_to) * pad_to - length
447
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
448
+
449
+ return audio_data
450
+
451
+ def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
452
+ """Decode given latent codes and return audio data
453
+
454
+ Parameters
455
+ ----------
456
+ z : Tensor[B x D x T]
457
+ Quantized continuous representation of input
458
+ length : int, optional
459
+ Number of samples in output audio, by default None
460
+
461
+ Returns
462
+ -------
463
+ dict
464
+ A dictionary with the following keys:
465
+ "audio" : Tensor[B x 1 x length]
466
+ Decoded audio data.
467
+ """
468
+ if self.sr_bin_boundaries is not None:
469
+ # use default output sample rate
470
+ if sr_cond is None:
471
+ sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
472
+ return self.decoder(z, sr_cond)
473
+
474
+ def encode(self, audio_data: torch.Tensor, sample_rate: int):
475
+ """
476
+ Args:
477
+ audio_data: Tensor[B x 1 x T]
478
+ sample_rate: int
479
+ Returns:
480
+ z: Tensor[B x D x T]
481
+ """
482
+ if audio_data.ndim == 2:
483
+ audio_data = audio_data.unsqueeze(1)
484
+
485
+ audio_data = self.preprocess(audio_data, sample_rate)
486
+ return self.encoder(audio_data)["mu"]
voxcpm/modules/layers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scalar_quantization_layer import ScalarQuantizationLayer
voxcpm/modules/layers/lora.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LoRALinear(nn.Module):
10
+ """
11
+ LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
12
+
13
+ state_dict 结构:
14
+ - weight: 原始权重(与 nn.Linear 一致)
15
+ - bias: 原始偏置(与 nn.Linear 一致)
16
+ - lora_A: LoRA 低秩矩阵 A
17
+ - lora_B: LoRA 低秩矩阵 B
18
+
19
+ 这样设计的好处:加载预训练权重时无需做 key 转换。
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ base: nn.Linear,
25
+ r: int,
26
+ alpha: float = 1.0,
27
+ dropout: float = 0.0,
28
+ ):
29
+ super().__init__()
30
+ assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
31
+
32
+ self.in_features = base.in_features
33
+ self.out_features = base.out_features
34
+ self.r = r
35
+ self.alpha = alpha
36
+ self._base_scaling = alpha / r if r > 0 else 0.0
37
+
38
+ # 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
39
+ # persistent=False 表示不保存到 state_dict,避免加载时 missing key
40
+ self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
41
+
42
+ # 直接持有 weight 和 bias(从原始 Linear 转移过来)
43
+ self.weight = base.weight
44
+ self.bias = base.bias # 可能是 None
45
+
46
+ # LoRA 参数
47
+ if r > 0:
48
+ self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
49
+ self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
50
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
51
+ nn.init.zeros_(self.lora_B)
52
+ else:
53
+ self.register_parameter("lora_A", None)
54
+ self.register_parameter("lora_B", None)
55
+
56
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ # 基础 Linear 计算
60
+ result = F.linear(x, self.weight, self.bias)
61
+ if self.r <= 0 or self.lora_A is None:
62
+ return result
63
+ # LoRA: result + dropout(x @ A^T @ B^T) * scaling
64
+ lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
65
+ return result + self.dropout(lora_out) * self.scaling
66
+
67
+ def reset_lora_parameters(self):
68
+ """重置 LoRA 参数到初始状态"""
69
+ if self.r > 0 and self.lora_A is not None:
70
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
71
+ nn.init.zeros_(self.lora_B)
72
+
73
+ def set_enabled(self, enabled: bool):
74
+ """启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
75
+ # 使用 fill_ 原地修改 buffer 值,不会触发重编译
76
+ self.scaling.fill_(self._base_scaling if enabled else 0.0)
77
+
78
+ @property
79
+ def enabled(self) -> bool:
80
+ return self.scaling.item() != 0.0
81
+
82
+
83
+ def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
84
+ """
85
+ 根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
86
+ """
87
+ parts = name.split(".")
88
+ if len(parts) == 1:
89
+ return root
90
+ parent = root
91
+ for p in parts[:-1]:
92
+ if not hasattr(parent, p):
93
+ return None
94
+ parent = getattr(parent, p)
95
+ return parent
96
+
97
+
98
+ def apply_lora_to_named_linear_modules(
99
+ root: nn.Module,
100
+ *,
101
+ target_submodule_names: list[str],
102
+ r: int,
103
+ alpha: float,
104
+ dropout: float,
105
+ ) -> None:
106
+ """
107
+ 在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
108
+
109
+ 例如 target_submodule_names=["q_proj", "v_proj"] 时,
110
+ 会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
111
+ """
112
+ for full_name, module in list(root.named_modules()):
113
+ if not isinstance(module, nn.Linear):
114
+ continue
115
+ short_name = full_name.split(".")[-1]
116
+ if short_name not in target_submodule_names:
117
+ continue
118
+
119
+ parent = _get_parent_module(root, full_name)
120
+ if parent is None:
121
+ continue
122
+
123
+ # 用 LoRALinear 替换原始 Linear
124
+ lora_layer = LoRALinear(
125
+ base=module,
126
+ r=r,
127
+ alpha=alpha,
128
+ dropout=dropout,
129
+ )
130
+ setattr(parent, short_name, lora_layer)
voxcpm/modules/layers/scalar_quantization_layer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ScalarQuantizationLayer(nn.Module):
6
+ def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
7
+ super().__init__()
8
+ self.in_dim = in_dim
9
+ self.out_dim = out_dim
10
+ self.latent_dim = latent_dim
11
+ self.scale = scale
12
+
13
+ self.in_proj = nn.Linear(in_dim, latent_dim)
14
+ self.out_proj = nn.Linear(latent_dim, out_dim)
15
+
16
+ def forward(self, hidden):
17
+ hidden = self.in_proj(hidden)
18
+ hidden = torch.tanh(hidden)
19
+
20
+ if self.training:
21
+ quantized = torch.round(hidden * self.scale) / self.scale
22
+ hidden = hidden + (quantized - hidden).detach()
23
+ else:
24
+ hidden = torch.round(hidden * self.scale) / self.scale
25
+
26
+ return self.out_proj(hidden)
voxcpm/modules/locdit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .unified_cfm import UnifiedCFM, CfmConfig
2
+ from .local_dit import VoxCPMLocDiT
3
+ from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
voxcpm/modules/locdit/local_dit.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..minicpm4 import MiniCPMModel, MiniCPM4Config
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ class SinusoidalPosEmb(torch.nn.Module):
8
+ def __init__(self, dim):
9
+ super().__init__()
10
+ self.dim = dim
11
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
12
+
13
+ def forward(self, x, scale=1000):
14
+ if x.ndim < 1:
15
+ x = x.unsqueeze(0)
16
+ device = x.device
17
+ half_dim = self.dim // 2
18
+ emb = math.log(10000) / (half_dim - 1)
19
+ emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
20
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
21
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
22
+ return emb
23
+
24
+
25
+ class TimestepEmbedding(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ time_embed_dim: int,
30
+ out_dim: int = None,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
35
+ self.act = nn.SiLU()
36
+ if out_dim is not None:
37
+ time_embed_dim_out = out_dim
38
+ else:
39
+ time_embed_dim_out = time_embed_dim
40
+
41
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
42
+
43
+ def forward(self, sample):
44
+ sample = self.linear_1(sample)
45
+ sample = self.act(sample)
46
+ sample = self.linear_2(sample)
47
+ return sample
48
+
49
+
50
+ class VoxCPMLocDiT(nn.Module):
51
+ """
52
+ Diffusion model with a Transformer backbone.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ config: MiniCPM4Config,
58
+ in_channels: int = 64,
59
+ ):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ self.out_channels = in_channels
63
+ self.config = config
64
+
65
+ self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
66
+ self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
67
+ self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
68
+
69
+ self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
70
+ self.time_mlp = TimestepEmbedding(
71
+ in_channels=config.hidden_size,
72
+ time_embed_dim=config.hidden_size,
73
+ )
74
+ self.delta_time_mlp = TimestepEmbedding(
75
+ in_channels=config.hidden_size,
76
+ time_embed_dim=config.hidden_size,
77
+ )
78
+
79
+ assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
80
+ self.decoder = MiniCPMModel(config)
81
+
82
+ def forward(
83
+ self,
84
+ x: torch.Tensor,
85
+ mu: torch.Tensor,
86
+ t: torch.Tensor,
87
+ cond: torch.Tensor,
88
+ dt: torch.Tensor,
89
+ ):
90
+ """
91
+ Forward pass of DiT.
92
+ x: (N, C, T) tensor of inputs
93
+ mu: (N, C) tensor of hidden embedding
94
+ t: (N,) tensor of diffusion timesteps
95
+ cond: (N, C, T') tensor of prefix conditions
96
+ dt: (N,) used for mean velocity (may be supported in the future...)
97
+ """
98
+ x = self.in_proj(x.transpose(1, 2).contiguous())
99
+
100
+ cond = self.cond_proj(cond.transpose(1, 2).contiguous())
101
+ prefix = cond.size(1)
102
+
103
+ t = self.time_embeddings(t).to(x.dtype)
104
+ t = self.time_mlp(t)
105
+ dt = self.time_embeddings(dt).to(x.dtype)
106
+ dt = self.delta_time_mlp(dt)
107
+ t = t + dt
108
+
109
+ x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
110
+ hidden, _ = self.decoder(x, is_causal=False)
111
+ hidden = hidden[:, prefix + 1 :, :]
112
+ hidden = self.out_proj(hidden)
113
+
114
+ return hidden.transpose(1, 2).contiguous()
voxcpm/modules/locdit/local_dit_v2.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..minicpm4 import MiniCPMModel, MiniCPM4Config
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ class SinusoidalPosEmb(torch.nn.Module):
8
+ def __init__(self, dim):
9
+ super().__init__()
10
+ self.dim = dim
11
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
12
+
13
+ def forward(self, x, scale=1000):
14
+ if x.ndim < 1:
15
+ x = x.unsqueeze(0)
16
+ device = x.device
17
+ half_dim = self.dim // 2
18
+ emb = math.log(10000) / (half_dim - 1)
19
+ emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
20
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
21
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
22
+ return emb
23
+
24
+
25
+ class TimestepEmbedding(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ time_embed_dim: int,
30
+ out_dim: int = None,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
35
+ self.act = nn.SiLU()
36
+ if out_dim is not None:
37
+ time_embed_dim_out = out_dim
38
+ else:
39
+ time_embed_dim_out = time_embed_dim
40
+
41
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
42
+
43
+ def forward(self, sample):
44
+ sample = self.linear_1(sample)
45
+ sample = self.act(sample)
46
+ sample = self.linear_2(sample)
47
+ return sample
48
+
49
+
50
+ class VoxCPMLocDiT(nn.Module):
51
+ """
52
+ Diffusion model with a Transformer backbone.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ config: MiniCPM4Config,
58
+ in_channels: int = 64,
59
+ ):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ self.out_channels = in_channels
63
+ self.config = config
64
+
65
+ self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
66
+ self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
67
+ self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
68
+
69
+ self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
70
+ self.time_mlp = TimestepEmbedding(
71
+ in_channels=config.hidden_size,
72
+ time_embed_dim=config.hidden_size,
73
+ )
74
+ self.delta_time_mlp = TimestepEmbedding(
75
+ in_channels=config.hidden_size,
76
+ time_embed_dim=config.hidden_size,
77
+ )
78
+
79
+ assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
80
+ self.decoder = MiniCPMModel(config)
81
+
82
+ def forward(
83
+ self,
84
+ x: torch.Tensor,
85
+ mu: torch.Tensor,
86
+ t: torch.Tensor,
87
+ cond: torch.Tensor,
88
+ dt: torch.Tensor,
89
+ ):
90
+ """
91
+ Forward pass of DiT.
92
+ x: (N, C, T) tensor of inputs
93
+ mu: (N, C) tensor of hidden embedding
94
+ t: (N,) tensor of diffusion timesteps
95
+ cond: (N, C, T') tensor of prefix conditions
96
+ dt: (N,) used for mean velocity (may be supported in the future...)
97
+ """
98
+ x = self.in_proj(x.transpose(1, 2).contiguous())
99
+
100
+ cond = self.cond_proj(cond.transpose(1, 2).contiguous())
101
+ prefix = cond.size(1)
102
+
103
+ t = self.time_embeddings(t).to(x.dtype)
104
+ t = self.time_mlp(t)
105
+ dt = self.time_embeddings(dt).to(x.dtype)
106
+ dt = self.delta_time_mlp(dt)
107
+ t = t + dt
108
+
109
+ mu = mu.view(x.size(0), -1, x.size(-1))
110
+ x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
111
+
112
+ hidden, _ = self.decoder(x, is_causal=False)
113
+ hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
114
+ hidden = self.out_proj(hidden)
115
+
116
+ return hidden.transpose(1, 2).contiguous()
voxcpm/modules/locdit/unified_cfm.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.func import jvp
6
+ from pydantic import BaseModel
7
+
8
+ from .local_dit import VoxCPMLocDiT
9
+
10
+
11
+ class CfmConfig(BaseModel):
12
+ sigma_min: float = 1e-6
13
+ solver: str = "euler"
14
+ t_scheduler: str = "log-norm"
15
+ training_cfg_rate: float = 0.1
16
+ inference_cfg_rate: float = 1.0
17
+ reg_loss_type: str = "l1"
18
+ ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
19
+ noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
20
+ noise_cond_scale: float = 0.0
21
+
22
+
23
+ class UnifiedCFM(torch.nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_channels: int,
27
+ cfm_params: CfmConfig,
28
+ estimator: VoxCPMLocDiT,
29
+ mean_mode: bool = False,
30
+ ):
31
+ super().__init__()
32
+ self.solver = cfm_params.solver
33
+ self.sigma_min = cfm_params.sigma_min
34
+ self.t_scheduler = cfm_params.t_scheduler
35
+ self.training_cfg_rate = cfm_params.training_cfg_rate
36
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
37
+ self.reg_loss_type = cfm_params.reg_loss_type
38
+ self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
39
+ self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
40
+ self.noise_cond_scale = cfm_params.noise_cond_scale
41
+
42
+ self.in_channels = in_channels
43
+ self.mean_mode = mean_mode
44
+
45
+ self.estimator = estimator
46
+
47
+ # ------------------------------------------------------------------ #
48
+ # Inference
49
+ # ------------------------------------------------------------------ #
50
+ @torch.inference_mode()
51
+ def forward(
52
+ self,
53
+ mu: torch.Tensor,
54
+ n_timesteps: int,
55
+ patch_size: int,
56
+ cond: torch.Tensor,
57
+ temperature: float = 1.0,
58
+ cfg_value: float = 1.0,
59
+ sway_sampling_coef: float = 1.0,
60
+ use_cfg_zero_star: bool = True,
61
+ ):
62
+ b, _ = mu.shape
63
+ t = patch_size
64
+ z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
65
+
66
+ t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
68
+
69
+ return self.solve_euler(
70
+ x=z,
71
+ t_span=t_span,
72
+ mu=mu,
73
+ cond=cond,
74
+ cfg_value=cfg_value,
75
+ use_cfg_zero_star=use_cfg_zero_star,
76
+ )
77
+
78
+ def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
79
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
80
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
81
+ st_star = dot_product / squared_norm
82
+ return st_star
83
+
84
+ def solve_euler(
85
+ self,
86
+ x: torch.Tensor,
87
+ t_span: torch.Tensor,
88
+ mu: torch.Tensor,
89
+ cond: torch.Tensor,
90
+ cfg_value: float = 1.0,
91
+ use_cfg_zero_star: bool = True,
92
+ ):
93
+ t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
94
+
95
+ sol = []
96
+ zero_init_steps = max(1, int(len(t_span) * 0.04))
97
+ for step in range(1, len(t_span)):
98
+ if use_cfg_zero_star and step <= zero_init_steps:
99
+ dphi_dt = torch.zeros_like(x)
100
+ else:
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ b = x.size(0)
103
+ x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
104
+ mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
105
+ t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
106
+ dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
107
+ cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
108
+ x_in[:b], x_in[b:] = x, x
109
+ mu_in[:b] = mu
110
+ t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
111
+ dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
112
+ # not used now
113
+ if not self.mean_mode:
114
+ dt_in = torch.zeros_like(dt_in)
115
+ cond_in[:b], cond_in[b:] = cond, cond
116
+
117
+ dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
118
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
119
+
120
+ if use_cfg_zero_star:
121
+ positive_flat = dphi_dt.view(b, -1)
122
+ negative_flat = cfg_dphi_dt.view(b, -1)
123
+ st_star = self.optimized_scale(positive_flat, negative_flat)
124
+ st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
125
+ else:
126
+ st_star = 1.0
127
+
128
+ dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
129
+
130
+ x = x - dt * dphi_dt
131
+ t = t - dt
132
+ sol.append(x)
133
+ if step < len(t_span) - 1:
134
+ dt = t - t_span[step + 1]
135
+
136
+ return sol[-1]
137
+
138
+ # ------------------------------------------------------------------ #
139
+ # Training loss
140
+ # ------------------------------------------------------------------ #
141
+ def adaptive_loss_weighting(
142
+ self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
143
+ ):
144
+ weights = 1.0 / ((losses + epsilon).pow(p))
145
+ if mask is not None:
146
+ weights = weights * mask
147
+ return weights.detach()
148
+
149
+ def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
150
+ batch_size = x.shape[0]
151
+ if self.t_scheduler == "log-norm":
152
+ s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
153
+ s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
154
+ r = torch.sigmoid(s_r)
155
+ t = torch.sigmoid(s_t)
156
+ elif self.t_scheduler == "uniform":
157
+ r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
158
+ t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
159
+ else:
160
+ raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
161
+
162
+ mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
163
+ r, t = torch.where(
164
+ mask,
165
+ torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
166
+ torch.stack([t, t], dim=0),
167
+ )
168
+
169
+ return r.squeeze(), t.squeeze()
170
+
171
+ def compute_loss(
172
+ self,
173
+ x1: torch.Tensor,
174
+ mu: torch.Tensor,
175
+ cond: torch.Tensor | None = None,
176
+ tgt_mask: torch.Tensor | None = None,
177
+ progress: float = 0.0,
178
+ ):
179
+ b, _, _ = x1.shape
180
+
181
+ if self.training_cfg_rate > 0:
182
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
183
+ mu = mu * cfg_mask.view(-1, 1)
184
+
185
+ if cond is None:
186
+ cond = torch.zeros_like(x1)
187
+
188
+ noisy_mask = torch.rand(b, device=x1.device) > (
189
+ 1.0
190
+ - (
191
+ self.noise_cond_prob_range[0]
192
+ + progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
193
+ )
194
+ )
195
+ cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
196
+
197
+ ratio_r_neq_t = (
198
+ self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
199
+ if self.mean_mode
200
+ else 0.0
201
+ )
202
+
203
+ r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
204
+ r_ = r.detach().clone()
205
+ t_ = t.detach().clone()
206
+ z = torch.randn_like(x1)
207
+ y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
208
+ v = z - x1
209
+
210
+ def model_fn(z_sample, r_sample, t_sample):
211
+ return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
212
+
213
+ if self.mean_mode:
214
+ v_r = torch.zeros_like(r)
215
+ v_t = torch.ones_like(t)
216
+ from torch.backends.cuda import sdp_kernel
217
+
218
+ with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
219
+ u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
220
+ u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
221
+ else:
222
+ u_pred = model_fn(y, r, t)
223
+ u_tgt = v
224
+
225
+ losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
226
+ if tgt_mask is not None:
227
+ weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
228
+ loss = (weights * losses).sum() / torch.sum(tgt_mask)
229
+ else:
230
+ loss = losses.mean()
231
+
232
+ return loss
voxcpm/modules/locenc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .local_encoder import VoxCPMLocEnc
voxcpm/modules/locenc/local_encoder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..minicpm4 import MiniCPMModel, MiniCPM4Config
4
+ from einops import rearrange
5
+
6
+
7
+ class VoxCPMLocEnc(nn.Module):
8
+ def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
9
+ super().__init__()
10
+ self.config = config
11
+ self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
12
+ self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
13
+
14
+ assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
15
+ self.encoder = MiniCPMModel(config)
16
+
17
+ def forward(self, x):
18
+ """
19
+ x: [B, T, P, D]
20
+ """
21
+ B, T, P, D = x.shape
22
+
23
+ x = self.in_proj(x)
24
+ special_tokens = self.special_token.expand(B, T, 1, -1)
25
+ x = torch.cat([special_tokens, x], dim=2)
26
+ x = rearrange(x, "b t p c -> (b t) p c")
27
+ outputs, _ = self.encoder(x, is_causal=False)
28
+ cls_output = outputs[:, 0, :]
29
+
30
+ return rearrange(cls_output, "(b t) c -> b t c", b=B)
voxcpm/modules/minicpm4/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import MiniCPM4Config
2
+ from .model import MiniCPMModel
3
+ from .cache import StaticKVCache
voxcpm/modules/minicpm4/cache.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import torch
3
+
4
+
5
+ class StaticKVCache:
6
+ def __init__(
7
+ self,
8
+ num_layers: int,
9
+ num_kv_heads: int,
10
+ dim_kv_head: int,
11
+ batch_size: int,
12
+ device: torch.device,
13
+ dtype: torch.dtype,
14
+ max_length: int = 8192,
15
+ ):
16
+ self.max_length = max_length
17
+ self.num_layers = num_layers
18
+
19
+ self.kv_cache = torch.zeros(
20
+ 2,
21
+ num_layers,
22
+ batch_size,
23
+ num_kv_heads,
24
+ max_length,
25
+ dim_kv_head,
26
+ device=device,
27
+ dtype=dtype,
28
+ )
29
+ self.current_length = 0
30
+
31
+ def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
33
+
34
+ def step(self) -> int:
35
+ if self.current_length >= self.max_length:
36
+ raise ValueError("KV cache is full")
37
+
38
+ ret = self.current_length
39
+ self.current_length += 1
40
+ return ret
41
+
42
+ def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
43
+ self.current_length = kv_caches[0][0].size(2)
44
+ self.kv_cache.zero_()
45
+ for i in range(self.num_layers):
46
+ self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
47
+ self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
voxcpm/modules/minicpm4/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+
5
+ class RopeScalingConfig(BaseModel):
6
+ type: str
7
+ long_factor: List[float]
8
+ short_factor: List[float]
9
+ original_max_position_embeddings: int
10
+
11
+
12
+ class MiniCPM4Config(BaseModel):
13
+ bos_token_id: int
14
+ eos_token_id: int
15
+ hidden_size: int
16
+ intermediate_size: int
17
+ max_position_embeddings: int
18
+ num_attention_heads: int
19
+ num_hidden_layers: int
20
+ num_key_value_heads: int
21
+ rms_norm_eps: float
22
+ rope_scaling: RopeScalingConfig
23
+ vocab_size: int
24
+ use_mup: bool = True
25
+ scale_emb: float
26
+ dim_model_base: int
27
+ scale_depth: float
28
+ rope_theta: float
29
+ kv_channels: int = None
30
+ no_rope: bool = False
voxcpm/modules/minicpm4/model.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import MiniCPM4Config
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List, Tuple
5
+ import math
6
+ from .cache import StaticKVCache
7
+
8
+
9
+ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
10
+ old_dtype = hidden.dtype
11
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
12
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
13
+ return hidden * weight
14
+
15
+
16
+ class MiniCPMRMSNorm(nn.Module):
17
+ def __init__(self, hidden_size, eps=1e-6):
18
+ """
19
+ MiniCPMRMSNorm is equivalent to T5LayerNorm
20
+ """
21
+ super().__init__()
22
+ self.weight = nn.Parameter(torch.ones(hidden_size))
23
+ self.variance_epsilon = eps
24
+
25
+ def forward(self, hidden_states):
26
+ return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1, x2 = x.chunk(2, dim=-1)
32
+ return torch.cat((-x2, x1), dim=-1)
33
+
34
+
35
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
36
+ """
37
+ Args:
38
+ q: Tensor(batch_size, num_heads, seq_len, head_dim)
39
+ k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
40
+ cos: Tensor(seq_len, head_dim)
41
+ sin: Tensor(seq_len, head_dim)
42
+ Returns:
43
+ Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
44
+ """
45
+ orig_dtype = q.dtype
46
+ q = q.to(torch.float32)
47
+ k = k.to(torch.float32)
48
+ q_embed = (q * cos) + (rotate_half(q) * sin)
49
+ k_embed = (k * cos) + (rotate_half(k) * sin)
50
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
51
+
52
+
53
+ class MiniCPMLongRoPE(nn.Module):
54
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
55
+
56
+ def __init__(self, config: MiniCPM4Config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
60
+ self.base = config.rope_theta
61
+ self.max_position_embeddings = config.max_position_embeddings
62
+
63
+ self.short_factor = config.rope_scaling.short_factor
64
+ self.long_factor = config.rope_scaling.long_factor
65
+ self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
66
+
67
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
68
+ self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
69
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
70
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
71
+
72
+ self.max_seq_len_cached = 0
73
+
74
+ self.register_buffer("cos_cached", torch.empty(0), persistent=False)
75
+ self.register_buffer("sin_cached", torch.empty(0), persistent=False)
76
+
77
+ self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
78
+
79
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
80
+ """设置cos和sin缓存"""
81
+ self.max_seq_len_cached = seq_len
82
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
83
+
84
+ if seq_len > self.original_max_position_embeddings:
85
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
86
+ else:
87
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
88
+
89
+ freqs = torch.mul(
90
+ torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
91
+ )
92
+
93
+ # 创建embeddings
94
+ emb = torch.cat((freqs, freqs), dim=-1)
95
+
96
+ self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
97
+ self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
98
+
99
+ def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Args:
102
+ position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
103
+ Returns:
104
+ Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
105
+ """
106
+ cos = self.cos_cached[position_ids]
107
+ sin = self.sin_cached[position_ids]
108
+
109
+ return cos, sin
110
+
111
+
112
+ class MiniCPMAttention(nn.Module):
113
+ def __init__(self, config: MiniCPM4Config, layer_idx: int):
114
+ super().__init__()
115
+ self.config = config
116
+ self.layer_idx = layer_idx
117
+ self.hidden_size = config.hidden_size
118
+ self.num_heads = config.num_attention_heads
119
+ self.head_dim = (
120
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
121
+ )
122
+ self.num_key_value_heads = config.num_key_value_heads
123
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
124
+ self.max_position_embeddings = config.max_position_embeddings
125
+ self.rope_theta = 10000.0
126
+
127
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
128
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
129
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
130
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
136
+ is_causal: bool,
137
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
138
+ bsz, q_len, _ = hidden_states.size()
139
+
140
+ query_states = self.q_proj(hidden_states)
141
+ key_states = self.k_proj(hidden_states)
142
+ value_states = self.v_proj(hidden_states)
143
+
144
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
145
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
146
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
147
+
148
+ if position_emb is not None:
149
+ cos, sin = position_emb
150
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
151
+
152
+ # ref: https://github.com/pytorch/pytorch/issues/163597
153
+ # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
154
+ query_states = query_states.contiguous()
155
+ key_states = key_states.contiguous()
156
+ value_states = value_states.contiguous()
157
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
158
+ query_states,
159
+ key_states,
160
+ value_states,
161
+ is_causal=is_causal,
162
+ enable_gqa=True,
163
+ )
164
+
165
+ attn_output = attn_output.transpose(1, 2).contiguous()
166
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
167
+
168
+ attn_output = self.o_proj(attn_output)
169
+
170
+ past_key_value = (key_states, value_states)
171
+ return attn_output, past_key_value
172
+
173
+ def forward_step(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
177
+ position_id: int,
178
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
179
+ ) -> torch.Tensor:
180
+ bsz, _ = hidden_states.size()
181
+
182
+ query_states = self.q_proj(hidden_states)
183
+ key_states = self.k_proj(hidden_states)
184
+ value_states = self.v_proj(hidden_states)
185
+
186
+ query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
187
+ key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
188
+ value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
189
+
190
+ if position_emb is not None:
191
+ cos, sin = position_emb
192
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
193
+
194
+ key_cache, value_cache = kv_cache
195
+
196
+ key_cache[:, :, position_id, :] = key_states
197
+ value_cache[:, :, position_id, :] = value_states
198
+
199
+ attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
200
+
201
+ # ref: https://github.com/pytorch/pytorch/issues/163597
202
+ # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
203
+ query_states = query_states.contiguous()
204
+ key_cache = key_cache.contiguous()
205
+ value_cache = value_cache.contiguous()
206
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
207
+ query_states,
208
+ key_cache,
209
+ value_cache,
210
+ attn_mask=attn_mask,
211
+ enable_gqa=True,
212
+ )
213
+
214
+ attn_output = attn_output.transpose(1, 2).contiguous()
215
+ attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
216
+ attn_output = self.o_proj(attn_output)
217
+
218
+ return attn_output
219
+
220
+
221
+ class MiniCPMMLP(nn.Module):
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ self.config = config
225
+ self.hidden_size = config.hidden_size
226
+ self.intermediate_size = config.intermediate_size
227
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
228
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
229
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
230
+ self.act_fn = nn.SiLU()
231
+
232
+ def forward(self, x):
233
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
234
+
235
+
236
+ class MiniCPMDecoderLayer(nn.Module):
237
+ def __init__(self, config: MiniCPM4Config, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+ self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
241
+
242
+ self.mlp = MiniCPMMLP(config)
243
+ self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
+ self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+
246
+ self.scale_depth = config.scale_depth
247
+ self.num_hidden_layers = config.num_hidden_layers
248
+ self.use_mup = config.use_mup
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
254
+ is_causal: bool,
255
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
256
+ """
257
+ Args:
258
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
+ position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
260
+ is_causal (`bool`): whether the attention mask is causal
261
+ """
262
+ residual = hidden_states
263
+ hidden_states = self.input_layernorm(hidden_states)
264
+ # Self Attention
265
+ hidden_states, present_key_value = self.self_attn(
266
+ hidden_states=hidden_states,
267
+ position_emb=position_emb,
268
+ is_causal=is_causal,
269
+ )
270
+
271
+ if self.use_mup:
272
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
273
+ else:
274
+ hidden_states = residual + hidden_states
275
+
276
+ # Fully Connected
277
+ residual = hidden_states
278
+ hidden_states = self.post_attention_layernorm(hidden_states)
279
+
280
+ hidden_states = self.mlp(hidden_states)
281
+ if self.use_mup:
282
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
283
+ else:
284
+ hidden_states = residual + hidden_states
285
+
286
+ return hidden_states, present_key_value
287
+
288
+ def forward_step(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
292
+ position_id: torch.Tensor,
293
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
294
+ ) -> torch.Tensor:
295
+ residual = hidden_states
296
+ hidden_states = self.input_layernorm(hidden_states)
297
+ # Self Attention
298
+ hidden_states = self.self_attn.forward_step(
299
+ hidden_states=hidden_states,
300
+ position_emb=position_emb,
301
+ position_id=position_id,
302
+ kv_cache=kv_cache,
303
+ )
304
+
305
+ if self.use_mup:
306
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
307
+ else:
308
+ hidden_states = residual + hidden_states
309
+
310
+ # Fully Connected
311
+ residual = hidden_states
312
+ hidden_states = self.post_attention_layernorm(hidden_states)
313
+
314
+ hidden_states = self.mlp(hidden_states)
315
+ if self.use_mup:
316
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
317
+ else:
318
+ hidden_states = residual + hidden_states
319
+
320
+ return hidden_states
321
+
322
+
323
+ class MiniCPMModel(nn.Module):
324
+ """
325
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
326
+
327
+ Args:
328
+ config: MiniCPMConfig
329
+ """
330
+
331
+ def __init__(self, config: MiniCPM4Config):
332
+ super().__init__()
333
+ self.vocab_size = config.vocab_size
334
+ self.config = config
335
+
336
+ if config.vocab_size > 0:
337
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
338
+ else:
339
+ self.embed_tokens = nn.Identity()
340
+
341
+ self.layers = nn.ModuleList(
342
+ [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
343
+ )
344
+
345
+ self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
346
+ if config.no_rope:
347
+ self.rope_emb = None
348
+ else:
349
+ self.rope_emb = MiniCPMLongRoPE(config)
350
+
351
+ self.kv_cache = None
352
+
353
+ def forward(
354
+ self,
355
+ inputs_embeds: torch.Tensor,
356
+ is_causal: bool = True,
357
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
358
+ """
359
+ Args:
360
+ inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
361
+ is_causal: bool, whether the attention mask is causal
362
+ Returns:
363
+ hidden_states: Tensor(batch_size, seq_length, hidden_size)
364
+ next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
365
+ """
366
+ if self.rope_emb is not None:
367
+ position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
368
+ position_emb = self.rope_emb(position_ids)
369
+ else:
370
+ position_emb = None
371
+ hidden_states = inputs_embeds
372
+
373
+ next_decoder_cache = []
374
+
375
+ for decoder_layer in self.layers:
376
+
377
+ hidden_states, this_cache = decoder_layer(
378
+ hidden_states,
379
+ position_emb,
380
+ is_causal,
381
+ )
382
+ next_decoder_cache.append(this_cache)
383
+ hidden_states = self.norm(hidden_states)
384
+ return hidden_states, next_decoder_cache
385
+
386
+ def forward_step(
387
+ self,
388
+ inputs_embeds: torch.Tensor,
389
+ position_id: torch.Tensor,
390
+ ) -> torch.Tensor:
391
+ """
392
+ Args:
393
+ inputs_embeds: Tensor(batch_size, hidden_size)
394
+ Returns:
395
+ hidden_states: Tensor(batch_size, hidden_size)
396
+ """
397
+ assert self.kv_cache is not None, "KV cache is not setup"
398
+
399
+ if self.rope_emb is not None:
400
+ position_emb = self.rope_emb(position_id)
401
+ else:
402
+ position_emb = None
403
+ hidden_states = inputs_embeds
404
+
405
+ for i, decoder_layer in enumerate(self.layers):
406
+ hidden_states = decoder_layer.forward_step(
407
+ hidden_states,
408
+ position_emb,
409
+ position_id,
410
+ self.kv_cache.get_layer_cache(i),
411
+ )
412
+
413
+ hidden_states = self.norm(hidden_states)
414
+ return hidden_states
415
+
416
+ def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
417
+ self.kv_cache = StaticKVCache(
418
+ num_layers=self.config.num_hidden_layers,
419
+ num_kv_heads=self.config.num_key_value_heads,
420
+ dim_kv_head=(
421
+ self.config.hidden_size // self.config.num_attention_heads
422
+ if self.config.kv_channels is None
423
+ else self.config.kv_channels
424
+ ),
425
+ batch_size=batch_size,
426
+ device=device,
427
+ dtype=dtype,
428
+ max_length=max_length,
429
+ )
voxcpm/training/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities for VoxCPM fine-tuning.
3
+
4
+ This package mirrors the training mechanics used in the minicpm-audio
5
+ tooling while relying solely on local audio-text datasets managed via
6
+ the HuggingFace ``datasets`` library.
7
+ """
8
+
9
+ from .accelerator import Accelerator
10
+ from .tracker import TrainingTracker
11
+ from .data import (
12
+ load_audio_text_datasets,
13
+ HFVoxCPMDataset,
14
+ build_dataloader,
15
+ BatchProcessor,
16
+ )
17
+ from .state import TrainingState
18
+
19
+ __all__ = [
20
+ "Accelerator",
21
+ "TrainingTracker",
22
+ "HFVoxCPMDataset",
23
+ "BatchProcessor",
24
+ "TrainingState",
25
+ "load_audio_text_datasets",
26
+ "build_dataloader",
27
+ ]
voxcpm/training/accelerator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import os
5
+ import random
6
+ import typing
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.utils.data
12
+ from torch.nn.parallel import DistributedDataParallel
13
+
14
+
15
+ class Accelerator:
16
+ """
17
+ Simplified accelerator that mirrors the behaviour of the minicpm-audio
18
+ training utilities. It initializes a distributed process group when
19
+ ``torchrun`` is used and exposes helpers for AMP, gradient scaling and
20
+ preparing models/dataloaders for DDP.
21
+ """
22
+
23
+ def __init__(self, amp: bool = False, seed: int = 42):
24
+ self.world_size = int(os.getenv("WORLD_SIZE", "1"))
25
+
26
+ if self.world_size > 1 and not dist.is_initialized():
27
+ dist.init_process_group("nccl", init_method="env://")
28
+
29
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
30
+ self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
31
+ self.amp = amp
32
+
33
+ # Set random seed to ensure model initialization consistency
34
+ self._set_seed(seed)
35
+
36
+ class DummyScaler:
37
+ def step(self, optimizer):
38
+ optimizer.step()
39
+
40
+ def scale(self, loss):
41
+ return loss
42
+
43
+ def unscale_(self, optimizer):
44
+ return optimizer
45
+
46
+ def update(self):
47
+ pass
48
+
49
+ self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
50
+ self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
51
+ self._ddp_model = None # For no_sync support
52
+
53
+ def _set_seed(self, seed: int):
54
+ """Set random seed to ensure model initialization consistency across multiple GPUs"""
55
+ torch.manual_seed(seed)
56
+ np.random.seed(seed)
57
+ random.seed(seed)
58
+ if torch.cuda.is_available():
59
+ torch.cuda.manual_seed_all(seed)
60
+
61
+ def __enter__(self):
62
+ if self.device_ctx is not None:
63
+ self.device_ctx.__enter__()
64
+ return self
65
+
66
+ def __exit__(self, exc_type, exc_value, traceback):
67
+ if self.device_ctx is not None:
68
+ self.device_ctx.__exit__(exc_type, exc_value, traceback)
69
+
70
+ def barrier(self):
71
+ """Synchronize all processes"""
72
+ if dist.is_initialized():
73
+ dist.barrier()
74
+
75
+ def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
76
+ """All-reduce tensor across processes"""
77
+ if dist.is_initialized():
78
+ dist.all_reduce(tensor, op=op)
79
+ return tensor
80
+
81
+ # ------------------------------------------------------------------ #
82
+ # Model helpers
83
+ # ------------------------------------------------------------------ #
84
+ def prepare_model(self, model: torch.nn.Module, **kwargs):
85
+ if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
86
+ model.device = self.device
87
+ model = model.to(self.device)
88
+ if self.world_size > 1:
89
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
90
+ model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
91
+ self._ddp_model = model # Save DDP model reference for no_sync support
92
+ return model
93
+
94
+ @contextlib.contextmanager
95
+ def no_sync(self):
96
+ """
97
+ Context manager to skip gradient synchronization during gradient accumulation.
98
+ Only used outside the last micro-batch.
99
+ """
100
+ if self._ddp_model is not None:
101
+ with self._ddp_model.no_sync():
102
+ yield
103
+ else:
104
+ yield
105
+
106
+ @property
107
+ def device(self):
108
+ if torch.cuda.is_available():
109
+ return torch.device("cuda", self.local_rank)
110
+ if torch.backends.mps.is_available():
111
+ return torch.device("mps")
112
+ return torch.device("cpu")
113
+
114
+ # ------------------------------------------------------------------ #
115
+ # AMP helpers
116
+ # ------------------------------------------------------------------ #
117
+ def autocast(self, *args, **kwargs):
118
+ return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
119
+
120
+ def backward(self, loss: torch.Tensor):
121
+ self.scaler.scale(loss).backward()
122
+
123
+ def step(self, optimizer: torch.optim.Optimizer):
124
+ self.scaler.step(optimizer)
125
+
126
+ def update(self):
127
+ self.scaler.update()
128
+
129
+ # ------------------------------------------------------------------ #
130
+ # Data helpers
131
+ # ------------------------------------------------------------------ #
132
+ def prepare_dataloader(
133
+ self,
134
+ dataset: typing.Iterable,
135
+ *,
136
+ batch_size: int,
137
+ num_workers: int = 0,
138
+ shuffle: bool = True,
139
+ collate_fn=None,
140
+ drop_last: bool = False,
141
+ ) -> torch.utils.data.DataLoader:
142
+ if self.world_size > 1:
143
+ sampler = torch.utils.data.distributed.DistributedSampler(
144
+ dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
145
+ )
146
+ shuffle = False
147
+ else:
148
+ sampler = None
149
+
150
+ return torch.utils.data.DataLoader(
151
+ dataset,
152
+ batch_size=batch_size,
153
+ shuffle=shuffle if sampler is None else False,
154
+ sampler=sampler,
155
+ num_workers=num_workers,
156
+ collate_fn=collate_fn,
157
+ drop_last=drop_last,
158
+ pin_memory=True,
159
+ )
160
+
161
+ @staticmethod
162
+ def unwrap(model: torch.nn.Module) -> torch.nn.Module:
163
+ return model.module if hasattr(model, "module") else model
voxcpm/training/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argbind
4
+ import yaml
5
+ from pathlib import Path
6
+ from typing import Dict, Any
7
+
8
+
9
+ def load_yaml_config(path: str | Path) -> Dict[str, Any]:
10
+ """
11
+ Load a YAML configuration file into a dictionary suitable for argbind.
12
+ """
13
+ path = Path(path)
14
+ with path.open("r", encoding="utf-8") as f:
15
+ data = yaml.safe_load(f)
16
+ if not isinstance(data, dict):
17
+ raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
18
+ return data
19
+
20
+
21
+ def parse_args_with_config(config_path: str | Path | None = None):
22
+ """
23
+ Helper to unify CLI arguments and YAML configuration.
24
+
25
+ Usage mirrors minicpm-audio:
26
+ args = parse_args_with_config("conf/voxcpm/finetune.yml")
27
+ with argbind.scope(args):
28
+ ...
29
+ """
30
+ cli_args = argbind.parse_args()
31
+ if config_path is None:
32
+ return cli_args
33
+
34
+ yaml_args = load_yaml_config(config_path)
35
+ with argbind.scope(cli_args):
36
+ yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
37
+ cli_args.update(yaml_args)
38
+ return cli_args
voxcpm/training/data.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import argbind
5
+ import torch
6
+ from datasets import Audio, Dataset, DatasetDict, load_dataset
7
+ from torch.utils.data import Dataset as TorchDataset
8
+
9
+ from ..model.voxcpm import VoxCPMConfig
10
+ from ..modules.audiovae import AudioVAE
11
+ from .packers import AudioFeatureProcessingPacker
12
+
13
+ DEFAULT_TEXT_COLUMN = "text"
14
+ DEFAULT_AUDIO_COLUMN = "audio"
15
+ DEFAULT_ID_COLUMN = "dataset_id"
16
+
17
+
18
+ @argbind.bind()
19
+ def load_audio_text_datasets(
20
+ train_manifest: str,
21
+ val_manifest: str = "",
22
+ text_column: str = DEFAULT_TEXT_COLUMN,
23
+ audio_column: str = DEFAULT_AUDIO_COLUMN,
24
+ dataset_id_column: str = DEFAULT_ID_COLUMN,
25
+ sample_rate: int = 16_000,
26
+ num_proc: int = 1,
27
+ ) -> Tuple[Dataset, Optional[Dataset]]:
28
+ data_files = {"train": train_manifest}
29
+ if val_manifest:
30
+ data_files["validation"] = val_manifest
31
+
32
+ dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
33
+
34
+ def prepare(ds: Dataset) -> Dataset:
35
+ if audio_column not in ds.column_names:
36
+ raise ValueError(f"Expected '{audio_column}' column in manifest.")
37
+ # We cast to Audio to ensure proper handling during training,
38
+ # but for length calculation we might need raw path or duration if available.
39
+ # HF datasets usually don't compute duration automatically for 'Audio' column.
40
+ ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
41
+ if audio_column != DEFAULT_AUDIO_COLUMN:
42
+ ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
43
+ if text_column != DEFAULT_TEXT_COLUMN:
44
+ ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
45
+ if dataset_id_column and dataset_id_column in ds.column_names:
46
+ if dataset_id_column != DEFAULT_ID_COLUMN:
47
+ ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
48
+ else:
49
+ ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
50
+ return ds
51
+
52
+ train_ds = prepare(dataset_dict["train"])
53
+ val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
54
+ return train_ds, val_ds
55
+
56
+
57
+ def compute_sample_lengths(
58
+ ds: Dataset,
59
+ audio_vae_fps: int = 25,
60
+ patch_size: int = 1,
61
+ ) -> List[int]:
62
+ """
63
+ 预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
64
+
65
+ 逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
66
+ - 文本长度: len(text_ids)
67
+ - 音频长度:
68
+ duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
69
+ t_seq = ceil(t_vae / patch_size)
70
+ - 序列总长约为: text_len + t_seq + 2
71
+
72
+ Optimized: Use batch column access instead of iterating item by item.
73
+ """
74
+ # Batch access columns - much faster than per-item access
75
+ text_ids_list = ds["text_ids"]
76
+ text_lens = [len(t) for t in text_ids_list]
77
+
78
+ has_duration = "duration" in ds.column_names
79
+ if has_duration:
80
+ durations = ds["duration"]
81
+ else:
82
+ # Fallback: need to compute from audio (slow, but unavoidable without duration column)
83
+ durations = []
84
+ for i in range(len(ds)):
85
+ audio = ds[i][DEFAULT_AUDIO_COLUMN]
86
+ durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
87
+
88
+ # Vectorized length computation
89
+ lengths = []
90
+ for text_len, duration in zip(text_lens, durations):
91
+ t_vae = math.ceil(float(duration) * audio_vae_fps)
92
+ t_seq = math.ceil(t_vae / patch_size)
93
+ total_len = text_len + t_seq + 2
94
+ lengths.append(total_len)
95
+
96
+ return lengths
97
+
98
+
99
+ class HFVoxCPMDataset(TorchDataset):
100
+ """
101
+ Thin wrapper around a tokenized HuggingFace dataset that returns
102
+ PyTorch-friendly samples.
103
+ """
104
+
105
+ def __init__(self, dataset: Dataset):
106
+ self.dataset = dataset
107
+
108
+ def __len__(self):
109
+ return len(self.dataset)
110
+
111
+ def __getitem__(self, idx: int):
112
+ item = self.dataset[idx]
113
+ audio = item[DEFAULT_AUDIO_COLUMN]
114
+ return {
115
+ "text_ids": item["text_ids"],
116
+ "audio_array": audio["array"],
117
+ "audio_sampling_rate": audio["sampling_rate"],
118
+ "dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
119
+ "is_prompt": item.get("is_prompt", False),
120
+ }
121
+
122
+ @staticmethod
123
+ def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
124
+ if not seqs:
125
+ return torch.empty(0)
126
+ max_len = max(seq.shape[0] for seq in seqs)
127
+ padded = []
128
+ for seq in seqs:
129
+ if seq.shape[0] < max_len:
130
+ pad_width = (0, max_len - seq.shape[0])
131
+ seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
132
+ padded.append(seq)
133
+ return torch.stack(padded)
134
+
135
+ @classmethod
136
+ def collate_fn(cls, batch: List[Dict]):
137
+ text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
138
+ audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
139
+ dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
140
+ is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
141
+
142
+ text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
143
+ audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
144
+ task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
145
+
146
+ return {
147
+ "text_tokens": text_padded,
148
+ "audio_tokens": audio_padded,
149
+ "task_ids": task_ids,
150
+ "dataset_ids": dataset_ids,
151
+ "is_prompts": is_prompts,
152
+ }
153
+
154
+
155
+ class BatchProcessor:
156
+ """
157
+ Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
158
+ the minicpm-audio mechanics.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ *,
164
+ config: VoxCPMConfig,
165
+ audio_vae: AudioVAE,
166
+ dataset_cnt: int,
167
+ device: torch.device,
168
+ ):
169
+ self.device = device
170
+ self.dataset_cnt = dataset_cnt
171
+ self.audio_vae = audio_vae
172
+ self.audio_vae.to(device)
173
+ self.packer = AudioFeatureProcessingPacker(
174
+ dataset_cnt=dataset_cnt,
175
+ max_len=config.max_length,
176
+ patch_size=config.patch_size,
177
+ feat_dim=config.feat_dim,
178
+ audio_vae=self.audio_vae,
179
+ )
180
+
181
+ def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
182
+ audio_tokens = batch["audio_tokens"].to(self.device)
183
+ text_tokens = batch["text_tokens"].to(self.device)
184
+ task_ids = batch["task_ids"].to(self.device)
185
+ dataset_ids = batch["dataset_ids"].to(self.device)
186
+
187
+ packed = self.packer(
188
+ audio_tokens=audio_tokens,
189
+ text_tokens=text_tokens,
190
+ task_ids=task_ids,
191
+ dataset_ids=dataset_ids,
192
+ is_prompts=batch["is_prompts"],
193
+ )
194
+ return packed
195
+
196
+
197
+ def build_dataloader(
198
+ hf_dataset: Dataset,
199
+ *,
200
+ accelerator,
201
+ batch_size: int,
202
+ num_workers: int,
203
+ drop_last: bool = False,
204
+ ) -> torch.utils.data.DataLoader:
205
+ torch_dataset = HFVoxCPMDataset(hf_dataset)
206
+ # Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
207
+ return accelerator.prepare_dataloader(
208
+ torch_dataset,
209
+ batch_size=batch_size,
210
+ num_workers=num_workers,
211
+ shuffle=True,
212
+ collate_fn=HFVoxCPMDataset.collate_fn,
213
+ drop_last=drop_last,
214
+ )
voxcpm/training/packers.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+
8
+ class AudioFeatureProcessingPacker:
9
+ """
10
+ Adapted from the minicpm-audio training utilities. It converts raw text and
11
+ audio tokens into the packed multimodal representation required by VoxCPM.
12
+ """
13
+
14
+ def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
15
+ self.audio_start_id = 101
16
+ self.audio_end_id = 102
17
+ # unused now
18
+ self.audio_prompt_start_id = 103
19
+ self.audio_prompt_end_id = 104
20
+ self.text_eos_token_id = 2
21
+
22
+ self.patch_size = patch_size
23
+ self.patch_len = audio_vae.hop_length * self.patch_size
24
+ self.feat_dim = feat_dim
25
+ self.dataset_cnt = max(dataset_cnt, 1)
26
+ self.max_len = max_len
27
+
28
+ self.audio_vae = audio_vae
29
+
30
+ self.process_functions = {"tts": self.process_tts_data}
31
+ self.task_id_map = {"tts": 1}
32
+ self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
33
+
34
+ # ------------------------------------------------------------------ #
35
+ # Helpers
36
+ # ------------------------------------------------------------------ #
37
+ @staticmethod
38
+ def _first_pad_position(tokens: torch.Tensor):
39
+ positions = (tokens == -100).nonzero(as_tuple=True)
40
+ if positions[0].numel() == 0:
41
+ return None
42
+ return int(positions[0][0])
43
+
44
+ def unpad_text_tokens(self, tokens: torch.Tensor):
45
+ pad_pos = self._first_pad_position(tokens)
46
+ return tokens if pad_pos is None else tokens[:pad_pos]
47
+
48
+ def unpad_audio_tokens(self, tokens: torch.Tensor):
49
+ pad_pos = self._first_pad_position(tokens)
50
+ return tokens if pad_pos is None else tokens[:pad_pos]
51
+
52
+ def encode_audio(self, wav: torch.Tensor):
53
+ """
54
+ Encode raw waveform into latent features using AudioVAE.
55
+
56
+ AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
57
+ We then transpose to [B, T, D] to match downstream expectations.
58
+ """
59
+ wav = wav.unsqueeze(0) # [1, T]
60
+ wav = wav.unsqueeze(1) # [1, 1, T]
61
+ wav_len = wav.size(-1)
62
+ if wav_len % self.patch_len != 0:
63
+ padding_size = self.patch_len - wav_len % self.patch_len
64
+ wav = torch.nn.functional.pad(wav, (0, padding_size))
65
+
66
+ with torch.no_grad():
67
+ z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
68
+ feat = z.transpose(1, 2) # [1, T', D]
69
+ return feat
70
+
71
+ # ------------------------------------------------------------------ #
72
+ # Main entry point
73
+ # ------------------------------------------------------------------ #
74
+ def __call__(
75
+ self,
76
+ audio_tokens: torch.Tensor,
77
+ text_tokens: torch.Tensor,
78
+ task_ids: torch.Tensor,
79
+ dataset_ids: torch.Tensor,
80
+ is_prompts: List[bool],
81
+ ) -> Dict[str, torch.Tensor]:
82
+ """
83
+ Padding-based batching: each sample in the input batch is processed
84
+ independently and then padded to a common length (capped by ``max_len``).
85
+ The result tensors all have shape [B, T, ...].
86
+ """
87
+ device = audio_tokens.device
88
+ max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
89
+ dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
90
+
91
+ text_tokens_list: List[torch.Tensor] = []
92
+ audio_feats_list: List[torch.Tensor] = []
93
+ text_mask_list: List[torch.Tensor] = []
94
+ audio_mask_list: List[torch.Tensor] = []
95
+ loss_mask_list: List[torch.Tensor] = []
96
+ labels_list: List[torch.Tensor] = []
97
+ audio_task_ids_list: List[torch.Tensor] = []
98
+ audio_dataset_ids_list: List[torch.Tensor] = []
99
+ lengths: List[int] = []
100
+
101
+ audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
102
+ text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
103
+
104
+ for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
105
+ audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
106
+ ):
107
+ unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
108
+ unpad_text_token = self.unpad_text_tokens(text_token)
109
+ usage = self.id_to_task[task_id]
110
+
111
+ (
112
+ packed_text,
113
+ audio_feat,
114
+ text_mask,
115
+ audio_mask,
116
+ loss_mask,
117
+ labels,
118
+ audio_duration,
119
+ text_token_count,
120
+ ) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
121
+
122
+ audio_duration_consumed[dataset_idx] += audio_duration
123
+ text_token_consumed[dataset_idx] += text_token_count
124
+
125
+ audio_task_id = torch.zeros_like(audio_mask)
126
+ audio_task_id[audio_mask == 1] = self.task_id_map[usage]
127
+
128
+ audio_dataset_id = torch.zeros_like(audio_mask)
129
+ audio_dataset_id[audio_mask == 1] = dataset_idx + 1
130
+
131
+ text_tokens_list.append(packed_text)
132
+ text_mask_list.append(text_mask)
133
+ audio_feats_list.append(audio_feat)
134
+ audio_mask_list.append(audio_mask)
135
+ loss_mask_list.append(loss_mask)
136
+ labels_list.append(labels)
137
+ audio_task_ids_list.append(audio_task_id)
138
+ audio_dataset_ids_list.append(audio_dataset_id)
139
+ lengths.append(packed_text.shape[0])
140
+
141
+ # Determine padded length per batch (cap by self.max_len)
142
+ if lengths:
143
+ max_len = min(self.max_len, max(lengths))
144
+ else:
145
+ max_len = self.max_len
146
+
147
+ def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
148
+ if x.size(0) >= max_len:
149
+ return x[:max_len]
150
+ pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
151
+ return torch.cat([x, pad], dim=0)
152
+
153
+ def pad_3d(x: torch.Tensor) -> torch.Tensor:
154
+ # x: [T, P, D]
155
+ if x.size(0) >= max_len:
156
+ return x[:max_len]
157
+ pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
158
+ return torch.cat([x, pad], dim=0)
159
+
160
+ if lengths:
161
+ text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
162
+ text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
163
+ audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
164
+ audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
165
+ loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
166
+ labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
167
+ audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
168
+ audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
169
+
170
+ # Position ids: [B, T], simple 0..L_i-1 then padded with 0
171
+ position_ids_list = []
172
+ for L in lengths:
173
+ L_clip = min(L, max_len)
174
+ pos = torch.arange(0, L_clip, device=device)
175
+ if L_clip < max_len:
176
+ pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
177
+ pos = torch.cat([pos, pad], dim=0)
178
+ position_ids_list.append(pos)
179
+ position_ids = torch.stack(position_ids_list, dim=0)
180
+ else:
181
+ # Empty batch fallback (shouldn't really happen)
182
+ text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
183
+ text_mask_batch = torch.zeros_like(text_tokens_batch)
184
+ audio_feats_batch = torch.zeros(
185
+ (0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
186
+ )
187
+ audio_mask_batch = torch.zeros_like(text_tokens_batch)
188
+ loss_mask_batch = torch.zeros_like(text_tokens_batch)
189
+ labels_batch = torch.zeros_like(text_tokens_batch)
190
+ audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
191
+ audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
192
+ position_ids = torch.zeros_like(text_tokens_batch)
193
+
194
+ audio_duration_consumed = audio_duration_consumed.to(torch.long)
195
+ text_token_consumed = text_token_consumed.to(torch.long)
196
+
197
+ return {
198
+ "text_tokens": text_tokens_batch,
199
+ "audio_feats": audio_feats_batch,
200
+ "text_mask": text_mask_batch,
201
+ "audio_mask": audio_mask_batch,
202
+ "loss_mask": loss_mask_batch,
203
+ "position_ids": position_ids,
204
+ "labels": labels_batch,
205
+ "audio_task_ids": audio_task_ids_batch,
206
+ "audio_dataset_ids": audio_dataset_ids_batch,
207
+ "audio_duration_consumed": audio_duration_consumed,
208
+ "text_token_consumed": text_token_consumed,
209
+ }
210
+
211
+ # ------------------------------------------------------------------ #
212
+ # Feature extraction helpers
213
+ # ------------------------------------------------------------------ #
214
+ def extract_audio_feats(self, audio_data: torch.Tensor):
215
+ audio_feats = self.encode_audio(audio_data)
216
+ if audio_feats.size(1) % self.patch_size != 0:
217
+ audio_feats_ = audio_feats.transpose(1, 2)
218
+ padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
219
+ audio_feats = padding.transpose(1, 2)
220
+
221
+ audio_duration = audio_feats.size(1) / 25
222
+ audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
223
+ return audio_feats, audio_duration
224
+
225
+ def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
226
+ text_token_info = torch.cat(
227
+ [
228
+ text_token,
229
+ torch.tensor(
230
+ [self.audio_prompt_start_id if is_prompt else self.audio_start_id],
231
+ dtype=torch.int32,
232
+ device=text_token.device,
233
+ ),
234
+ ],
235
+ dim=-1,
236
+ )
237
+ text_token_count = len(text_token)
238
+ text_length = text_token_info.shape[0]
239
+ audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
240
+ audio_feat_info = audio_feat_info.squeeze(0)
241
+ audio_length = audio_feat_info.shape[0]
242
+
243
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
244
+ text_token_info = torch.cat(
245
+ [
246
+ text_token_info,
247
+ text_pad_token,
248
+ torch.tensor(
249
+ [self.audio_prompt_end_id if is_prompt else self.audio_end_id],
250
+ dtype=torch.int32,
251
+ device=text_token.device,
252
+ ),
253
+ ]
254
+ )
255
+ audio_pad_feat = torch.zeros(
256
+ (text_length, self.patch_size, audio_feat_info.size(-1)),
257
+ dtype=torch.float32,
258
+ device=text_token.device,
259
+ )
260
+ audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
261
+
262
+ text_mask = (
263
+ torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
264
+ .type(torch.int32)
265
+ .to(text_token.device)
266
+ )
267
+ audio_mask = (
268
+ torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
269
+ .type(torch.int32)
270
+ .to(text_token.device)
271
+ )
272
+ loss_mask = (
273
+ torch.cat(
274
+ [
275
+ torch.zeros(text_length),
276
+ torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
277
+ torch.zeros(1),
278
+ ]
279
+ )
280
+ .type(torch.int32)
281
+ .to(text_token.device)
282
+ )
283
+
284
+ labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
285
+ labels[-2] = 1
286
+
287
+ return (
288
+ text_token_info,
289
+ audio_feat_info,
290
+ text_mask,
291
+ audio_mask,
292
+ loss_mask,
293
+ labels,
294
+ audio_duration,
295
+ text_token_count,
296
+ )
voxcpm/training/state.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class TrainingState:
8
+ """
9
+ Container that mirrors the object returned in the minicpm-audio training
10
+ loop. It holds persistent references to the model, optimizer, scheduler,
11
+ dataloaders and tracker.
12
+ """
13
+
14
+ generator: object
15
+ optimizer: object
16
+ scheduler: object
17
+ train_loader: object
18
+ val_loader: object
19
+ tracker: object
20
+ batch_processor: object
voxcpm/training/tracker.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class TrainingTracker:
11
+ """
12
+ Lightweight tracker inspired by the minimcpm-audio training workflow.
13
+
14
+ It keeps track of the current global step, prints rank-aware messages,
15
+ optionally writes to TensorBoard via a provided writer, and stores progress
16
+ in a logfile for later inspection.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ writer=None,
23
+ log_file: Optional[str] = None,
24
+ rank: int = 0,
25
+ ):
26
+ self.writer = writer
27
+ self.log_file = Path(log_file) if log_file else None
28
+ if self.log_file:
29
+ self.log_file.parent.mkdir(parents=True, exist_ok=True)
30
+ self.rank = rank
31
+ self.step = 0
32
+ # Record the time of the last log to calculate the interval
33
+ self._last_log_time: float | None = None
34
+
35
+ # ------------------------------------------------------------------ #
36
+ # Logging helpers
37
+ # ------------------------------------------------------------------ #
38
+ def print(self, message: str):
39
+ if self.rank == 0:
40
+ print(message, flush=True, file=sys.stderr)
41
+ if self.log_file:
42
+ with self.log_file.open("a", encoding="utf-8") as f:
43
+ f.write(message + "\n")
44
+
45
+ def log_metrics(self, metrics: Dict[str, float], split: str):
46
+ if self.rank == 0:
47
+ now = time.time()
48
+ dt_str = ""
49
+ if self._last_log_time is not None:
50
+ dt = now - self._last_log_time
51
+ dt_str = f", log interval: {dt:.2f}s"
52
+ self._last_log_time = now
53
+
54
+ formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
55
+ self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
56
+ if self.writer is not None:
57
+ for key, value in metrics.items():
58
+ if isinstance(value, (int, float)):
59
+ self.writer.add_scalar(f"{split}/{key}", value, self.step)
60
+
61
+ def done(self, split: str, message: str):
62
+ self.print(f"[{split}] {message}")
63
+
64
+ # ------------------------------------------------------------------ #
65
+ # State dict
66
+ # ------------------------------------------------------------------ #
67
+ def state_dict(self):
68
+ return {"step": self.step}
69
+
70
+ def load_state_dict(self, state):
71
+ self.step = int(state.get("step", 0))
72
+
73
+ # ------------------------------------------------------------------ #
74
+ # Context manager compatibility (for parity with minicpm-audio code)
75
+ # ------------------------------------------------------------------ #
76
+ @contextlib.contextmanager
77
+ def live(self):
78
+ yield
voxcpm/utils/text_normalize.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
2
+ import re
3
+ import regex
4
+ import inflect
5
+ from wetext import Normalizer
6
+
7
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
8
+
9
+
10
+ # whether contain chinese character
11
+ def contains_chinese(text):
12
+ return bool(chinese_char_pattern.search(text))
13
+
14
+
15
+ # replace special symbol
16
+ def replace_corner_mark(text):
17
+ text = text.replace("²", "平方")
18
+ text = text.replace("³", "立方")
19
+ text = text.replace("√", "根号")
20
+ text = text.replace("≈", "约等于")
21
+ text = text.replace("<", "小于")
22
+ return text
23
+
24
+
25
+ # remove meaningless symbol
26
+ def remove_bracket(text):
27
+ text = text.replace("(", " ").replace(")", " ")
28
+ text = text.replace("【", " ").replace("】", " ")
29
+ text = text.replace("`", "").replace("`", "")
30
+ text = text.replace("——", " ")
31
+ return text
32
+
33
+
34
+ # spell Arabic numerals
35
+ def spell_out_number(text: str, inflect_parser):
36
+ new_text = []
37
+ st = None
38
+ for i, c in enumerate(text):
39
+ if not c.isdigit():
40
+ if st is not None:
41
+ num_str = inflect_parser.number_to_words(text[st:i])
42
+ new_text.append(num_str)
43
+ st = None
44
+ new_text.append(c)
45
+ else:
46
+ if st is None:
47
+ st = i
48
+ if st is not None and st < len(text):
49
+ num_str = inflect_parser.number_to_words(text[st:])
50
+ new_text.append(num_str)
51
+ return "".join(new_text)
52
+
53
+
54
+ # split paragrah logic:
55
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
56
+ # 2. cal sentence len according to lang
57
+ # 3. split sentence according to puncatation
58
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
59
+ def calc_utt_length(_text: str):
60
+ if lang == "zh":
61
+ return len(_text)
62
+ else:
63
+ return len(tokenize(_text))
64
+
65
+ def should_merge(_text: str):
66
+ if lang == "zh":
67
+ return len(_text) < merge_len
68
+ else:
69
+ return len(tokenize(_text)) < merge_len
70
+
71
+ if lang == "zh":
72
+ pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
73
+ else:
74
+ pounc = [".", "?", "!", ";", ":"]
75
+ if comma_split:
76
+ pounc.extend([",", ","])
77
+ st = 0
78
+ utts = []
79
+ for i, c in enumerate(text):
80
+ if c in pounc:
81
+ if len(text[st:i]) > 0:
82
+ utts.append(text[st:i] + c)
83
+ if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
84
+ tmp = utts.pop(-1)
85
+ utts.append(tmp + text[i + 1])
86
+ st = i + 2
87
+ else:
88
+ st = i + 1
89
+ if len(utts) == 0:
90
+ if lang == "zh":
91
+ utts.append(text + "。")
92
+ else:
93
+ utts.append(text + ".")
94
+ final_utts = []
95
+ cur_utt = ""
96
+ for utt in utts:
97
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
98
+ final_utts.append(cur_utt)
99
+ cur_utt = ""
100
+ cur_utt = cur_utt + utt
101
+ if len(cur_utt) > 0:
102
+ if should_merge(cur_utt) and len(final_utts) != 0:
103
+ final_utts[-1] = final_utts[-1] + cur_utt
104
+ else:
105
+ final_utts.append(cur_utt)
106
+
107
+ return final_utts
108
+
109
+
110
+ # remove blank between chinese character
111
+ def replace_blank(text: str):
112
+ out_str = []
113
+ for i, c in enumerate(text):
114
+ if c == " ":
115
+ if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
116
+ out_str.append(c)
117
+ else:
118
+ out_str.append(c)
119
+ return "".join(out_str)
120
+
121
+
122
+ def clean_markdown(md_text: str) -> str:
123
+ # 去除代码块 ``` ```(包括多行)
124
+ md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
125
+
126
+ # 去除内联代码 `code`
127
+ md_text = re.sub(r"`[^`]*`", "", md_text)
128
+
129
+ # 去除图片语法 ![alt](url)
130
+ md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
131
+
132
+ # 去除链接但保留文本 [text](url) -> text
133
+ md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
134
+
135
+ # 替换无序列表符号
136
+ md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
137
+
138
+ # 去除HTML标签
139
+ md_text = re.sub(r"<[^>]+>", "", md_text)
140
+
141
+ # 去除标题符号(#)
142
+ md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
143
+
144
+ # 去除多余空格和空行
145
+ md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
146
+ md_text = md_text.strip()
147
+
148
+ return md_text
149
+
150
+
151
+ def clean_text(text):
152
+ # 去除 Markdown 语法
153
+ text = clean_markdown(text)
154
+ # 匹配并移除表情符号
155
+ text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
156
+ # 去除换行符
157
+ text = text.replace("\n", " ")
158
+ text = text.replace("\t", " ")
159
+ text = text.replace("“", '"').replace("”", '"')
160
+ return text
161
+
162
+
163
+ class TextNormalizer:
164
+ def __init__(self, tokenizer=None):
165
+ self.tokenizer = tokenizer
166
+ self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
167
+ self.en_tn_model = Normalizer(lang="en", operator="tn")
168
+ self.inflect_parser = inflect.engine()
169
+
170
+ def normalize(self, text, split=False):
171
+ # 去除 Markdown 语法,去除表情符号,去除换行符
172
+ lang = "zh" if contains_chinese(text) else "en"
173
+ text = clean_text(text)
174
+ if lang == "zh":
175
+ text = text.replace(
176
+ "=", "等于"
177
+ ) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
178
+ if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
179
+ text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
180
+ text = self.zh_tn_model.normalize(text)
181
+ text = replace_blank(text)
182
+ text = replace_corner_mark(text)
183
+ text = remove_bracket(text)
184
+ else:
185
+ text = self.en_tn_model.normalize(text)
186
+ text = spell_out_number(text, self.inflect_parser)
187
+ if split is False:
188
+ return text
voxcpm/zipenhancer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ZipEnhancer Module - Audio Denoising Enhancer
3
+
4
+ Provides on-demand import ZipEnhancer functionality for audio denoising processing.
5
+ Related dependencies are imported only when denoising functionality is needed.
6
+ """
7
+
8
+ import os
9
+ import tempfile
10
+ from typing import Optional
11
+ import torchaudio
12
+ from modelscope.pipelines import pipeline
13
+ from modelscope.utils.constant import Tasks
14
+
15
+
16
+ class ZipEnhancer:
17
+ """ZipEnhancer Audio Denoising Enhancer"""
18
+
19
+ def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
20
+ """
21
+ Initialize ZipEnhancer
22
+ Args:
23
+ model_path: ModelScope model path or local path
24
+ """
25
+ self.model_path = model_path
26
+ self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
27
+
28
+ def _normalize_loudness(self, wav_path: str):
29
+ """
30
+ Audio loudness normalization
31
+
32
+ Args:
33
+ wav_path: Audio file path
34
+ """
35
+ audio, sr = torchaudio.load(wav_path)
36
+ loudness = torchaudio.functional.loudness(audio, sr)
37
+ normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
38
+ torchaudio.save(wav_path, normalized_audio, sr)
39
+
40
+ def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
41
+ """
42
+ Audio denoising enhancement
43
+ Args:
44
+ input_path: Input audio file path
45
+ output_path: Output audio file path (optional, creates temp file by default)
46
+ normalize_loudness: Whether to perform loudness normalization
47
+ Returns:
48
+ str: Output audio file path
49
+ Raises:
50
+ RuntimeError: If pipeline is not initialized or processing fails
51
+ """
52
+ if not os.path.exists(input_path):
53
+ raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
54
+ # Create temporary file if no output path is specified
55
+ if output_path is None:
56
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
57
+ output_path = tmp_file.name
58
+ try:
59
+ # Perform denoising processing
60
+ self._pipeline(input_path, output_path=output_path)
61
+ # Loudness normalization
62
+ if normalize_loudness:
63
+ self._normalize_loudness(output_path)
64
+ return output_path
65
+ except Exception as e:
66
+ # Clean up possibly created temporary files
67
+ if output_path and os.path.exists(output_path):
68
+ try:
69
+ os.unlink(output_path)
70
+ except OSError:
71
+ pass
72
+ raise RuntimeError(f"Audio denoising processing failed: {e}")