xjsc0 commited on
Commit
99cf7e1
·
1 Parent(s): 61e6f25
Files changed (2) hide show
  1. app.py +6 -3
  2. initialization.py +73 -0
app.py CHANGED
@@ -6,12 +6,13 @@ A singing voice synthesis system powered by YingMusicSinger,
6
  with built-in vocal/accompaniment separation via MelBandRoformer.
7
  """
8
 
 
 
 
9
  import gradio as gr
10
  import torch
11
  import torchaudio
12
- import tempfile
13
- import os
14
- import numpy as np
15
 
16
  # ---------------------------------------------------------------------------
17
  # Model loading (lazy, singleton) / 模型懒加载(单例)
@@ -22,6 +23,7 @@ _separator = None
22
 
23
  def get_model(device: str = "cuda:0"):
24
  """加载 YingMusicSinger 模型 / Load YingMusicSinger model."""
 
25
  global _model
26
  if _model is None:
27
  from src.YingMusicSinger.infer.YingMusicSinger import YingMusicSinger
@@ -35,6 +37,7 @@ def get_separator(device: str = "cuda:0"):
35
  加载 MelBandRoformer 分离模型 / Load MelBandRoformer separator.
36
  Returns a Separator instance ready for inference.
37
  """
 
38
  global _separator
39
  if _separator is None:
40
  from src.third_party.MusicSourceSeparationTraining.inference_api import (
 
6
  with built-in vocal/accompaniment separation via MelBandRoformer.
7
  """
8
 
9
+ import os
10
+ import tempfile
11
+
12
  import gradio as gr
13
  import torch
14
  import torchaudio
15
+ from initialization import download_files
 
 
16
 
17
  # ---------------------------------------------------------------------------
18
  # Model loading (lazy, singleton) / 模型懒加载(单例)
 
23
 
24
  def get_model(device: str = "cuda:0"):
25
  """加载 YingMusicSinger 模型 / Load YingMusicSinger model."""
26
+ download_files(task="infer")
27
  global _model
28
  if _model is None:
29
  from src.YingMusicSinger.infer.YingMusicSinger import YingMusicSinger
 
37
  加载 MelBandRoformer 分离模型 / Load MelBandRoformer separator.
38
  Returns a Separator instance ready for inference.
39
  """
40
+ download_files(task="infer")
41
  global _separator
42
  if _separator is None:
43
  from src.third_party.MusicSourceSeparationTraining.inference_api import (
initialization.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YingMusic-Singer Initialization Script
3
+
4
+ Downloads required checkpoints from HuggingFace based on task type.
5
+
6
+ Usage:
7
+ python initialization.py --task infer
8
+ python initialization.py --task train
9
+ """
10
+
11
+ import argparse
12
+ import os
13
+
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ REPO_ID = "ASLP-lab/YingMusic-Singer"
17
+ CKPT_DIR = "ckpts"
18
+
19
+ # Files required for each task
20
+ INFER_FILES = [
21
+ "ckpts/MelBandRoformer.ckpt",
22
+ "ckpts/config_vocals_mel_band_roformer_kj.yaml",
23
+ ]
24
+
25
+ TRAIN_EXTRA_FILES = [
26
+ "ckpts/YingMusicSinger_model.pt",
27
+ "ckpts/model_ckpt_steps_100000_simplified.ckpt",
28
+ "ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
29
+ ]
30
+
31
+ TASK_FILES = {
32
+ "infer": INFER_FILES,
33
+ "train": INFER_FILES + TRAIN_EXTRA_FILES,
34
+ }
35
+
36
+
37
+ def download_files(task: str):
38
+ files = TASK_FILES[task]
39
+ os.makedirs(CKPT_DIR, exist_ok=True)
40
+
41
+ print(f"Task: {task} | Downloading {len(files)} file(s) to {CKPT_DIR}/")
42
+ for remote_path in files:
43
+ filename = os.path.basename(remote_path)
44
+ local_path = os.path.join(CKPT_DIR, filename)
45
+
46
+ if os.path.exists(local_path):
47
+ print(f" [skip] {filename} already exists")
48
+ continue
49
+
50
+ print(f" [download] {filename} ...")
51
+ hf_hub_download(
52
+ repo_id=REPO_ID,
53
+ filename=remote_path,
54
+ local_dir=".",
55
+ )
56
+ print(f" [done] {filename}")
57
+
58
+ print("All downloads complete.")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ parser = argparse.ArgumentParser(
63
+ description="Download YingMusic-Singer checkpoints"
64
+ )
65
+ parser.add_argument(
66
+ "--task",
67
+ type=str,
68
+ required=True,
69
+ choices=list(TASK_FILES.keys()),
70
+ help="Task type: 'infer' for inference, 'train' for training",
71
+ )
72
+ args = parser.parse_args()
73
+ download_files(args.task)