Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Typed models for WorkflowArena. | |
| """ | |
| from __future__ import annotations | |
| from enum import Enum | |
| from openenv.core.env_server.types import Action, Observation | |
| from pydantic import BaseModel, Field | |
| class TaskStatus(str, Enum): | |
| """Allowed lifecycle states for a workflow task.""" | |
| BLOCKED = "blocked" | |
| READY = "ready" | |
| RUNNING = "running" | |
| COMPLETED = "completed" | |
| class DifficultyPreset(str, Enum): | |
| """Initial task presets required by the hackathon.""" | |
| EASY = "easy" | |
| MEDIUM = "medium" | |
| HARD = "hard" | |
| class WorkflowActionType(str, Enum): | |
| """Explicit action space for the scheduler agent.""" | |
| DISPATCH = "dispatch" | |
| WAIT = "wait" | |
| class RewardBreakdown(BaseModel): | |
| """Named reward channels for shaped feedback.""" | |
| completion_reward: float = Field( | |
| default=0.0, description="Reward for completing tasks." | |
| ) | |
| utilization_reward: float = Field( | |
| default=0.0, description="Reward for keeping workers busy." | |
| ) | |
| deadline_reward: float = Field( | |
| default=0.0, description="Reward or penalty tied to deadlines." | |
| ) | |
| criticality_reward: float = Field( | |
| default=0.0, | |
| description="Reward for prioritizing critical-path work appropriately.", | |
| ) | |
| idle_penalty: float = Field( | |
| default=0.0, description="Penalty for leaving workers idle." | |
| ) | |
| invalid_action_penalty: float = Field( | |
| default=0.0, | |
| description="Penalty for malformed or infeasible actions.", | |
| ) | |
| terminal_makespan_score: float = Field( | |
| default=0.0, | |
| description="Terminal score based on final schedule quality.", | |
| ) | |
| unfinished_task_penalty: float = Field( | |
| default=0.0, | |
| description="Terminal penalty for unfinished work at episode end.", | |
| ) | |
| class FailureEventType(str, Enum): | |
| """Failure events surfaced to agents and the UI.""" | |
| WORKER_OUTAGE_START = "worker_outage_start" | |
| WORKER_OUTAGE_END = "worker_outage_end" | |
| TASK_RETRY_FAILURE = "task_retry_failure" | |
| class WorkflowFailureEvent(BaseModel): | |
| """Structured failure event emitted by the environment.""" | |
| event_type: FailureEventType = Field(..., description="Failure category.") | |
| time: int = Field(..., ge=0, description="Simulated time when the event was observed.") | |
| task_id: str | None = Field(default=None, description="Task affected by the event, if any.") | |
| worker_delta: int = Field(default=0, description="Net temporary change in usable workers.") | |
| duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.") | |
| detail: str = Field(default="", description="Short human-readable summary.") | |
| class WorkflowTaskView(BaseModel): | |
| """Compact task payload used in observations and the future UI.""" | |
| task_id: str = Field(..., description="Stable task identifier.") | |
| status: TaskStatus = Field(..., description="Current task lifecycle state.") | |
| duration: int = Field( | |
| ..., ge=1, description="Task runtime in simulated time units." | |
| ) | |
| priority: int = Field(..., ge=0, description="Priority weight for the task.") | |
| dependencies: list[str] = Field( | |
| default_factory=list, | |
| description="Upstream task ids that must complete first.", | |
| ) | |
| deadline: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Optional deadline in simulated time units.", | |
| ) | |
| criticality: float | None = Field( | |
| default=None, | |
| description="Derived importance score from the DAG structure.", | |
| ) | |
| slack: float | None = Field( | |
| default=None, | |
| description="Derived slack estimate for scheduling decisions.", | |
| ) | |
| downstream_count: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Count of downstream dependents reachable from this task.", | |
| ) | |
| start_time: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Simulated start time if the task is running or completed.", | |
| ) | |
| end_time: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Simulated end time if the task is completed or scheduled to finish.", | |
| ) | |
| attempt_count: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Number of retry attempts already consumed by this task.", | |
| ) | |
| class WorkflowTaskSpec(BaseModel): | |
| """Static task specification generated at episode reset.""" | |
| task_id: str = Field(..., description="Stable task identifier.") | |
| duration: int = Field(..., ge=1, description="Task runtime in simulated time units.") | |
| priority: int = Field(..., ge=0, description="Priority weight for the task.") | |
| dependencies: list[str] = Field( | |
| default_factory=list, | |
| description="Upstream task ids that must complete first.", | |
| ) | |
| dependents: list[str] = Field( | |
| default_factory=list, | |
| description="Downstream task ids that depend on this task.", | |
| ) | |
| deadline: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Optional deadline in simulated time units.", | |
| ) | |
| downstream_count: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Number of downstream tasks reachable from this node.", | |
| ) | |
| critical_path_length: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Duration-weighted path length from this task to a sink.", | |
| ) | |
| earliest_start: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Earliest feasible start time under dependency constraints.", | |
| ) | |
| slack: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Scheduling slack measured in simulated time units.", | |
| ) | |
| criticality: float = Field( | |
| default=0.0, | |
| description="Normalized importance score derived from critical path and downstream impact.", | |
| ) | |
| class ProgressSummary(BaseModel): | |
| """Counts by task lifecycle state.""" | |
| total: int = Field(default=0, ge=0) | |
| blocked: int = Field(default=0, ge=0) | |
| ready: int = Field(default=0, ge=0) | |
| running: int = Field(default=0, ge=0) | |
| completed: int = Field(default=0, ge=0) | |
| class EpisodeConfig(BaseModel): | |
| """Reset-time knobs that define the episode.""" | |
| preset: DifficultyPreset = Field( | |
| default=DifficultyPreset.EASY, | |
| description="Difficulty preset for the episode generator.", | |
| ) | |
| seed: int = Field( | |
| default=0, description="Seed for deterministic episode generation." | |
| ) | |
| worker_count: int = Field( | |
| default=2, | |
| ge=1, | |
| description="Number of identical workers available to the scheduler.", | |
| ) | |
| class GraderTarget(BaseModel): | |
| """High-level target bands for each preset's grader.""" | |
| description: str = Field(..., description="What good performance means for the preset.") | |
| score_band_hint: str = Field(..., description="Human-readable interpretation of scores.") | |
| class DifficultyPresetConfig(BaseModel): | |
| """Concrete generator knobs for a preset.""" | |
| preset: DifficultyPreset = Field(..., description="Preset identifier.") | |
| min_tasks: int = Field(..., ge=2) | |
| max_tasks: int = Field(..., ge=2) | |
| edge_probability: float = Field(..., ge=0.0, le=1.0) | |
| duration_min: int = Field(..., ge=1) | |
| duration_max: int = Field(..., ge=1) | |
| priority_min: int = Field(..., ge=0) | |
| priority_max: int = Field(..., ge=0) | |
| worker_count: int = Field(..., ge=1) | |
| deadline_tightness: float = Field( | |
| ..., | |
| ge=0.0, | |
| description="Larger values mean tighter deadlines.", | |
| ) | |
| time_budget_multiplier: float | None = Field( | |
| default=None, | |
| gt=0.0, | |
| description="Optional multiplier over the theoretical lower-bound makespan.", | |
| ) | |
| worker_outage_rate: float = Field( | |
| default=0.0, | |
| ge=0.0, | |
| le=1.0, | |
| description="Chance of a hard-mode worker outage being sampled on a wait transition.", | |
| ) | |
| worker_outage_duration_min: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Minimum outage duration in simulated time units.", | |
| ) | |
| worker_outage_duration_max: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Maximum outage duration in simulated time units.", | |
| ) | |
| task_retry_failure_rate: float = Field( | |
| default=0.0, | |
| ge=0.0, | |
| le=1.0, | |
| description="Chance that a hard-mode task completion becomes a retry failure.", | |
| ) | |
| max_task_retries: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Maximum number of retry failures a task may suffer before it must complete.", | |
| ) | |
| grader_target: GraderTarget = Field( | |
| ..., | |
| description="Preset-specific grader interpretation.", | |
| ) | |
| class WorkflowEpisodeSpec(BaseModel): | |
| """Static episode description produced by the generator.""" | |
| config: EpisodeConfig = Field(..., description="Reset-time configuration.") | |
| preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.") | |
| tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.") | |
| class WorkflowEnvStateSnapshot(BaseModel): | |
| """Serializable environment state for the current episode.""" | |
| episode_id: str = Field(..., description="Stable current episode identifier.") | |
| current_time: int = Field(default=0, ge=0, description="Current simulated time.") | |
| task_statuses: dict[str, TaskStatus] = Field( | |
| default_factory=dict, | |
| description="Current task status by task id.", | |
| ) | |
| running_task_ids: list[str] = Field( | |
| default_factory=list, | |
| description="Tasks currently consuming workers.", | |
| ) | |
| completed_task_ids: list[str] = Field( | |
| default_factory=list, | |
| description="Tasks that have completed.", | |
| ) | |
| ready_task_ids: list[str] = Field( | |
| default_factory=list, | |
| description="Tasks currently ready for dispatch.", | |
| ) | |
| blocked_task_ids: list[str] = Field( | |
| default_factory=list, | |
| description="Tasks still blocked on dependencies.", | |
| ) | |
| task_start_times: dict[str, int] = Field( | |
| default_factory=dict, | |
| description="Simulated start time by task id.", | |
| ) | |
| task_end_times: dict[str, int] = Field( | |
| default_factory=dict, | |
| description="Simulated completion time by task id.", | |
| ) | |
| task_remaining_dependencies: dict[str, int] = Field( | |
| default_factory=dict, | |
| description="Remaining unfinished prerequisites by task id.", | |
| ) | |
| task_assigned_finish_times: dict[str, int] = Field( | |
| default_factory=dict, | |
| description="Predicted completion times for currently running tasks.", | |
| ) | |
| task_attempt_counts: dict[str, int] = Field( | |
| default_factory=dict, | |
| description="Retry attempts consumed by each task.", | |
| ) | |
| cumulative_busy_time: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Aggregate worker busy time accrued so far.", | |
| ) | |
| time_budget: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Optional terminal time budget for the episode.", | |
| ) | |
| degraded_workers: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Workers temporarily removed from usable capacity.", | |
| ) | |
| active_worker_outage_until: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Time when the current worker outage expires, if any.", | |
| ) | |
| recent_failure_events: list[WorkflowFailureEvent] = Field( | |
| default_factory=list, | |
| description="Failure events generated on the latest transition.", | |
| ) | |
| class SuccessMetrics(BaseModel): | |
| """Primary quality metrics used for grading and demos.""" | |
| makespan: int | None = Field( | |
| default=None, description="Total simulated completion time." | |
| ) | |
| worker_utilization: float | None = Field( | |
| default=None, | |
| description="Fraction of available worker time that was used.", | |
| ) | |
| deadline_miss_count: int = Field( | |
| default=0, ge=0, description="Missed task deadlines." | |
| ) | |
| unfinished_task_count: int = Field( | |
| default=0, ge=0, description="Tasks left incomplete at terminal time." | |
| ) | |
| weighted_priority_completion: float | None = Field( | |
| default=None, | |
| description="Priority-weighted on-time completion score.", | |
| ) | |
| benchmark_score: float | None = Field( | |
| default=None, | |
| description="Deterministic terminal benchmark score in the 0.0-1.0 range.", | |
| ) | |
| class WorkflowArenaAction(Action): | |
| """Strict action space for the workflow scheduler.""" | |
| action_type: WorkflowActionType = Field( | |
| ..., | |
| description="Dispatch ready tasks or wait for the next completion event.", | |
| ) | |
| task_ids: list[str] = Field( | |
| default_factory=list, | |
| description="Task ids to dispatch. Must be empty for wait().", | |
| ) | |
| class WorkflowArenaObservation(Observation): | |
| """Compact, typed observation contract for WorkflowArena.""" | |
| instruction: str = Field( | |
| default=( | |
| "Schedule dependency-constrained workflow tasks on limited workers using " | |
| "dispatch(task_ids=[...]) or wait()." | |
| ), | |
| description="Short prompt shown to inference agents.", | |
| ) | |
| config: EpisodeConfig = Field( | |
| default_factory=EpisodeConfig, | |
| description="Episode generation settings.", | |
| ) | |
| current_time: int = Field(default=0, ge=0, description="Current simulated time.") | |
| total_workers: int = Field(default=2, ge=1, description="Total identical workers.") | |
| effective_workers: int = Field( | |
| default=2, | |
| ge=0, | |
| description="Usable workers after temporary degradation is applied.", | |
| ) | |
| degraded_workers: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Workers currently unavailable due to outages.", | |
| ) | |
| free_workers: int = Field(default=2, ge=0, description="Currently idle workers.") | |
| time_budget: int | None = Field( | |
| default=None, | |
| ge=0, | |
| description="Optional terminal time budget for the current episode.", | |
| ) | |
| time_remaining: int | None = Field( | |
| default=None, | |
| description="Remaining time until the episode budget expires, if budgeted.", | |
| ) | |
| progress: ProgressSummary = Field( | |
| default_factory=ProgressSummary, | |
| description="Task counts by lifecycle state.", | |
| ) | |
| ready_tasks: list[WorkflowTaskView] = Field( | |
| default_factory=list, | |
| description="Ready tasks eligible for dispatch.", | |
| ) | |
| running_tasks: list[WorkflowTaskView] = Field( | |
| default_factory=list, | |
| description="Tasks currently consuming workers.", | |
| ) | |
| completed_tasks: list[WorkflowTaskView] = Field( | |
| default_factory=list, | |
| description="Tasks already completed.", | |
| ) | |
| blocked_tasks: list[WorkflowTaskView] = Field( | |
| default_factory=list, | |
| description="Tasks still waiting on dependencies.", | |
| ) | |
| last_reward_breakdown: RewardBreakdown = Field( | |
| default_factory=RewardBreakdown, | |
| description="Per-step reward channel breakdown.", | |
| ) | |
| cumulative_reward: float = Field(default=0.0, description="Running total reward.") | |
| success_metrics: SuccessMetrics = Field( | |
| default_factory=SuccessMetrics, | |
| description="Primary schedule quality metrics.", | |
| ) | |
| note: str | None = Field( | |
| default=None, | |
| description="Short environment note about the latest transition.", | |
| ) | |
| validation_error: str | None = Field( | |
| default=None, | |
| description="Explicit invalid-action explanation when the previous action failed.", | |
| ) | |
| termination_reason: str | None = Field( | |
| default=None, | |
| description="Terminal reason when the episode ended unsuccessfully.", | |
| ) | |
| benchmark_score: float | None = Field( | |
| default=None, | |
| description="Top-level bounded benchmark score for easier client access.", | |
| ) | |
| recent_failure_events: list[WorkflowFailureEvent] = Field( | |
| default_factory=list, | |
| description="Failure events generated on the latest accepted transition.", | |
| ) | |
| received_action: dict[str, object] | None = Field( | |
| default=None, | |
| description="Last action accepted by the server for logging and prompting.", | |
| ) | |