刘鑫 commited on
Commit
0203b68
·
1 Parent(s): 5351045

fix: remove unsupported inference_timesteps parameter and dit_steps slider

Browse files

nanovllm-voxcpm does not support inference_timesteps; remove the parameter
to prevent accidental server rebuilds when users adjust the slider.

Made-with: Cursor

Files changed (1) hide show
  1. app.py +13 -66
app.py CHANGED
@@ -60,7 +60,6 @@ _configure_cache_dirs()
60
  _asr_model = None
61
  _voxcpm_server = None
62
  _model_info = None
63
- _server_inference_timesteps = None
64
  _denoiser = None
65
  _server_lock = Lock()
66
  _prewarm_lock = Lock()
@@ -102,9 +101,6 @@ def _get_devices_env() -> list[int]:
102
  return [int(part) for part in values]
103
 
104
 
105
- DEFAULT_INFERENCE_TIMESTEPS = _get_int_env("NANOVLLM_INFERENCE_TIMESTEPS", 10)
106
-
107
-
108
  def _resolve_model_ref() -> str:
109
  for env_name in ("NANOVLLM_MODEL", "NANOVLLM_MODEL_PATH", "HF_REPO_ID"):
110
  value = os.environ.get(env_name, "").strip()
@@ -258,17 +254,10 @@ def _safe_prompt_wav_recognition(use_prompt_text: bool, prompt_wav: Optional[str
258
  return ""
259
 
260
 
261
- def _validate_reference_audio_upload(
262
- audio_path: Optional[str], request: gr.Request
263
- ) -> Optional[str]:
264
- if audio_path is None or not audio_path.strip():
265
- return audio_path
266
- _validate_reference_audio_duration(audio_path, request)
267
- return audio_path
268
 
269
 
270
  def _stop_server_if_needed() -> None:
271
- global _voxcpm_server, _model_info, _server_inference_timesteps
272
  if _voxcpm_server is None:
273
  return
274
 
@@ -281,7 +270,6 @@ def _stop_server_if_needed() -> None:
281
 
282
  _voxcpm_server = None
283
  _model_info = None
284
- _server_inference_timesteps = None
285
 
286
 
287
  atexit.register(_stop_server_if_needed)
@@ -381,8 +369,6 @@ _I18N_TRANSLATIONS = {
381
  "normalize_info": "Normalize numbers, dates, and abbreviations via wetext",
382
  "cfg_label": "CFG (guidance scale)",
383
  "cfg_info": "Higher → closer to the prompt / reference; lower → more creative variation",
384
- "dit_steps_label": "LocDiT flow-matching steps",
385
- "dit_steps_info": "LocDiT flow-matching steps — more steps → maybe better audio quality, but slower",
386
  "reference_audio_too_long_error": "Reference audio is too long. Please upload audio no longer than 50 seconds.",
387
  "usage_instructions": _USAGE_INSTRUCTIONS_EN,
388
  "examples_footer": _EXAMPLES_FOOTER_EN,
@@ -405,8 +391,6 @@ _I18N_TRANSLATIONS = {
405
  "normalize_info": "自动规范化数字、日期及缩写(基于 wetext)",
406
  "cfg_label": "CFG(引导强度)",
407
  "cfg_info": "数值越高 → 越贴合提示/参考音色;数值越低 → 生成风格更自由",
408
- "dit_steps_label": "LocDiT 流匹配迭代步数",
409
- "dit_steps_info": "LocDiT 流匹配生成迭代步数 — 步数越多 → 可能生成更好的音频质量,但速度变慢",
410
  "reference_audio_too_long_error": "参考音频太长了,请上传不超过 50 秒的音频。",
411
  "usage_instructions": _USAGE_INSTRUCTIONS_ZH,
412
  "examples_footer": _EXAMPLES_FOOTER_ZH,
@@ -534,33 +518,22 @@ def get_asr_model():
534
  return _asr_model
535
 
536
 
537
- def get_voxcpm_server(inference_timesteps: int):
538
- global _voxcpm_server, _model_info, _server_inference_timesteps
539
- if _voxcpm_server is not None and _server_inference_timesteps == inference_timesteps:
540
  return _voxcpm_server
541
 
542
  with _server_lock:
543
- if _voxcpm_server is not None and _server_inference_timesteps == inference_timesteps:
544
  return _voxcpm_server
545
 
546
- if _voxcpm_server is not None and _server_inference_timesteps != inference_timesteps:
547
- logger.info(
548
- f"Rebuilding nano-vLLM server for inference_timesteps={inference_timesteps} "
549
- f"(previous={_server_inference_timesteps})"
550
- )
551
- _stop_server_if_needed()
552
-
553
  _log_runtime_diagnostics_once()
554
  from nanovllm_voxcpm import VoxCPM
555
 
556
  model_ref = _resolve_model_ref()
557
- logger.info(
558
- f"Loading nano-vLLM VoxCPM server from {model_ref} "
559
- f"with inference_timesteps={inference_timesteps} ..."
560
- )
561
  _voxcpm_server = VoxCPM.from_pretrained(
562
  model=model_ref,
563
- inference_timesteps=int(inference_timesteps),
564
  max_num_batched_tokens=_get_int_env("NANOVLLM_SERVERPOOL_MAX_NUM_BATCHED_TOKENS", 8192),
565
  max_num_seqs=_get_int_env("NANOVLLM_SERVERPOOL_MAX_NUM_SEQS", 16),
566
  max_model_len=_get_int_env("NANOVLLM_SERVERPOOL_MAX_MODEL_LEN", 4096),
@@ -569,25 +542,22 @@ def get_voxcpm_server(inference_timesteps: int):
569
  devices=_get_devices_env(),
570
  )
571
  _model_info = _voxcpm_server.get_model_info()
572
- _server_inference_timesteps = inference_timesteps
573
  logger.info(f"nano-vLLM VoxCPM server loaded: {_model_info}")
574
  return _voxcpm_server
575
 
576
 
577
- def get_model_info(inference_timesteps: int) -> dict:
578
  global _model_info
579
- if _model_info is None or _server_inference_timesteps != inference_timesteps:
580
- get_voxcpm_server(inference_timesteps)
581
  assert _model_info is not None
582
  return _model_info
583
 
584
 
585
  def _prewarm_backend() -> None:
586
  try:
587
- logger.info(
588
- f"Starting backend prewarm with inference_timesteps={DEFAULT_INFERENCE_TIMESTEPS} ..."
589
- )
590
- get_voxcpm_server(DEFAULT_INFERENCE_TIMESTEPS)
591
  logger.info("Backend prewarm completed.")
592
  except Exception as exc:
593
  logger.warning(f"Backend prewarm failed: {exc}")
@@ -632,14 +602,12 @@ def _generate_tts_audio_once(
632
  cfg_value_input: float = 2.0,
633
  do_normalize: bool = True,
634
  denoise: bool = True,
635
- inference_timesteps: int = 10,
636
  request: Optional[gr.Request] = None,
637
  ) -> Tuple[int, np.ndarray]:
638
  temp_audio_path = None
639
  try:
640
- timesteps = int(inference_timesteps)
641
- server = get_voxcpm_server(timesteps)
642
- model_info = get_model_info(timesteps)
643
 
644
  text = (text_input or "").strip()
645
  if len(text) == 0:
@@ -720,7 +688,6 @@ def generate_tts_audio(
720
  cfg_value_input: float = 2.0,
721
  do_normalize: bool = True,
722
  denoise: bool = True,
723
- inference_timesteps: int = 10,
724
  request: Optional[gr.Request] = None,
725
  ) -> Tuple[int, np.ndarray]:
726
  request_payload = {
@@ -733,7 +700,6 @@ def generate_tts_audio(
733
  "cfg_value": float(cfg_value_input),
734
  "do_normalize": bool(do_normalize),
735
  "denoise": bool(denoise),
736
- "inference_timesteps": int(inference_timesteps),
737
  "has_reference_audio": bool(reference_wav_path_input and reference_wav_path_input.strip()),
738
  }
739
  if request_payload["has_reference_audio"]:
@@ -754,7 +720,6 @@ def generate_tts_audio(
754
  cfg_value_input=cfg_value_input,
755
  do_normalize=do_normalize,
756
  denoise=denoise,
757
- inference_timesteps=inference_timesteps,
758
  request=request,
759
  )
760
  try:
@@ -781,7 +746,6 @@ def generate_tts_audio(
781
  cfg_value_input=cfg_value_input,
782
  do_normalize=do_normalize,
783
  denoise=denoise,
784
- inference_timesteps=inference_timesteps,
785
  request=request,
786
  )
787
  try:
@@ -881,28 +845,12 @@ def create_demo_interface():
881
  label=I18N("cfg_label"),
882
  info=I18N("cfg_info"),
883
  )
884
- dit_steps = gr.Slider(
885
- minimum=1,
886
- maximum=50,
887
- value=DEFAULT_INFERENCE_TIMESTEPS,
888
- step=1,
889
- label=I18N("dit_steps_label"),
890
- info=I18N("dit_steps_info"),
891
- )
892
-
893
  run_btn = gr.Button(I18N("generate_btn"), variant="primary", size="lg")
894
 
895
  with gr.Column():
896
  audio_output = gr.Audio(label=I18N("generated_audio_label"))
897
  gr.Markdown(I18N("examples_footer"))
898
 
899
- reference_wav.change(
900
- fn=_validate_reference_audio_upload,
901
- inputs=[reference_wav],
902
- outputs=[reference_wav],
903
- show_progress=False,
904
- )
905
-
906
  show_prompt_text.change(
907
  fn=_on_toggle_instant,
908
  inputs=[show_prompt_text],
@@ -924,7 +872,6 @@ def create_demo_interface():
924
  cfg_value,
925
  DoNormalizeText,
926
  DoDenoisePromptAudio,
927
- dit_steps,
928
  ],
929
  outputs=[audio_output],
930
  show_progress=True,
 
60
  _asr_model = None
61
  _voxcpm_server = None
62
  _model_info = None
 
63
  _denoiser = None
64
  _server_lock = Lock()
65
  _prewarm_lock = Lock()
 
101
  return [int(part) for part in values]
102
 
103
 
 
 
 
104
  def _resolve_model_ref() -> str:
105
  for env_name in ("NANOVLLM_MODEL", "NANOVLLM_MODEL_PATH", "HF_REPO_ID"):
106
  value = os.environ.get(env_name, "").strip()
 
254
  return ""
255
 
256
 
 
 
 
 
 
 
 
257
 
258
 
259
  def _stop_server_if_needed() -> None:
260
+ global _voxcpm_server, _model_info
261
  if _voxcpm_server is None:
262
  return
263
 
 
270
 
271
  _voxcpm_server = None
272
  _model_info = None
 
273
 
274
 
275
  atexit.register(_stop_server_if_needed)
 
369
  "normalize_info": "Normalize numbers, dates, and abbreviations via wetext",
370
  "cfg_label": "CFG (guidance scale)",
371
  "cfg_info": "Higher → closer to the prompt / reference; lower → more creative variation",
 
 
372
  "reference_audio_too_long_error": "Reference audio is too long. Please upload audio no longer than 50 seconds.",
373
  "usage_instructions": _USAGE_INSTRUCTIONS_EN,
374
  "examples_footer": _EXAMPLES_FOOTER_EN,
 
391
  "normalize_info": "自动规范化数字、日期及缩写(基于 wetext)",
392
  "cfg_label": "CFG(引导强度)",
393
  "cfg_info": "数值越高 → 越贴合提示/参考音色;数值越低 → 生成风格更自由",
 
 
394
  "reference_audio_too_long_error": "参考音频太长了,请上传不超过 50 秒的音频。",
395
  "usage_instructions": _USAGE_INSTRUCTIONS_ZH,
396
  "examples_footer": _EXAMPLES_FOOTER_ZH,
 
518
  return _asr_model
519
 
520
 
521
+ def get_voxcpm_server():
522
+ global _voxcpm_server, _model_info
523
+ if _voxcpm_server is not None:
524
  return _voxcpm_server
525
 
526
  with _server_lock:
527
+ if _voxcpm_server is not None:
528
  return _voxcpm_server
529
 
 
 
 
 
 
 
 
530
  _log_runtime_diagnostics_once()
531
  from nanovllm_voxcpm import VoxCPM
532
 
533
  model_ref = _resolve_model_ref()
534
+ logger.info(f"Loading nano-vLLM VoxCPM server from {model_ref} ...")
 
 
 
535
  _voxcpm_server = VoxCPM.from_pretrained(
536
  model=model_ref,
 
537
  max_num_batched_tokens=_get_int_env("NANOVLLM_SERVERPOOL_MAX_NUM_BATCHED_TOKENS", 8192),
538
  max_num_seqs=_get_int_env("NANOVLLM_SERVERPOOL_MAX_NUM_SEQS", 16),
539
  max_model_len=_get_int_env("NANOVLLM_SERVERPOOL_MAX_MODEL_LEN", 4096),
 
542
  devices=_get_devices_env(),
543
  )
544
  _model_info = _voxcpm_server.get_model_info()
 
545
  logger.info(f"nano-vLLM VoxCPM server loaded: {_model_info}")
546
  return _voxcpm_server
547
 
548
 
549
+ def get_model_info() -> dict:
550
  global _model_info
551
+ if _model_info is None:
552
+ get_voxcpm_server()
553
  assert _model_info is not None
554
  return _model_info
555
 
556
 
557
  def _prewarm_backend() -> None:
558
  try:
559
+ logger.info("Starting backend prewarm ...")
560
+ get_voxcpm_server()
 
 
561
  logger.info("Backend prewarm completed.")
562
  except Exception as exc:
563
  logger.warning(f"Backend prewarm failed: {exc}")
 
602
  cfg_value_input: float = 2.0,
603
  do_normalize: bool = True,
604
  denoise: bool = True,
 
605
  request: Optional[gr.Request] = None,
606
  ) -> Tuple[int, np.ndarray]:
607
  temp_audio_path = None
608
  try:
609
+ server = get_voxcpm_server()
610
+ model_info = get_model_info()
 
611
 
612
  text = (text_input or "").strip()
613
  if len(text) == 0:
 
688
  cfg_value_input: float = 2.0,
689
  do_normalize: bool = True,
690
  denoise: bool = True,
 
691
  request: Optional[gr.Request] = None,
692
  ) -> Tuple[int, np.ndarray]:
693
  request_payload = {
 
700
  "cfg_value": float(cfg_value_input),
701
  "do_normalize": bool(do_normalize),
702
  "denoise": bool(denoise),
 
703
  "has_reference_audio": bool(reference_wav_path_input and reference_wav_path_input.strip()),
704
  }
705
  if request_payload["has_reference_audio"]:
 
720
  cfg_value_input=cfg_value_input,
721
  do_normalize=do_normalize,
722
  denoise=denoise,
 
723
  request=request,
724
  )
725
  try:
 
746
  cfg_value_input=cfg_value_input,
747
  do_normalize=do_normalize,
748
  denoise=denoise,
 
749
  request=request,
750
  )
751
  try:
 
845
  label=I18N("cfg_label"),
846
  info=I18N("cfg_info"),
847
  )
 
 
 
 
 
 
 
 
 
848
  run_btn = gr.Button(I18N("generate_btn"), variant="primary", size="lg")
849
 
850
  with gr.Column():
851
  audio_output = gr.Audio(label=I18N("generated_audio_label"))
852
  gr.Markdown(I18N("examples_footer"))
853
 
 
 
 
 
 
 
 
854
  show_prompt_text.change(
855
  fn=_on_toggle_instant,
856
  inputs=[show_prompt_text],
 
872
  cfg_value,
873
  DoNormalizeText,
874
  DoDenoisePromptAudio,
 
875
  ],
876
  outputs=[audio_output],
877
  show_progress=True,