| """ |
| 领域模型模块 |
| |
| 定义训练任务相关的核心数据结构 |
| """ |
|
|
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from enum import Enum |
| from typing import Dict, Optional, Any |
|
|
|
|
| class TaskStatus(Enum): |
| """任务状态枚举""" |
| QUEUED = "queued" |
| RUNNING = "running" |
| COMPLETED = "completed" |
| FAILED = "failed" |
| CANCELLED = "cancelled" |
| INTERRUPTED = "interrupted" |
|
|
|
|
| @dataclass |
| class Task: |
| """ |
| 训练任务领域模型 |
| |
| Attributes: |
| id: 任务唯一标识 |
| job_id: 队列作业ID(由任务队列生成) |
| exp_name: 实验名称 |
| status: 任务状态 |
| config: 任务配置(包含所有训练参数) |
| current_stage: 当前执行阶段 |
| progress: 总体进度 (0.0-1.0) |
| stage_progress: 当前阶段进度 (0.0-1.0) |
| message: 最新状态消息 |
| error_message: 错误信息(失败时) |
| created_at: 创建时间 |
| started_at: 开始执行时间 |
| completed_at: 完成时间 |
| |
| Example: |
| >>> task = Task( |
| ... id="task-123", |
| ... exp_name="my_voice", |
| ... config={"version": "v2", "batch_size": 4} |
| ... ) |
| >>> task.status |
| <TaskStatus.QUEUED: 'queued'> |
| """ |
| id: str |
| exp_name: str |
| config: Dict[str, Any] |
| job_id: Optional[str] = None |
| status: TaskStatus = TaskStatus.QUEUED |
| current_stage: Optional[str] = None |
| progress: float = 0.0 |
| stage_progress: float = 0.0 |
| message: Optional[str] = None |
| error_message: Optional[str] = None |
| created_at: datetime = field(default_factory=datetime.utcnow) |
| started_at: Optional[datetime] = None |
| completed_at: Optional[datetime] = None |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """转换为字典""" |
| return { |
| "id": self.id, |
| "job_id": self.job_id, |
| "exp_name": self.exp_name, |
| "status": self.status.value, |
| "config": self.config, |
| "current_stage": self.current_stage, |
| "progress": self.progress, |
| "stage_progress": self.stage_progress, |
| "message": self.message, |
| "error_message": self.error_message, |
| "created_at": self.created_at.isoformat() if self.created_at else None, |
| "started_at": self.started_at.isoformat() if self.started_at else None, |
| "completed_at": self.completed_at.isoformat() if self.completed_at else None, |
| } |
| |
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "Task": |
| """从字典创建实例""" |
| |
| status = data.get("status", "queued") |
| if isinstance(status, str): |
| status = TaskStatus(status) |
| |
| |
| def parse_datetime(value): |
| if value is None: |
| return None |
| if isinstance(value, datetime): |
| return value |
| return datetime.fromisoformat(value) |
| |
| return cls( |
| id=data["id"], |
| job_id=data.get("job_id"), |
| exp_name=data["exp_name"], |
| status=status, |
| config=data.get("config", {}), |
| current_stage=data.get("current_stage"), |
| progress=data.get("progress", 0.0), |
| stage_progress=data.get("stage_progress", 0.0), |
| message=data.get("message"), |
| error_message=data.get("error_message"), |
| created_at=parse_datetime(data.get("created_at")), |
| started_at=parse_datetime(data.get("started_at")), |
| completed_at=parse_datetime(data.get("completed_at")), |
| ) |
|
|
|
|
| @dataclass |
| class ProgressInfo: |
| """ |
| 进度信息数据结构 |
| |
| 用于在子进程和主进程之间传递进度更新 |
| |
| Attributes: |
| type: 消息类型 ("progress", "log", "error", "heartbeat") |
| stage: 当前阶段名称 |
| stage_index: 当前阶段索引 |
| total_stages: 总阶段数 |
| progress: 阶段内进度 (0.0-1.0) |
| overall_progress: 总体进度 (0.0-1.0) |
| message: 进度消息 |
| status: 状态 |
| data: 附加数据 |
| """ |
| type: str = "progress" |
| stage: Optional[str] = None |
| stage_index: Optional[int] = None |
| total_stages: Optional[int] = None |
| progress: float = 0.0 |
| overall_progress: float = 0.0 |
| message: Optional[str] = None |
| status: Optional[str] = None |
| data: Dict[str, Any] = field(default_factory=dict) |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """转换为字典""" |
| return { |
| "type": self.type, |
| "stage": self.stage, |
| "stage_index": self.stage_index, |
| "total_stages": self.total_stages, |
| "progress": self.progress, |
| "overall_progress": self.overall_progress, |
| "message": self.message, |
| "status": self.status, |
| "data": self.data, |
| } |
| |
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo": |
| """从字典创建实例""" |
| return cls( |
| type=data.get("type", "progress"), |
| stage=data.get("stage"), |
| stage_index=data.get("stage_index"), |
| total_stages=data.get("total_stages"), |
| progress=data.get("progress", 0.0), |
| overall_progress=data.get("overall_progress", 0.0), |
| message=data.get("message"), |
| status=data.get("status"), |
| data=data.get("data", {}), |
| ) |
|
|