| """ |
| Advanced Mode 实验服务 |
| |
| 处理专家模式分阶段训练的业务逻辑 |
| """ |
|
|
| import uuid |
| from datetime import datetime |
| from typing import AsyncGenerator, Dict, List, Optional, Any |
|
|
| from ..core.adapters import ( |
| get_database_adapter, |
| get_task_queue_adapter, |
| get_progress_adapter, |
| ) |
| from ..models.schemas.experiment import ( |
| ExperimentCreate, |
| ExperimentUpdate, |
| ExperimentResponse, |
| ExperimentListResponse, |
| StageStatus, |
| StageExecuteResponse, |
| StagesListResponse, |
| STAGE_DEPENDENCIES, |
| ) |
|
|
|
|
| |
| STAGE_TYPES = [ |
| "audio_slice", |
| "asr", |
| "text_feature", |
| "hubert_feature", |
| "semantic_token", |
| "sovits_train", |
| "gpt_train", |
| ] |
|
|
|
|
| class ExperimentService: |
| """ |
| Advanced Mode 实验服务 |
| |
| 提供专家模式的分阶段训练管理: |
| - 创建实验 |
| - 查询实验/阶段状态 |
| - 执行/取消单个阶段 |
| - 检查阶段依赖 |
| |
| Example: |
| >>> service = ExperimentService() |
| >>> exp = await service.create_experiment(request) |
| >>> await service.execute_stage(exp.id, "audio_slice", {}) |
| >>> stages = await service.get_all_stages(exp.id) |
| """ |
| |
| def __init__(self): |
| """初始化服务""" |
| self._db = None |
| self._queue = None |
| self._progress = None |
| |
| @property |
| def db(self): |
| """延迟获取数据库适配器""" |
| if self._db is None: |
| self._db = get_database_adapter() |
| return self._db |
| |
| @property |
| def queue(self): |
| """延迟获取任务队列适配器""" |
| if self._queue is None: |
| self._queue = get_task_queue_adapter() |
| return self._queue |
| |
| @property |
| def progress_adapter(self): |
| """延迟获取进度适配器""" |
| if self._progress is None: |
| self._progress = get_progress_adapter() |
| return self._progress |
| |
| async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse: |
| """ |
| 创建实验 |
| |
| 创建实验但不立即执行,用户可以逐阶段控制训练流程。 |
| |
| Args: |
| request: 创建实验请求 |
| |
| Returns: |
| ExperimentResponse |
| """ |
| exp_id = f"exp-{uuid.uuid4().hex[:8]}" |
| |
| experiment_data = { |
| "id": exp_id, |
| "exp_name": request.exp_name, |
| "version": request.version, |
| "gpu_numbers": request.gpu_numbers, |
| "is_half": request.is_half, |
| "audio_file_id": request.audio_file_id, |
| "status": "created", |
| } |
| |
| |
| experiment = await self.db.create_experiment(experiment_data) |
| |
| return self._experiment_to_response(experiment) |
| |
| async def get_experiment(self, exp_id: str) -> Optional[ExperimentResponse]: |
| """ |
| 获取实验详情 |
| |
| Args: |
| exp_id: 实验ID |
| |
| Returns: |
| ExperimentResponse 或 None |
| """ |
| experiment = await self.db.get_experiment(exp_id) |
| if not experiment: |
| return None |
| return self._experiment_to_response(experiment) |
| |
| async def list_experiments( |
| self, |
| status: Optional[str] = None, |
| limit: int = 50, |
| offset: int = 0 |
| ) -> ExperimentListResponse: |
| """ |
| 获取实验列表 |
| |
| Args: |
| status: 按状态筛选 |
| limit: 每页数量 |
| offset: 偏移量 |
| |
| Returns: |
| ExperimentListResponse |
| """ |
| experiments = await self.db.list_experiments( |
| status=status, limit=limit, offset=offset |
| ) |
| |
| |
| full_experiments = [] |
| for exp in experiments: |
| full_exp = await self.db.get_experiment(exp["id"]) |
| if full_exp: |
| full_experiments.append(full_exp) |
| |
| return ExperimentListResponse( |
| items=[self._experiment_to_response(e) for e in full_experiments], |
| total=len(experiments), |
| limit=limit, |
| offset=offset, |
| ) |
| |
| async def update_experiment( |
| self, |
| exp_id: str, |
| request: ExperimentUpdate |
| ) -> Optional[ExperimentResponse]: |
| """ |
| 更新实验基础配置 |
| |
| Args: |
| exp_id: 实验ID |
| request: 更新请求 |
| |
| Returns: |
| ExperimentResponse 或 None |
| """ |
| updates = {} |
| if request.exp_name is not None: |
| updates["exp_name"] = request.exp_name |
| if request.gpu_numbers is not None: |
| updates["gpu_numbers"] = request.gpu_numbers |
| if request.is_half is not None: |
| updates["is_half"] = request.is_half |
| |
| if not updates: |
| return await self.get_experiment(exp_id) |
| |
| experiment = await self.db.update_experiment(exp_id, updates) |
| if not experiment: |
| return None |
| |
| return self._experiment_to_response(experiment) |
| |
| async def delete_experiment(self, exp_id: str) -> bool: |
| """ |
| 删除实验 |
| |
| Args: |
| exp_id: 实验ID |
| |
| Returns: |
| 是否成功删除 |
| """ |
| |
| stages = await self.db.get_all_stages(exp_id) |
| for stage in stages: |
| if stage.get("status") == "running" and stage.get("job_id"): |
| await self.queue.cancel(stage["job_id"]) |
| |
| return await self.db.delete_experiment(exp_id) |
| |
| async def check_stage_dependencies( |
| self, |
| exp_id: str, |
| stage_type: str |
| ) -> Dict[str, Any]: |
| """ |
| 检查阶段依赖是否满足 |
| |
| Args: |
| exp_id: 实验ID |
| stage_type: 阶段类型 |
| |
| Returns: |
| {"satisfied": bool, "missing": List[str]} |
| """ |
| experiment = await self.db.get_experiment(exp_id) |
| if not experiment: |
| return {"satisfied": False, "missing": [], "error": "实验不存在"} |
| |
| dependencies = STAGE_DEPENDENCIES.get(stage_type, []) |
| stages = experiment.get("stages", {}) |
| |
| missing = [] |
| for dep in dependencies: |
| dep_stage = stages.get(dep, {}) |
| if dep_stage.get("status") != "completed": |
| missing.append(dep) |
| |
| return { |
| "satisfied": len(missing) == 0, |
| "missing": missing, |
| } |
| |
| async def execute_stage( |
| self, |
| exp_id: str, |
| stage_type: str, |
| params: Dict[str, Any] |
| ) -> Optional[StageExecuteResponse]: |
| """ |
| 执行指定阶段 |
| |
| Args: |
| exp_id: 实验ID |
| stage_type: 阶段类型 |
| params: 阶段参数 |
| |
| Returns: |
| StageExecuteResponse 或 None |
| """ |
| |
| experiment = await self.db.get_experiment(exp_id) |
| if not experiment: |
| return None |
| |
| stages = experiment.get("stages", {}) |
| current_stage = stages.get(stage_type, {}) |
| |
| |
| is_rerun = current_stage.get("status") == "completed" |
| previous_run = None |
| if is_rerun: |
| previous_run = { |
| "completed_at": current_stage.get("completed_at"), |
| "outputs": current_stage.get("outputs"), |
| } |
| |
| |
| stage_config = { |
| "exp_id": exp_id, |
| "exp_name": experiment["exp_name"], |
| "version": experiment.get("version", "v2"), |
| "gpu_numbers": experiment.get("gpu_numbers", "0"), |
| "is_half": experiment.get("is_half", True), |
| "audio_file_id": experiment.get("audio_file_id"), |
| "stage_type": stage_type, |
| "params": params, |
| |
| "stages": [stage_type], |
| } |
| |
| |
| task_id = f"{exp_id}-{stage_type}-{uuid.uuid4().hex[:4]}" |
| |
| |
| job_id = await self.queue.enqueue(task_id, stage_config) |
| |
| |
| now = datetime.utcnow() |
| await self.db.update_stage(exp_id, stage_type, { |
| "status": "running", |
| "config": params, |
| "job_id": job_id, |
| "started_at": now, |
| "completed_at": None, |
| "error_message": None, |
| "outputs": None, |
| "progress": 0.0, |
| }) |
| |
| |
| await self.db.update_experiment(exp_id, {"status": "running"}) |
| |
| return StageExecuteResponse( |
| exp_id=exp_id, |
| stage_type=stage_type, |
| status="running", |
| job_id=job_id, |
| config=params, |
| rerun=is_rerun, |
| previous_run=previous_run, |
| started_at=now, |
| ) |
| |
| async def get_stage( |
| self, |
| exp_id: str, |
| stage_type: str |
| ) -> Optional[StageStatus]: |
| """ |
| 获取阶段状态 |
| |
| Args: |
| exp_id: 实验ID |
| stage_type: 阶段类型 |
| |
| Returns: |
| StageStatus 或 None |
| """ |
| stage = await self.db.get_stage(exp_id, stage_type) |
| if not stage: |
| return None |
| return self._stage_to_status(stage) |
| |
| async def get_all_stages(self, exp_id: str) -> Optional[StagesListResponse]: |
| """ |
| 获取所有阶段状态 |
| |
| Args: |
| exp_id: 实验ID |
| |
| Returns: |
| StagesListResponse 或 None |
| """ |
| stages = await self.db.get_all_stages(exp_id) |
| if not stages: |
| |
| experiment = await self.db.get_experiment(exp_id) |
| if not experiment: |
| return None |
| stages = [] |
| |
| return StagesListResponse( |
| exp_id=exp_id, |
| stages=[self._stage_to_status(s) for s in stages], |
| ) |
| |
| async def cancel_stage(self, exp_id: str, stage_type: str) -> bool: |
| """ |
| 取消正在执行的阶段 |
| |
| Args: |
| exp_id: 实验ID |
| stage_type: 阶段类型 |
| |
| Returns: |
| 是否成功取消 |
| """ |
| stage = await self.db.get_stage(exp_id, stage_type) |
| if not stage: |
| return False |
| |
| |
| if stage.get("status") != "running": |
| return False |
| |
| |
| job_id = stage.get("job_id") |
| if job_id: |
| await self.queue.cancel(job_id) |
| |
| |
| await self.db.update_stage(exp_id, stage_type, { |
| "status": "cancelled", |
| "completed_at": datetime.utcnow(), |
| "message": "阶段已取消", |
| }) |
| |
| return True |
| |
| async def subscribe_stage_progress( |
| self, |
| exp_id: str, |
| stage_type: str |
| ) -> AsyncGenerator[Dict[str, Any], None]: |
| """ |
| 订阅阶段进度(SSE 流) |
| |
| Args: |
| exp_id: 实验ID |
| stage_type: 阶段类型 |
| |
| Yields: |
| 进度信息字典 |
| """ |
| |
| stage = await self.db.get_stage(exp_id, stage_type) |
| if not stage: |
| yield {"type": "error", "message": "阶段不存在"} |
| return |
| |
| |
| if stage.get("status") in ("completed", "failed", "cancelled"): |
| yield { |
| "type": "final", |
| "status": stage.get("status"), |
| "message": stage.get("message") or stage.get("error_message"), |
| "progress": stage.get("progress", 0.0), |
| "outputs": stage.get("outputs"), |
| } |
| return |
| |
| |
| if stage.get("status") == "pending": |
| yield {"type": "info", "message": "阶段尚未开始"} |
| return |
| |
| |
| |
| |
| job_id = stage.get("job_id") |
| if not job_id: |
| yield {"type": "error", "message": "无法获取任务ID"} |
| return |
| |
| |
| |
| |
| task_id = f"{exp_id}-{stage_type}" |
| |
| async for progress in self.progress_adapter.subscribe(task_id): |
| yield progress |
| |
| |
| if progress.get("status") in ("completed", "failed", "cancelled"): |
| break |
| |
| def _experiment_to_response(self, experiment: Dict[str, Any]) -> ExperimentResponse: |
| """将实验数据转换为响应模型""" |
| stages_data = experiment.get("stages", {}) |
| stages = {} |
| |
| for stage_type, stage_info in stages_data.items(): |
| stages[stage_type] = self._stage_to_status(stage_info) |
| |
| |
| created_at = experiment.get("created_at") |
| if isinstance(created_at, str): |
| created_at = datetime.fromisoformat(created_at) |
| elif created_at is None: |
| created_at = datetime.utcnow() |
| |
| updated_at = experiment.get("updated_at") |
| if isinstance(updated_at, str): |
| updated_at = datetime.fromisoformat(updated_at) |
| |
| return ExperimentResponse( |
| id=experiment["id"], |
| exp_name=experiment["exp_name"], |
| version=experiment.get("version", "v2"), |
| status=experiment.get("status", "created"), |
| gpu_numbers=experiment.get("gpu_numbers", "0"), |
| is_half=experiment.get("is_half", True), |
| audio_file_id=experiment.get("audio_file_id", ""), |
| stages=stages, |
| created_at=created_at, |
| updated_at=updated_at, |
| ) |
| |
| def _stage_to_status(self, stage: Dict[str, Any]) -> StageStatus: |
| """将阶段数据转换为状态模型""" |
| |
| started_at = stage.get("started_at") |
| if isinstance(started_at, str): |
| started_at = datetime.fromisoformat(started_at) |
| |
| completed_at = stage.get("completed_at") |
| if isinstance(completed_at, str): |
| completed_at = datetime.fromisoformat(completed_at) |
| |
| return StageStatus( |
| stage_type=stage.get("stage_type", ""), |
| status=stage.get("status", "pending"), |
| progress=stage.get("progress"), |
| message=stage.get("message"), |
| started_at=started_at, |
| completed_at=completed_at, |
| config=stage.get("config"), |
| outputs=stage.get("outputs"), |
| error_message=stage.get("error_message"), |
| ) |
|
|