MoYoYo.tts / api_server /app /services /experiment_service.py
liumaolin
feat(api): implement local training MVP with adapter pattern
e054d0c
"""
Advanced Mode 实验服务
处理专家模式分阶段训练的业务逻辑
"""
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Any
from ..core.adapters import (
get_database_adapter,
get_task_queue_adapter,
get_progress_adapter,
)
from ..models.schemas.experiment import (
ExperimentCreate,
ExperimentUpdate,
ExperimentResponse,
ExperimentListResponse,
StageStatus,
StageExecuteResponse,
StagesListResponse,
STAGE_DEPENDENCIES,
)
# 阶段类型列表(按执行顺序)
STAGE_TYPES = [
"audio_slice",
"asr",
"text_feature",
"hubert_feature",
"semantic_token",
"sovits_train",
"gpt_train",
]
class ExperimentService:
"""
Advanced Mode 实验服务
提供专家模式的分阶段训练管理:
- 创建实验
- 查询实验/阶段状态
- 执行/取消单个阶段
- 检查阶段依赖
Example:
>>> service = ExperimentService()
>>> exp = await service.create_experiment(request)
>>> await service.execute_stage(exp.id, "audio_slice", {})
>>> stages = await service.get_all_stages(exp.id)
"""
def __init__(self):
"""初始化服务"""
self._db = None
self._queue = None
self._progress = None
@property
def db(self):
"""延迟获取数据库适配器"""
if self._db is None:
self._db = get_database_adapter()
return self._db
@property
def queue(self):
"""延迟获取任务队列适配器"""
if self._queue is None:
self._queue = get_task_queue_adapter()
return self._queue
@property
def progress_adapter(self):
"""延迟获取进度适配器"""
if self._progress is None:
self._progress = get_progress_adapter()
return self._progress
async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse:
"""
创建实验
创建实验但不立即执行,用户可以逐阶段控制训练流程。
Args:
request: 创建实验请求
Returns:
ExperimentResponse
"""
exp_id = f"exp-{uuid.uuid4().hex[:8]}"
experiment_data = {
"id": exp_id,
"exp_name": request.exp_name,
"version": request.version,
"gpu_numbers": request.gpu_numbers,
"is_half": request.is_half,
"audio_file_id": request.audio_file_id,
"status": "created",
}
# 创建实验(会自动创建所有阶段)
experiment = await self.db.create_experiment(experiment_data)
return self._experiment_to_response(experiment)
async def get_experiment(self, exp_id: str) -> Optional[ExperimentResponse]:
"""
获取实验详情
Args:
exp_id: 实验ID
Returns:
ExperimentResponse 或 None
"""
experiment = await self.db.get_experiment(exp_id)
if not experiment:
return None
return self._experiment_to_response(experiment)
async def list_experiments(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> ExperimentListResponse:
"""
获取实验列表
Args:
status: 按状态筛选
limit: 每页数量
offset: 偏移量
Returns:
ExperimentListResponse
"""
experiments = await self.db.list_experiments(
status=status, limit=limit, offset=offset
)
# 获取每个实验的完整信息(包含 stages)
full_experiments = []
for exp in experiments:
full_exp = await self.db.get_experiment(exp["id"])
if full_exp:
full_experiments.append(full_exp)
return ExperimentListResponse(
items=[self._experiment_to_response(e) for e in full_experiments],
total=len(experiments), # TODO: 添加 count 方法
limit=limit,
offset=offset,
)
async def update_experiment(
self,
exp_id: str,
request: ExperimentUpdate
) -> Optional[ExperimentResponse]:
"""
更新实验基础配置
Args:
exp_id: 实验ID
request: 更新请求
Returns:
ExperimentResponse 或 None
"""
updates = {}
if request.exp_name is not None:
updates["exp_name"] = request.exp_name
if request.gpu_numbers is not None:
updates["gpu_numbers"] = request.gpu_numbers
if request.is_half is not None:
updates["is_half"] = request.is_half
if not updates:
return await self.get_experiment(exp_id)
experiment = await self.db.update_experiment(exp_id, updates)
if not experiment:
return None
return self._experiment_to_response(experiment)
async def delete_experiment(self, exp_id: str) -> bool:
"""
删除实验
Args:
exp_id: 实验ID
Returns:
是否成功删除
"""
# 先取消所有运行中的阶段
stages = await self.db.get_all_stages(exp_id)
for stage in stages:
if stage.get("status") == "running" and stage.get("job_id"):
await self.queue.cancel(stage["job_id"])
return await self.db.delete_experiment(exp_id)
async def check_stage_dependencies(
self,
exp_id: str,
stage_type: str
) -> Dict[str, Any]:
"""
检查阶段依赖是否满足
Args:
exp_id: 实验ID
stage_type: 阶段类型
Returns:
{"satisfied": bool, "missing": List[str]}
"""
experiment = await self.db.get_experiment(exp_id)
if not experiment:
return {"satisfied": False, "missing": [], "error": "实验不存在"}
dependencies = STAGE_DEPENDENCIES.get(stage_type, [])
stages = experiment.get("stages", {})
missing = []
for dep in dependencies:
dep_stage = stages.get(dep, {})
if dep_stage.get("status") != "completed":
missing.append(dep)
return {
"satisfied": len(missing) == 0,
"missing": missing,
}
async def execute_stage(
self,
exp_id: str,
stage_type: str,
params: Dict[str, Any]
) -> Optional[StageExecuteResponse]:
"""
执行指定阶段
Args:
exp_id: 实验ID
stage_type: 阶段类型
params: 阶段参数
Returns:
StageExecuteResponse 或 None
"""
# 获取实验
experiment = await self.db.get_experiment(exp_id)
if not experiment:
return None
stages = experiment.get("stages", {})
current_stage = stages.get(stage_type, {})
# 检查是否是重新执行
is_rerun = current_stage.get("status") == "completed"
previous_run = None
if is_rerun:
previous_run = {
"completed_at": current_stage.get("completed_at"),
"outputs": current_stage.get("outputs"),
}
# 构建阶段配置
stage_config = {
"exp_id": exp_id,
"exp_name": experiment["exp_name"],
"version": experiment.get("version", "v2"),
"gpu_numbers": experiment.get("gpu_numbers", "0"),
"is_half": experiment.get("is_half", True),
"audio_file_id": experiment.get("audio_file_id"),
"stage_type": stage_type,
"params": params,
# 只执行单个阶段
"stages": [stage_type],
}
# 生成任务ID(用于进度追踪)
task_id = f"{exp_id}-{stage_type}-{uuid.uuid4().hex[:4]}"
# 入队执行
job_id = await self.queue.enqueue(task_id, stage_config)
# 更新阶段状态
now = datetime.utcnow()
await self.db.update_stage(exp_id, stage_type, {
"status": "running",
"config": params,
"job_id": job_id,
"started_at": now,
"completed_at": None,
"error_message": None,
"outputs": None,
"progress": 0.0,
})
# 更新实验状态为运行中
await self.db.update_experiment(exp_id, {"status": "running"})
return StageExecuteResponse(
exp_id=exp_id,
stage_type=stage_type,
status="running",
job_id=job_id,
config=params,
rerun=is_rerun,
previous_run=previous_run,
started_at=now,
)
async def get_stage(
self,
exp_id: str,
stage_type: str
) -> Optional[StageStatus]:
"""
获取阶段状态
Args:
exp_id: 实验ID
stage_type: 阶段类型
Returns:
StageStatus 或 None
"""
stage = await self.db.get_stage(exp_id, stage_type)
if not stage:
return None
return self._stage_to_status(stage)
async def get_all_stages(self, exp_id: str) -> Optional[StagesListResponse]:
"""
获取所有阶段状态
Args:
exp_id: 实验ID
Returns:
StagesListResponse 或 None
"""
stages = await self.db.get_all_stages(exp_id)
if not stages:
# 检查实验是否存在
experiment = await self.db.get_experiment(exp_id)
if not experiment:
return None
stages = []
return StagesListResponse(
exp_id=exp_id,
stages=[self._stage_to_status(s) for s in stages],
)
async def cancel_stage(self, exp_id: str, stage_type: str) -> bool:
"""
取消正在执行的阶段
Args:
exp_id: 实验ID
stage_type: 阶段类型
Returns:
是否成功取消
"""
stage = await self.db.get_stage(exp_id, stage_type)
if not stage:
return False
# 只有运行中的阶段可以取消
if stage.get("status") != "running":
return False
# 取消任务
job_id = stage.get("job_id")
if job_id:
await self.queue.cancel(job_id)
# 更新状态
await self.db.update_stage(exp_id, stage_type, {
"status": "cancelled",
"completed_at": datetime.utcnow(),
"message": "阶段已取消",
})
return True
async def subscribe_stage_progress(
self,
exp_id: str,
stage_type: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""
订阅阶段进度(SSE 流)
Args:
exp_id: 实验ID
stage_type: 阶段类型
Yields:
进度信息字典
"""
# 获取阶段信息
stage = await self.db.get_stage(exp_id, stage_type)
if not stage:
yield {"type": "error", "message": "阶段不存在"}
return
# 如果阶段已结束,直接返回最终状态
if stage.get("status") in ("completed", "failed", "cancelled"):
yield {
"type": "final",
"status": stage.get("status"),
"message": stage.get("message") or stage.get("error_message"),
"progress": stage.get("progress", 0.0),
"outputs": stage.get("outputs"),
}
return
# 如果阶段未开始
if stage.get("status") == "pending":
yield {"type": "info", "message": "阶段尚未开始"}
return
# 使用任务ID订阅进度
# 任务ID格式: {exp_id}-{stage_type}-{random}
# 由于我们不知道确切的任务ID,使用 job_id
job_id = stage.get("job_id")
if not job_id:
yield {"type": "error", "message": "无法获取任务ID"}
return
# 订阅进度
# 注意:这里需要根据实际的进度适配器实现来调整
# 当前使用 task_id 格式为 "{exp_id}-{stage_type}"
task_id = f"{exp_id}-{stage_type}"
async for progress in self.progress_adapter.subscribe(task_id):
yield progress
# 检查是否为终态
if progress.get("status") in ("completed", "failed", "cancelled"):
break
def _experiment_to_response(self, experiment: Dict[str, Any]) -> ExperimentResponse:
"""将实验数据转换为响应模型"""
stages_data = experiment.get("stages", {})
stages = {}
for stage_type, stage_info in stages_data.items():
stages[stage_type] = self._stage_to_status(stage_info)
# 解析日期时间
created_at = experiment.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.utcnow()
updated_at = experiment.get("updated_at")
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
return ExperimentResponse(
id=experiment["id"],
exp_name=experiment["exp_name"],
version=experiment.get("version", "v2"),
status=experiment.get("status", "created"),
gpu_numbers=experiment.get("gpu_numbers", "0"),
is_half=experiment.get("is_half", True),
audio_file_id=experiment.get("audio_file_id", ""),
stages=stages,
created_at=created_at,
updated_at=updated_at,
)
def _stage_to_status(self, stage: Dict[str, Any]) -> StageStatus:
"""将阶段数据转换为状态模型"""
# 解析日期时间
started_at = stage.get("started_at")
if isinstance(started_at, str):
started_at = datetime.fromisoformat(started_at)
completed_at = stage.get("completed_at")
if isinstance(completed_at, str):
completed_at = datetime.fromisoformat(completed_at)
return StageStatus(
stage_type=stage.get("stage_type", ""),
status=stage.get("status", "pending"),
progress=stage.get("progress"),
message=stage.get("message"),
started_at=started_at,
completed_at=completed_at,
config=stage.get("config"),
outputs=stage.get("outputs"),
error_message=stage.get("error_message"),
)