| |
| """ |
| Pipeline 包装脚本 |
| |
| 此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。 |
| 主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。 |
| |
| 进度消息格式: |
| ##PROGRESS##{"type": "progress", "stage": "...", ...}## |
| |
| Usage: |
| python run_pipeline.py --config /path/to/config.json --task-id task-123 |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| import os |
| import traceback |
| from datetime import datetime |
| from typing import Dict, Any |
|
|
| |
| from pathlib import Path |
| _SCRIPT_DIR = Path(__file__).parent.resolve() |
| _API_SERVER_ROOT = _SCRIPT_DIR.parent.parent |
| _PROJECT_ROOT = _API_SERVER_ROOT.parent |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
| |
| from project_config import settings, PROJECT_ROOT, get_pythonpath |
|
|
|
|
| |
| PROGRESS_PREFIX = "##PROGRESS##" |
| PROGRESS_SUFFIX = "##" |
|
|
|
|
| def emit_progress(progress_info: Dict[str, Any]) -> None: |
| """ |
| 输出进度消息到 stdout |
| |
| Args: |
| progress_info: 进度信息字典 |
| """ |
| |
| if "timestamp" not in progress_info: |
| progress_info["timestamp"] = datetime.utcnow().isoformat() |
| |
| json_str = json.dumps(progress_info, ensure_ascii=False) |
| print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True) |
|
|
|
|
| def emit_log(level: str, message: str, **extra) -> None: |
| """ |
| 输出日志消息 |
| |
| Args: |
| level: 日志级别 (info, warning, error) |
| message: 日志消息 |
| **extra: 额外数据 |
| """ |
| emit_progress({ |
| "type": "log", |
| "level": level, |
| "message": message, |
| **extra |
| }) |
|
|
|
|
| def load_config(config_path: str) -> Dict[str, Any]: |
| """ |
| 加载配置文件 |
| |
| Args: |
| config_path: 配置文件路径 |
| |
| Returns: |
| 配置字典 |
| """ |
| with open(config_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
|
|
|
|
| def build_pipeline(config: Dict[str, Any]): |
| """ |
| 根据配置构建 TrainingPipeline |
| |
| Args: |
| config: 配置字典,包含: |
| - exp_name: 实验名称 |
| - version: 模型版本 |
| - stages: 要执行的阶段列表 |
| - 各阶段的具体配置 |
| |
| Returns: |
| TrainingPipeline 实例 |
| """ |
| from training_pipeline import ( |
| TrainingPipeline, |
| ModelVersion, |
| |
| AudioSliceConfig, |
| ASRConfig, |
| DenoiseConfig, |
| FeatureExtractionConfig, |
| SoVITSTrainConfig, |
| GPTTrainConfig, |
| InferenceConfig, |
| |
| AudioSliceStage, |
| ASRStage, |
| DenoiseStage, |
| TextFeatureStage, |
| HuBERTFeatureStage, |
| SemanticTokenStage, |
| SoVITSTrainStage, |
| GPTTrainStage, |
| InferenceStage, |
| ) |
| |
| pipeline = TrainingPipeline() |
| |
| exp_name = config["exp_name"] |
| version_str = config.get("version", "v2") |
| version = ModelVersion(version_str) if isinstance(version_str, str) else version_str |
| |
| |
| base_params = { |
| "exp_name": exp_name, |
| "exp_root": config.get("exp_root", "logs"), |
| "gpu_numbers": config.get("gpu_numbers", "0"), |
| "is_half": config.get("is_half", True), |
| } |
| |
| |
| stage_builders = { |
| "audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig( |
| **base_params, |
| input_path=cfg.get("input_path", ""), |
| threshold=cfg.get("threshold", -34), |
| min_length=cfg.get("min_length", 4000), |
| min_interval=cfg.get("min_interval", 300), |
| hop_size=cfg.get("hop_size", 10), |
| max_sil_kept=cfg.get("max_sil_kept", 500), |
| max_amp=cfg.get("max_amp", 0.9), |
| alpha=cfg.get("alpha", 0.25), |
| n_parts=cfg.get("n_parts", 4), |
| )), |
| |
| "asr": lambda cfg: ASRStage(ASRConfig( |
| **base_params, |
| model=cfg.get("model", "达摩 ASR (中文)"), |
| model_size=cfg.get("model_size", "large"), |
| language=cfg.get("language", "zh"), |
| precision=cfg.get("precision", "float32"), |
| )), |
| |
| "denoise": lambda cfg: DenoiseStage(DenoiseConfig( |
| **base_params, |
| input_dir=cfg.get("input_dir", ""), |
| output_dir=cfg.get("output_dir", "output/denoise_opt"), |
| )), |
| |
| "text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig( |
| **base_params, |
| version=version, |
| bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| pretrained_s2G=cfg.get("pretrained_s2G", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| )), |
| |
| "hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig( |
| **base_params, |
| version=version, |
| bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| pretrained_s2G=cfg.get("pretrained_s2G", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| )), |
| |
| "semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig( |
| **base_params, |
| version=version, |
| bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| pretrained_s2G=cfg.get("pretrained_s2G", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| )), |
| |
| "sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig( |
| **base_params, |
| version=version, |
| batch_size=cfg.get("batch_size", 4), |
| total_epoch=cfg.get("total_epoch", 8), |
| text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4), |
| save_every_epoch=cfg.get("save_every_epoch", 4), |
| if_save_latest=cfg.get("if_save_latest", True), |
| if_save_every_weights=cfg.get("if_save_every_weights", True), |
| pretrained_s2G=cfg.get("pretrained_s2G", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| pretrained_s2D=cfg.get("pretrained_s2D", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"), |
| if_grad_ckpt=cfg.get("if_grad_ckpt", False), |
| lora_rank=cfg.get("lora_rank", 32), |
| )), |
| |
| "gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig( |
| **base_params, |
| version=version, |
| batch_size=cfg.get("batch_size", 4), |
| total_epoch=cfg.get("total_epoch", 15), |
| save_every_epoch=cfg.get("save_every_epoch", 5), |
| if_save_latest=cfg.get("if_save_latest", True), |
| if_save_every_weights=cfg.get("if_save_every_weights", True), |
| if_dpo=cfg.get("if_dpo", False), |
| pretrained_s1=cfg.get("pretrained_s1", |
| "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"), |
| )), |
| |
| "inference": lambda cfg: InferenceStage(InferenceConfig( |
| **base_params, |
| version=version, |
| gpt_path=cfg.get("gpt_path", ""), |
| sovits_path=cfg.get("sovits_path", ""), |
| ref_text=cfg.get("ref_text", ""), |
| ref_audio_path=cfg.get("ref_audio_path", ""), |
| target_text=cfg.get("target_text", ""), |
| text_split_method=cfg.get("text_split_method", "cut1"), |
| )), |
| } |
| |
| |
| |
| |
| |
| stages = config.get("stages", []) |
| for stage_item in stages: |
| |
| if isinstance(stage_item, str): |
| stage_type = stage_item |
| stage_config = config |
| elif isinstance(stage_item, dict): |
| stage_type = stage_item.get("type") |
| |
| stage_config = {**config, **stage_item} |
| else: |
| emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}") |
| continue |
| |
| if stage_type in stage_builders: |
| stage = stage_builders[stage_type](stage_config) |
| pipeline.add_stage(stage) |
| emit_log("info", f"已添加阶段: {stage.name}") |
| else: |
| emit_log("warning", f"未知阶段类型: {stage_type}") |
| |
| return pipeline |
|
|
|
|
| def run_pipeline(config: Dict[str, Any], task_id: str) -> bool: |
| """ |
| 执行 Pipeline |
| |
| Args: |
| config: 配置字典 |
| task_id: 任务ID |
| |
| Returns: |
| 是否成功完成 |
| """ |
| emit_progress({ |
| "type": "progress", |
| "status": "running", |
| "message": "正在初始化训练流水线...", |
| "task_id": task_id, |
| "progress": 0.0, |
| "overall_progress": 0.0, |
| }) |
| |
| try: |
| pipeline = build_pipeline(config) |
| |
| stages = pipeline.get_stages() |
| if not stages: |
| emit_progress({ |
| "type": "progress", |
| "status": "failed", |
| "message": "没有配置任何训练阶段", |
| "task_id": task_id, |
| }) |
| return False |
| |
| emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段") |
| |
| |
| for progress in pipeline.run(): |
| |
| emit_progress({ |
| "type": "progress", |
| "status": "running", |
| "stage": progress.get("stage"), |
| "stage_index": progress.get("stage_index"), |
| "total_stages": progress.get("total_stages"), |
| "progress": progress.get("progress", 0.0), |
| "overall_progress": progress.get("overall_progress", 0.0), |
| "message": progress.get("message"), |
| "task_id": task_id, |
| "data": progress.get("data", {}), |
| }) |
| |
| |
| if progress.get("status") == "failed": |
| emit_progress({ |
| "type": "progress", |
| "status": "failed", |
| "stage": progress.get("stage"), |
| "message": progress.get("message", "阶段执行失败"), |
| "task_id": task_id, |
| }) |
| return False |
| |
| |
| emit_progress({ |
| "type": "progress", |
| "status": "completed", |
| "message": "训练流水线执行完成", |
| "task_id": task_id, |
| "progress": 1.0, |
| "overall_progress": 1.0, |
| }) |
| return True |
| |
| except Exception as e: |
| error_msg = str(e) |
| error_trace = traceback.format_exc() |
| emit_progress({ |
| "type": "progress", |
| "status": "failed", |
| "message": f"执行出错: {error_msg}", |
| "error": error_msg, |
| "traceback": error_trace, |
| "task_id": task_id, |
| }) |
| return False |
|
|
|
|
| def main(): |
| """主函数""" |
| parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线") |
| parser.add_argument("--config", required=True, help="配置文件路径 (JSON)") |
| parser.add_argument("--task-id", required=True, help="任务ID") |
| |
| args = parser.parse_args() |
| |
| emit_log("info", f"启动训练任务: {args.task_id}") |
| emit_log("info", f"配置文件: {args.config}") |
| |
| try: |
| config = load_config(args.config) |
| except Exception as e: |
| emit_progress({ |
| "type": "progress", |
| "status": "failed", |
| "message": f"加载配置文件失败: {e}", |
| "task_id": args.task_id, |
| }) |
| sys.exit(1) |
| |
| success = run_pipeline(config, args.task_id) |
| sys.exit(0 if success else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|