""" Advanced Mode 实验服务 处理专家模式分阶段训练的业务逻辑 """ import uuid from datetime import datetime from typing import AsyncGenerator, Dict, List, Optional, Any from ..core.adapters import ( get_database_adapter, get_task_queue_adapter, get_progress_adapter, ) from ..models.schemas.experiment import ( ExperimentCreate, ExperimentUpdate, ExperimentResponse, ExperimentListResponse, StageStatus, StageExecuteResponse, StagesListResponse, STAGE_DEPENDENCIES, ) # 阶段类型列表(按执行顺序) STAGE_TYPES = [ "audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token", "sovits_train", "gpt_train", ] class ExperimentService: """ Advanced Mode 实验服务 提供专家模式的分阶段训练管理: - 创建实验 - 查询实验/阶段状态 - 执行/取消单个阶段 - 检查阶段依赖 Example: >>> service = ExperimentService() >>> exp = await service.create_experiment(request) >>> await service.execute_stage(exp.id, "audio_slice", {}) >>> stages = await service.get_all_stages(exp.id) """ def __init__(self): """初始化服务""" self._db = None self._queue = None self._progress = None @property def db(self): """延迟获取数据库适配器""" if self._db is None: self._db = get_database_adapter() return self._db @property def queue(self): """延迟获取任务队列适配器""" if self._queue is None: self._queue = get_task_queue_adapter() return self._queue @property def progress_adapter(self): """延迟获取进度适配器""" if self._progress is None: self._progress = get_progress_adapter() return self._progress async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse: """ 创建实验 创建实验但不立即执行,用户可以逐阶段控制训练流程。 Args: request: 创建实验请求 Returns: ExperimentResponse """ exp_id = f"exp-{uuid.uuid4().hex[:8]}" experiment_data = { "id": exp_id, "exp_name": request.exp_name, "version": request.version, "gpu_numbers": request.gpu_numbers, "is_half": request.is_half, "audio_file_id": request.audio_file_id, "status": "created", } # 创建实验(会自动创建所有阶段) experiment = await self.db.create_experiment(experiment_data) return self._experiment_to_response(experiment) async def get_experiment(self, exp_id: str) -> Optional[ExperimentResponse]: """ 获取实验详情 Args: exp_id: 实验ID Returns: ExperimentResponse 或 None """ experiment = await self.db.get_experiment(exp_id) if not experiment: return None return self._experiment_to_response(experiment) async def list_experiments( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> ExperimentListResponse: """ 获取实验列表 Args: status: 按状态筛选 limit: 每页数量 offset: 偏移量 Returns: ExperimentListResponse """ experiments = await self.db.list_experiments( status=status, limit=limit, offset=offset ) # 获取每个实验的完整信息(包含 stages) full_experiments = [] for exp in experiments: full_exp = await self.db.get_experiment(exp["id"]) if full_exp: full_experiments.append(full_exp) return ExperimentListResponse( items=[self._experiment_to_response(e) for e in full_experiments], total=len(experiments), # TODO: 添加 count 方法 limit=limit, offset=offset, ) async def update_experiment( self, exp_id: str, request: ExperimentUpdate ) -> Optional[ExperimentResponse]: """ 更新实验基础配置 Args: exp_id: 实验ID request: 更新请求 Returns: ExperimentResponse 或 None """ updates = {} if request.exp_name is not None: updates["exp_name"] = request.exp_name if request.gpu_numbers is not None: updates["gpu_numbers"] = request.gpu_numbers if request.is_half is not None: updates["is_half"] = request.is_half if not updates: return await self.get_experiment(exp_id) experiment = await self.db.update_experiment(exp_id, updates) if not experiment: return None return self._experiment_to_response(experiment) async def delete_experiment(self, exp_id: str) -> bool: """ 删除实验 Args: exp_id: 实验ID Returns: 是否成功删除 """ # 先取消所有运行中的阶段 stages = await self.db.get_all_stages(exp_id) for stage in stages: if stage.get("status") == "running" and stage.get("job_id"): await self.queue.cancel(stage["job_id"]) return await self.db.delete_experiment(exp_id) async def check_stage_dependencies( self, exp_id: str, stage_type: str ) -> Dict[str, Any]: """ 检查阶段依赖是否满足 Args: exp_id: 实验ID stage_type: 阶段类型 Returns: {"satisfied": bool, "missing": List[str]} """ experiment = await self.db.get_experiment(exp_id) if not experiment: return {"satisfied": False, "missing": [], "error": "实验不存在"} dependencies = STAGE_DEPENDENCIES.get(stage_type, []) stages = experiment.get("stages", {}) missing = [] for dep in dependencies: dep_stage = stages.get(dep, {}) if dep_stage.get("status") != "completed": missing.append(dep) return { "satisfied": len(missing) == 0, "missing": missing, } async def execute_stage( self, exp_id: str, stage_type: str, params: Dict[str, Any] ) -> Optional[StageExecuteResponse]: """ 执行指定阶段 Args: exp_id: 实验ID stage_type: 阶段类型 params: 阶段参数 Returns: StageExecuteResponse 或 None """ # 获取实验 experiment = await self.db.get_experiment(exp_id) if not experiment: return None stages = experiment.get("stages", {}) current_stage = stages.get(stage_type, {}) # 检查是否是重新执行 is_rerun = current_stage.get("status") == "completed" previous_run = None if is_rerun: previous_run = { "completed_at": current_stage.get("completed_at"), "outputs": current_stage.get("outputs"), } # 构建阶段配置 stage_config = { "exp_id": exp_id, "exp_name": experiment["exp_name"], "version": experiment.get("version", "v2"), "gpu_numbers": experiment.get("gpu_numbers", "0"), "is_half": experiment.get("is_half", True), "audio_file_id": experiment.get("audio_file_id"), "stage_type": stage_type, "params": params, # 只执行单个阶段 "stages": [stage_type], } # 生成任务ID(用于进度追踪) task_id = f"{exp_id}-{stage_type}-{uuid.uuid4().hex[:4]}" # 入队执行 job_id = await self.queue.enqueue(task_id, stage_config) # 更新阶段状态 now = datetime.utcnow() await self.db.update_stage(exp_id, stage_type, { "status": "running", "config": params, "job_id": job_id, "started_at": now, "completed_at": None, "error_message": None, "outputs": None, "progress": 0.0, }) # 更新实验状态为运行中 await self.db.update_experiment(exp_id, {"status": "running"}) return StageExecuteResponse( exp_id=exp_id, stage_type=stage_type, status="running", job_id=job_id, config=params, rerun=is_rerun, previous_run=previous_run, started_at=now, ) async def get_stage( self, exp_id: str, stage_type: str ) -> Optional[StageStatus]: """ 获取阶段状态 Args: exp_id: 实验ID stage_type: 阶段类型 Returns: StageStatus 或 None """ stage = await self.db.get_stage(exp_id, stage_type) if not stage: return None return self._stage_to_status(stage) async def get_all_stages(self, exp_id: str) -> Optional[StagesListResponse]: """ 获取所有阶段状态 Args: exp_id: 实验ID Returns: StagesListResponse 或 None """ stages = await self.db.get_all_stages(exp_id) if not stages: # 检查实验是否存在 experiment = await self.db.get_experiment(exp_id) if not experiment: return None stages = [] return StagesListResponse( exp_id=exp_id, stages=[self._stage_to_status(s) for s in stages], ) async def cancel_stage(self, exp_id: str, stage_type: str) -> bool: """ 取消正在执行的阶段 Args: exp_id: 实验ID stage_type: 阶段类型 Returns: 是否成功取消 """ stage = await self.db.get_stage(exp_id, stage_type) if not stage: return False # 只有运行中的阶段可以取消 if stage.get("status") != "running": return False # 取消任务 job_id = stage.get("job_id") if job_id: await self.queue.cancel(job_id) # 更新状态 await self.db.update_stage(exp_id, stage_type, { "status": "cancelled", "completed_at": datetime.utcnow(), "message": "阶段已取消", }) return True async def subscribe_stage_progress( self, exp_id: str, stage_type: str ) -> AsyncGenerator[Dict[str, Any], None]: """ 订阅阶段进度(SSE 流) Args: exp_id: 实验ID stage_type: 阶段类型 Yields: 进度信息字典 """ # 获取阶段信息 stage = await self.db.get_stage(exp_id, stage_type) if not stage: yield {"type": "error", "message": "阶段不存在"} return # 如果阶段已结束,直接返回最终状态 if stage.get("status") in ("completed", "failed", "cancelled"): yield { "type": "final", "status": stage.get("status"), "message": stage.get("message") or stage.get("error_message"), "progress": stage.get("progress", 0.0), "outputs": stage.get("outputs"), } return # 如果阶段未开始 if stage.get("status") == "pending": yield {"type": "info", "message": "阶段尚未开始"} return # 使用任务ID订阅进度 # 任务ID格式: {exp_id}-{stage_type}-{random} # 由于我们不知道确切的任务ID,使用 job_id job_id = stage.get("job_id") if not job_id: yield {"type": "error", "message": "无法获取任务ID"} return # 订阅进度 # 注意:这里需要根据实际的进度适配器实现来调整 # 当前使用 task_id 格式为 "{exp_id}-{stage_type}" task_id = f"{exp_id}-{stage_type}" async for progress in self.progress_adapter.subscribe(task_id): yield progress # 检查是否为终态 if progress.get("status") in ("completed", "failed", "cancelled"): break def _experiment_to_response(self, experiment: Dict[str, Any]) -> ExperimentResponse: """将实验数据转换为响应模型""" stages_data = experiment.get("stages", {}) stages = {} for stage_type, stage_info in stages_data.items(): stages[stage_type] = self._stage_to_status(stage_info) # 解析日期时间 created_at = experiment.get("created_at") if isinstance(created_at, str): created_at = datetime.fromisoformat(created_at) elif created_at is None: created_at = datetime.utcnow() updated_at = experiment.get("updated_at") if isinstance(updated_at, str): updated_at = datetime.fromisoformat(updated_at) return ExperimentResponse( id=experiment["id"], exp_name=experiment["exp_name"], version=experiment.get("version", "v2"), status=experiment.get("status", "created"), gpu_numbers=experiment.get("gpu_numbers", "0"), is_half=experiment.get("is_half", True), audio_file_id=experiment.get("audio_file_id", ""), stages=stages, created_at=created_at, updated_at=updated_at, ) def _stage_to_status(self, stage: Dict[str, Any]) -> StageStatus: """将阶段数据转换为状态模型""" # 解析日期时间 started_at = stage.get("started_at") if isinstance(started_at, str): started_at = datetime.fromisoformat(started_at) completed_at = stage.get("completed_at") if isinstance(completed_at, str): completed_at = datetime.fromisoformat(completed_at) return StageStatus( stage_type=stage.get("stage_type", ""), status=stage.get("status", "pending"), progress=stage.get("progress"), message=stage.get("message"), started_at=started_at, completed_at=completed_at, config=stage.get("config"), outputs=stage.get("outputs"), error_message=stage.get("error_message"), )