liumaolin commited on
Commit
8f68d0a
·
1 Parent(s): 60123c5

refactor(config): centralize configuration management in `project_config`

Browse files

- Move `config.py` from `core` to root-level `project_config.py`
- Update imports throughout the project to use the new module
- Adjust paths to use centralized constants like `DEFAULT_MODEL_DIR`
- Improve maintainability by unifying configuration for TTS, ASR, and pipelines
- Remove redundant configuration instances in submodules

GPT_SoVITS/download.py CHANGED
@@ -1,13 +1,15 @@
1
  import os
2
  import sys
3
 
 
 
4
  now_dir = os.getcwd()
5
  sys.path.insert(0, now_dir)
6
  from text.g2pw import G2PWPinyin
7
 
8
  g2pw = G2PWPinyin(
9
- model_dir="GPT_SoVITS/text/G2PWModel",
10
- model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
11
  v_to_u=False,
12
  neutral_tone_with_five=True,
13
  )
 
1
  import os
2
  import sys
3
 
4
+ from project_config import settings, DEFAULT_MODEL_DIR
5
+
6
  now_dir = os.getcwd()
7
  sys.path.insert(0, now_dir)
8
  from text.g2pw import G2PWPinyin
9
 
10
  g2pw = G2PWPinyin(
11
+ model_dir=DEFAULT_MODEL_DIR / "G2PWModel",
12
+ model_source=settings.BERT_PRETRAINED_DIR,
13
  v_to_u=False,
14
  neutral_tone_with_five=True,
15
  )
GPT_SoVITS/text/chinese2.py CHANGED
@@ -8,6 +8,7 @@ from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
8
  from GPT_SoVITS.text.symbols import punctuation
9
  from GPT_SoVITS.text.tone_sandhi import ToneSandhi
10
  from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
 
11
 
12
  normalizer = lambda x: cn2an.transform(x, "an2cn")
13
 
@@ -32,8 +33,8 @@ if is_g2pw:
32
 
33
  parent_directory = os.path.dirname(current_file_path)
34
  g2pw = G2PWPinyin(
35
- model_dir="GPT_SoVITS/text/G2PWModel",
36
- model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
37
  v_to_u=False,
38
  neutral_tone_with_five=True,
39
  )
 
8
  from GPT_SoVITS.text.symbols import punctuation
9
  from GPT_SoVITS.text.tone_sandhi import ToneSandhi
10
  from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
11
+ from project_config import settings, DEFAULT_MODEL_DIR
12
 
13
  normalizer = lambda x: cn2an.transform(x, "an2cn")
14
 
 
33
 
34
  parent_directory = os.path.dirname(current_file_path)
35
  g2pw = G2PWPinyin(
36
+ model_dir=DEFAULT_MODEL_DIR / "G2PWModel",
37
+ model_source=os.environ.get("bert_path", settings.BERT_PRETRAINED_DIR),
38
  v_to_u=False,
39
  neutral_tone_with_five=True,
40
  )
api_server/app/adapters/local/database.py CHANGED
@@ -13,11 +13,10 @@ from typing import Any, Dict, List, Optional
13
 
14
  import aiosqlite
15
 
 
16
  from ..base import DatabaseAdapter
17
- from ...core.config import settings
18
  from ...models.domain import Task, TaskStatus
19
 
20
-
21
  # 阶段类型列表
22
  STAGE_TYPES = [
23
  "audio_slice",
 
13
 
14
  import aiosqlite
15
 
16
+ from project_config import settings
17
  from ..base import DatabaseAdapter
 
18
  from ...models.domain import Task, TaskStatus
19
 
 
20
  # 阶段类型列表
21
  STAGE_TYPES = [
22
  "audio_slice",
api_server/app/adapters/local/storage.py CHANGED
@@ -13,8 +13,8 @@ from typing import Any, Dict, List, Optional
13
 
14
  import aiofiles
15
 
 
16
  from ..base import StorageAdapter
17
- from ...core.config import settings
18
 
19
 
20
  class LocalStorageAdapter(StorageAdapter):
 
13
 
14
  import aiofiles
15
 
16
+ from project_config import settings
17
  from ..base import StorageAdapter
 
18
 
19
 
20
  class LocalStorageAdapter(StorageAdapter):
api_server/app/adapters/local/task_queue.py CHANGED
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict, Optional, AsyncGenerator, List
17
 
18
  import aiosqlite
19
 
 
20
  from ..base import TaskQueueAdapter
21
- from ...core.config import settings, PROJECT_ROOT, get_pythonpath
22
 
23
  if TYPE_CHECKING:
24
  from ..base import DatabaseAdapter
 
17
 
18
  import aiosqlite
19
 
20
+ from project_config import settings, PROJECT_ROOT, get_pythonpath
21
  from ..base import TaskQueueAdapter
 
22
 
23
  if TYPE_CHECKING:
24
  from ..base import DatabaseAdapter
api_server/app/core/__init__.py CHANGED
@@ -4,6 +4,6 @@
4
  包含配置、枚举等核心组件
5
  """
6
 
7
- from .config import settings, PROJECT_ROOT, API_SERVER_ROOT
8
 
9
  __all__ = ["settings", "PROJECT_ROOT", "API_SERVER_ROOT"]
 
4
  包含配置、枚举等核心组件
5
  """
6
 
7
+ from project_config import settings, PROJECT_ROOT, API_SERVER_ROOT
8
 
9
  __all__ = ["settings", "PROJECT_ROOT", "API_SERVER_ROOT"]
api_server/app/core/adapters.py CHANGED
@@ -12,7 +12,7 @@ Example:
12
  from functools import lru_cache
13
  from typing import TYPE_CHECKING
14
 
15
- from .config import settings
16
 
17
  if TYPE_CHECKING:
18
  from ..adapters.base import (
 
12
  from functools import lru_cache
13
  from typing import TYPE_CHECKING
14
 
15
+ from project_config import settings
16
 
17
  if TYPE_CHECKING:
18
  from ..adapters.base import (
api_server/app/main.py CHANGED
@@ -13,8 +13,8 @@ from typing import AsyncGenerator
13
  from fastapi import FastAPI
14
  from fastapi.middleware.cors import CORSMiddleware
15
 
 
16
  from .api.v1.router import api_router
17
- from .core.config import settings, ensure_data_dirs
18
 
19
 
20
  @asynccontextmanager
 
13
  from fastapi import FastAPI
14
  from fastapi.middleware.cors import CORSMiddleware
15
 
16
+ from project_config import settings, ensure_data_dirs
17
  from .api.v1.router import api_router
 
18
 
19
 
20
  @asynccontextmanager
api_server/app/scripts/run_pipeline.py CHANGED
@@ -28,7 +28,7 @@ _PROJECT_ROOT = _API_SERVER_ROOT.parent
28
  sys.path.insert(0, str(_PROJECT_ROOT))
29
 
30
  # 导入配置模块
31
- from api_server.app.core.config import settings, PROJECT_ROOT, get_pythonpath
32
 
33
 
34
  # 进度消息前缀和后缀,用于主进程解析
 
28
  sys.path.insert(0, str(_PROJECT_ROOT))
29
 
30
  # 导入配置模块
31
+ from project_config import settings, PROJECT_ROOT, get_pythonpath
32
 
33
 
34
  # 进度消息前缀和后缀,用于主进程解析
api_server/app/services/task_service.py CHANGED
@@ -10,8 +10,8 @@ from datetime import datetime
10
  from pathlib import Path
11
  from typing import AsyncGenerator, Dict, Optional, Any, Tuple
12
 
 
13
  from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter
14
- from ..core.config import settings
15
  from ..models.domain import Task, TaskStatus
16
  from ..models.schemas.task import (
17
  QuickModeRequest,
 
10
  from pathlib import Path
11
  from typing import AsyncGenerator, Dict, Optional, Any, Tuple
12
 
13
+ from project_config import settings
14
  from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter
 
15
  from ..models.domain import Task, TaskStatus
16
  from ..models.schemas.task import (
17
  QuickModeRequest,
infer.py CHANGED
@@ -17,8 +17,9 @@ from pathlib import Path
17
  import click
18
  import soundfile as sf
19
 
20
- from training_pipeline.stages.inference import create_tts_module, create_inference_config
21
  from training_pipeline.configs import InferenceConfig
 
22
 
23
 
24
  @click.command()
@@ -58,13 +59,13 @@ from training_pipeline.configs import InferenceConfig
58
  )
59
  @click.option(
60
  '--bert-path',
61
- default='GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
62
  type=click.Path(exists=True),
63
  help='BERT 预训练模型路径'
64
  )
65
  @click.option(
66
  '--cnhubert-path',
67
- default='GPT_SoVITS/pretrained_models/chinese-hubert-base',
68
  type=click.Path(exists=True),
69
  help='Chinese HuBERT 预训练模型路径'
70
  )
@@ -79,16 +80,16 @@ from training_pipeline.configs import InferenceConfig
79
  help='参考文本语言(默认: zh)'
80
  )
81
  def main(
82
- target_text: str,
83
- ref_text: str,
84
- ref_audio: str,
85
- gpt_model: str,
86
- sovits_model: str,
87
- output: str,
88
- bert_path: str,
89
- cnhubert_path: str,
90
- text_lang: str,
91
- prompt_lang: str,
92
  ):
93
  """GPT-SoVITS 命令行推理工具
94
 
@@ -103,11 +104,11 @@ def main(
103
  click.echo(f" SoVITS 模型: {sovits_model}")
104
  click.echo(f" 输出路径: {output}")
105
  click.echo()
106
-
107
  # 确保输出目录存在
108
  output_path = Path(output)
109
  output_path.parent.mkdir(parents=True, exist_ok=True)
110
-
111
  # 创建推理配置
112
  cfg = InferenceConfig(
113
  exp_name="cli_inference",
@@ -119,12 +120,12 @@ def main(
119
  ref_audio_path=ref_audio,
120
  target_text=target_text,
121
  )
122
-
123
  click.echo("⏳ 正在加载模型...")
124
  try:
125
  # 创建 TTS 模块
126
  tts_module = create_tts_module(cfg)
127
-
128
  # 创建推理配置
129
  inference_config = create_inference_config(
130
  text=target_text,
@@ -133,7 +134,7 @@ def main(
133
  text_lang=text_lang,
134
  prompt_lang=prompt_lang,
135
  )
136
-
137
  click.echo("🔊 正在合成语音...")
138
  # 执行推理
139
  for item in tts_module.run(inference_config):
@@ -141,9 +142,9 @@ def main(
141
  # 保存音频
142
  sf.write(str(output_path), audio_data, sample_rate, subtype='PCM_16')
143
  break # 只取第一个结果
144
-
145
  click.echo(f"✅ 成功!音频已保存至: {output_path.absolute()}")
146
-
147
  except Exception as e:
148
  click.echo(f"❌ 推理失败: {e}", err=True)
149
  sys.exit(1)
 
17
  import click
18
  import soundfile as sf
19
 
20
+ from project_config import settings
21
  from training_pipeline.configs import InferenceConfig
22
+ from training_pipeline.stages.inference import create_tts_module, create_inference_config
23
 
24
 
25
  @click.command()
 
59
  )
60
  @click.option(
61
  '--bert-path',
62
+ default=settings.BERT_PRETRAINED_DIR,
63
  type=click.Path(exists=True),
64
  help='BERT 预训练模型路径'
65
  )
66
  @click.option(
67
  '--cnhubert-path',
68
+ default=settings.SSL_PRETRAINED_DIR,
69
  type=click.Path(exists=True),
70
  help='Chinese HuBERT 预训练模型路径'
71
  )
 
80
  help='参考文本语言(默认: zh)'
81
  )
82
  def main(
83
+ target_text: str,
84
+ ref_text: str,
85
+ ref_audio: str,
86
+ gpt_model: str,
87
+ sovits_model: str,
88
+ output: str,
89
+ bert_path: str,
90
+ cnhubert_path: str,
91
+ text_lang: str,
92
+ prompt_lang: str,
93
  ):
94
  """GPT-SoVITS 命令行推理工具
95
 
 
104
  click.echo(f" SoVITS 模型: {sovits_model}")
105
  click.echo(f" 输出路径: {output}")
106
  click.echo()
107
+
108
  # 确保输出目录存在
109
  output_path = Path(output)
110
  output_path.parent.mkdir(parents=True, exist_ok=True)
111
+
112
  # 创建推理配置
113
  cfg = InferenceConfig(
114
  exp_name="cli_inference",
 
120
  ref_audio_path=ref_audio,
121
  target_text=target_text,
122
  )
123
+
124
  click.echo("⏳ 正在加载模型...")
125
  try:
126
  # 创建 TTS 模块
127
  tts_module = create_tts_module(cfg)
128
+
129
  # 创建推理配置
130
  inference_config = create_inference_config(
131
  text=target_text,
 
134
  text_lang=text_lang,
135
  prompt_lang=prompt_lang,
136
  )
137
+
138
  click.echo("🔊 正在合成语音...")
139
  # 执行推理
140
  for item in tts_module.run(inference_config):
 
142
  # 保存音频
143
  sf.write(str(output_path), audio_data, sample_rate, subtype='PCM_16')
144
  break # 只取第一个结果
145
+
146
  click.echo(f"✅ 成功!音频已保存至: {output_path.absolute()}")
147
+
148
  except Exception as e:
149
  click.echo(f"❌ 推理失败: {e}", err=True)
150
  sys.exit(1)
api_server/app/core/config.py → project_config.py RENAMED
@@ -2,6 +2,7 @@
2
  环境变量和配置模块
3
 
4
  统一管理项目路径、环境配置等
 
5
  """
6
 
7
  import os
@@ -14,18 +15,24 @@ from typing import Literal
14
 
15
  USER_HOME_ROOT = Path.home()
16
 
17
- # api_server/app/core/config.py -> api_server/app/core -> api_server/app -> api_server -> 项目根目录
18
- API_SERVER_ROOT = Path(__file__).parent.parent.parent.resolve()
19
- PROJECT_ROOT = API_SERVER_ROOT.parent.resolve()
 
 
20
 
21
  # GPT_SoVITS 模块路径
22
  GPT_SOVITS_ROOT = PROJECT_ROOT / "GPT_SoVITS"
23
 
 
 
24
  # 默认数据目录
25
- DEFAULT_DATA_DIR = USER_HOME_ROOT / '.moyoyo-tts' / "data"
 
 
26
 
27
  # 预训练模型目录
28
- PRETRAINED_MODELS_DIR = GPT_SOVITS_ROOT / "pretrained_models"
29
 
30
  # 日志目录
31
  LOGS_DIR = PROJECT_ROOT / "logs"
@@ -42,7 +49,7 @@ class Settings:
42
  支持从环境变量读取配置,提供合理的默认值
43
 
44
  Example:
45
- >>> from api_server.app.core.config import settings
46
  >>> print(settings.PROJECT_ROOT)
47
  >>> print(settings.DEPLOYMENT_MODE)
48
  """
 
2
  环境变量和配置模块
3
 
4
  统一管理项目路径、环境配置等
5
+ 供整个项目共用
6
  """
7
 
8
  import os
 
15
 
16
  USER_HOME_ROOT = Path.home()
17
 
18
+ # project_config.py 位于项目根目录
19
+ PROJECT_ROOT = Path(__file__).parent.resolve()
20
+
21
+ # api_server 目录路径
22
+ API_SERVER_ROOT = PROJECT_ROOT / "api_server"
23
 
24
  # GPT_SoVITS 模块路径
25
  GPT_SOVITS_ROOT = PROJECT_ROOT / "GPT_SoVITS"
26
 
27
+ DEFAULT_APP_DIR = USER_HOME_ROOT / '.moyoyo-tts'
28
+
29
  # 默认数据目录
30
+ DEFAULT_DATA_DIR = DEFAULT_APP_DIR / "data"
31
+
32
+ DEFAULT_MODEL_DIR = DEFAULT_APP_DIR / "models"
33
 
34
  # 预训练模型目录
35
+ PRETRAINED_MODELS_DIR = DEFAULT_MODEL_DIR / "pretrained_models"
36
 
37
  # 日志目录
38
  LOGS_DIR = PROJECT_ROOT / "logs"
 
49
  支持从环境变量读取配置,提供合理的默认值
50
 
51
  Example:
52
+ >>> from project_config import settings
53
  >>> print(settings.PROJECT_ROOT)
54
  >>> print(settings.DEPLOYMENT_MODE)
55
  """
tools/asr/fasterwhisper_asr.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import snapshot_download as snapshot_download_hf
9
  from modelscope import snapshot_download as snapshot_download_ms
10
  from tqdm import tqdm
11
 
 
12
  from tools.asr.config import get_models
13
  from tools.asr.funasr_asr import only_asr
14
  from tools.my_utils import load_cudnn
@@ -52,20 +53,20 @@ def download_model(model_size: str):
52
  if "distil" in model_size:
53
  if "3.5" in model_size:
54
  repo_id = "distil-whisper/distil-large-v3.5-ct2"
55
- model_path = "tools/asr/models/faster-distil-whisper-large-v3.5"
56
  else:
57
  repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
58
  elif model_size == "large-v3-turbo":
59
  repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
60
- model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
61
  else:
62
  repo_id = f"Systran/faster-whisper-{model_size}"
63
  model_path = (
64
- model_path or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}"
65
  )
66
  else:
67
  repo_id = "XXXXRT/faster-whisper"
68
- model_path = "tools/asr/models"
69
 
70
  files: list[str] = [
71
  "config.json",
@@ -83,21 +84,21 @@ def download_model(model_size: str):
83
  files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
84
 
85
  if source == "HF":
86
- print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
87
  snapshot_download_hf(
88
  repo_id,
89
- local_dir=model_path,
90
  local_dir_use_symlinks=False,
91
  allow_patterns=files,
92
  )
93
  else:
94
- print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
95
  snapshot_download_ms(
96
  repo_id,
97
- local_dir=model_path,
98
  allow_patterns=files,
99
  )
100
- return model_path + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper")
101
  return model_path
102
 
103
 
 
9
  from modelscope import snapshot_download as snapshot_download_ms
10
  from tqdm import tqdm
11
 
12
+ from project_config import DEFAULT_MODEL_DIR
13
  from tools.asr.config import get_models
14
  from tools.asr.funasr_asr import only_asr
15
  from tools.my_utils import load_cudnn
 
53
  if "distil" in model_size:
54
  if "3.5" in model_size:
55
  repo_id = "distil-whisper/distil-large-v3.5-ct2"
56
+ model_path = DEFAULT_MODEL_DIR / "faster-distil-whisper-large-v3.5"
57
  else:
58
  repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
59
  elif model_size == "large-v3-turbo":
60
  repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
61
+ model_path = DEFAULT_MODEL_DIR / "faster-whisper-large-v3-turbo"
62
  else:
63
  repo_id = f"Systran/faster-whisper-{model_size}"
64
  model_path = (
65
+ model_path or DEFAULT_MODEL_DIR / f"{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}"
66
  )
67
  else:
68
  repo_id = "XXXXRT/faster-whisper"
69
+ model_path = DEFAULT_MODEL_DIR
70
 
71
  files: list[str] = [
72
  "config.json",
 
84
  files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
85
 
86
  if source == "HF":
87
+ print(f"Downloading model from HuggingFace: {repo_id} to {model_path.as_posix()}")
88
  snapshot_download_hf(
89
  repo_id,
90
+ local_dir=model_path.as_posix(),
91
  local_dir_use_symlinks=False,
92
  allow_patterns=files,
93
  )
94
  else:
95
+ print(f"Downloading model from ModelScope: {repo_id} to {model_path.as_posix()}")
96
  snapshot_download_ms(
97
  repo_id,
98
+ local_dir=model_path.as_posix(),
99
  allow_patterns=files,
100
  )
101
+ return model_path.as_posix() + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper")
102
  return model_path
103
 
104
 
tools/asr/funasr_asr.py CHANGED
@@ -8,6 +8,8 @@ from funasr import AutoModel
8
  from modelscope import snapshot_download
9
  from tqdm import tqdm
10
 
 
 
11
  funasr_models = {} # 存储模型避免重复加载
12
 
13
 
@@ -23,27 +25,27 @@ def only_asr(input_file, language):
23
 
24
  def create_model(language="zh"):
25
  if language == "zh":
26
- path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
27
- path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
28
- path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
29
  snapshot_download(
30
  "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
31
- local_dir="tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
32
  )
33
  snapshot_download(
34
  "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
35
- local_dir="tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
36
  )
37
  snapshot_download(
38
  "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
39
- local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
40
  )
41
  model_revision = "v2.0.4"
42
  elif language == "yue":
43
- path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
44
  snapshot_download(
45
  "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
46
- local_dir="tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
47
  )
48
  path_vad = path_punc = None
49
  vad_model_revision = punc_model_revision = ""
@@ -57,11 +59,11 @@ def create_model(language="zh"):
57
  return funasr_models[language]
58
  else:
59
  model = AutoModel(
60
- model=path_asr,
61
  model_revision=model_revision,
62
- vad_model=path_vad,
63
  vad_model_revision=vad_model_revision,
64
- punc_model=path_punc,
65
  punc_model_revision=punc_model_revision,
66
  )
67
  print(f"FunASR 模型加载完成: {language.upper()}")
 
8
  from modelscope import snapshot_download
9
  from tqdm import tqdm
10
 
11
+ from project_config import DEFAULT_MODEL_DIR
12
+
13
  funasr_models = {} # 存储模型避免重复加载
14
 
15
 
 
25
 
26
  def create_model(language="zh"):
27
  if language == "zh":
28
+ path_vad = DEFAULT_MODEL_DIR / "speech_fsmn_vad_zh-cn-16k-common-pytorch"
29
+ path_punc = DEFAULT_MODEL_DIR / "punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
30
+ path_asr = DEFAULT_MODEL_DIR / "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
31
  snapshot_download(
32
  "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
33
+ local_dir=path_vad.as_posix(),
34
  )
35
  snapshot_download(
36
  "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
37
+ local_dir=path_punc.as_posix(),
38
  )
39
  snapshot_download(
40
  "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
41
+ local_dir=path_asr.as_posix(),
42
  )
43
  model_revision = "v2.0.4"
44
  elif language == "yue":
45
+ path_asr = DEFAULT_MODEL_DIR / "speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
46
  snapshot_download(
47
  "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
48
+ local_dir=path_asr.as_posix(),
49
  )
50
  path_vad = path_punc = None
51
  vad_model_revision = punc_model_revision = ""
 
59
  return funasr_models[language]
60
  else:
61
  model = AutoModel(
62
+ model=path_asr.as_posix(),
63
  model_revision=model_revision,
64
+ vad_model=path_vad.as_posix(),
65
  vad_model_revision=vad_model_revision,
66
+ punc_model=path_punc.as_posix(),
67
  punc_model_revision=punc_model_revision,
68
  )
69
  print(f"FunASR 模型加载完成: {language.upper()}")