# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Typed models for WorkflowArena. """ from __future__ import annotations from enum import Enum from openenv.core.env_server.types import Action, Observation from pydantic import BaseModel, Field class TaskStatus(str, Enum): """Allowed lifecycle states for a workflow task.""" BLOCKED = "blocked" READY = "ready" RUNNING = "running" COMPLETED = "completed" class DifficultyPreset(str, Enum): """Initial task presets required by the hackathon.""" EASY = "easy" MEDIUM = "medium" HARD = "hard" class WorkflowActionType(str, Enum): """Explicit action space for the scheduler agent.""" DISPATCH = "dispatch" WAIT = "wait" class RewardBreakdown(BaseModel): """Named reward channels for shaped feedback.""" completion_reward: float = Field( default=0.0, description="Reward for completing tasks." ) utilization_reward: float = Field( default=0.0, description="Reward for keeping workers busy." ) deadline_reward: float = Field( default=0.0, description="Reward or penalty tied to deadlines." ) criticality_reward: float = Field( default=0.0, description="Reward for prioritizing critical-path work appropriately.", ) idle_penalty: float = Field( default=0.0, description="Penalty for leaving workers idle." ) invalid_action_penalty: float = Field( default=0.0, description="Penalty for malformed or infeasible actions.", ) terminal_makespan_score: float = Field( default=0.0, description="Terminal score based on final schedule quality.", ) unfinished_task_penalty: float = Field( default=0.0, description="Terminal penalty for unfinished work at episode end.", ) class FailureEventType(str, Enum): """Failure events surfaced to agents and the UI.""" WORKER_OUTAGE_START = "worker_outage_start" WORKER_OUTAGE_END = "worker_outage_end" TASK_RETRY_FAILURE = "task_retry_failure" class WorkflowFailureEvent(BaseModel): """Structured failure event emitted by the environment.""" event_type: FailureEventType = Field(..., description="Failure category.") time: int = Field(..., ge=0, description="Simulated time when the event was observed.") task_id: str | None = Field(default=None, description="Task affected by the event, if any.") worker_delta: int = Field(default=0, description="Net temporary change in usable workers.") duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.") detail: str = Field(default="", description="Short human-readable summary.") class WorkflowTaskView(BaseModel): """Compact task payload used in observations and the future UI.""" task_id: str = Field(..., description="Stable task identifier.") status: TaskStatus = Field(..., description="Current task lifecycle state.") duration: int = Field( ..., ge=1, description="Task runtime in simulated time units." ) priority: int = Field(..., ge=0, description="Priority weight for the task.") dependencies: list[str] = Field( default_factory=list, description="Upstream task ids that must complete first.", ) deadline: int | None = Field( default=None, ge=0, description="Optional deadline in simulated time units.", ) criticality: float | None = Field( default=None, description="Derived importance score from the DAG structure.", ) slack: float | None = Field( default=None, description="Derived slack estimate for scheduling decisions.", ) downstream_count: int = Field( default=0, ge=0, description="Count of downstream dependents reachable from this task.", ) start_time: int | None = Field( default=None, ge=0, description="Simulated start time if the task is running or completed.", ) end_time: int | None = Field( default=None, ge=0, description="Simulated end time if the task is completed or scheduled to finish.", ) attempt_count: int = Field( default=0, ge=0, description="Number of retry attempts already consumed by this task.", ) class WorkflowTaskSpec(BaseModel): """Static task specification generated at episode reset.""" task_id: str = Field(..., description="Stable task identifier.") duration: int = Field(..., ge=1, description="Task runtime in simulated time units.") priority: int = Field(..., ge=0, description="Priority weight for the task.") dependencies: list[str] = Field( default_factory=list, description="Upstream task ids that must complete first.", ) dependents: list[str] = Field( default_factory=list, description="Downstream task ids that depend on this task.", ) deadline: int | None = Field( default=None, ge=0, description="Optional deadline in simulated time units.", ) downstream_count: int = Field( default=0, ge=0, description="Number of downstream tasks reachable from this node.", ) critical_path_length: int = Field( default=0, ge=0, description="Duration-weighted path length from this task to a sink.", ) earliest_start: int = Field( default=0, ge=0, description="Earliest feasible start time under dependency constraints.", ) slack: int = Field( default=0, ge=0, description="Scheduling slack measured in simulated time units.", ) criticality: float = Field( default=0.0, description="Normalized importance score derived from critical path and downstream impact.", ) class ProgressSummary(BaseModel): """Counts by task lifecycle state.""" total: int = Field(default=0, ge=0) blocked: int = Field(default=0, ge=0) ready: int = Field(default=0, ge=0) running: int = Field(default=0, ge=0) completed: int = Field(default=0, ge=0) class EpisodeConfig(BaseModel): """Reset-time knobs that define the episode.""" preset: DifficultyPreset = Field( default=DifficultyPreset.EASY, description="Difficulty preset for the episode generator.", ) seed: int = Field( default=0, description="Seed for deterministic episode generation." ) worker_count: int = Field( default=2, ge=1, description="Number of identical workers available to the scheduler.", ) class GraderTarget(BaseModel): """High-level target bands for each preset's grader.""" description: str = Field(..., description="What good performance means for the preset.") score_band_hint: str = Field(..., description="Human-readable interpretation of scores.") class DifficultyPresetConfig(BaseModel): """Concrete generator knobs for a preset.""" preset: DifficultyPreset = Field(..., description="Preset identifier.") min_tasks: int = Field(..., ge=2) max_tasks: int = Field(..., ge=2) edge_probability: float = Field(..., ge=0.0, le=1.0) duration_min: int = Field(..., ge=1) duration_max: int = Field(..., ge=1) priority_min: int = Field(..., ge=0) priority_max: int = Field(..., ge=0) worker_count: int = Field(..., ge=1) deadline_tightness: float = Field( ..., ge=0.0, description="Larger values mean tighter deadlines.", ) time_budget_multiplier: float | None = Field( default=None, gt=0.0, description="Optional multiplier over the theoretical lower-bound makespan.", ) worker_outage_rate: float = Field( default=0.0, ge=0.0, le=1.0, description="Chance of a hard-mode worker outage being sampled on a wait transition.", ) worker_outage_duration_min: int = Field( default=0, ge=0, description="Minimum outage duration in simulated time units.", ) worker_outage_duration_max: int = Field( default=0, ge=0, description="Maximum outage duration in simulated time units.", ) task_retry_failure_rate: float = Field( default=0.0, ge=0.0, le=1.0, description="Chance that a hard-mode task completion becomes a retry failure.", ) max_task_retries: int = Field( default=0, ge=0, description="Maximum number of retry failures a task may suffer before it must complete.", ) grader_target: GraderTarget = Field( ..., description="Preset-specific grader interpretation.", ) class WorkflowEpisodeSpec(BaseModel): """Static episode description produced by the generator.""" config: EpisodeConfig = Field(..., description="Reset-time configuration.") preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.") tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.") class WorkflowEnvStateSnapshot(BaseModel): """Serializable environment state for the current episode.""" episode_id: str = Field(..., description="Stable current episode identifier.") current_time: int = Field(default=0, ge=0, description="Current simulated time.") task_statuses: dict[str, TaskStatus] = Field( default_factory=dict, description="Current task status by task id.", ) running_task_ids: list[str] = Field( default_factory=list, description="Tasks currently consuming workers.", ) completed_task_ids: list[str] = Field( default_factory=list, description="Tasks that have completed.", ) ready_task_ids: list[str] = Field( default_factory=list, description="Tasks currently ready for dispatch.", ) blocked_task_ids: list[str] = Field( default_factory=list, description="Tasks still blocked on dependencies.", ) task_start_times: dict[str, int] = Field( default_factory=dict, description="Simulated start time by task id.", ) task_end_times: dict[str, int] = Field( default_factory=dict, description="Simulated completion time by task id.", ) task_remaining_dependencies: dict[str, int] = Field( default_factory=dict, description="Remaining unfinished prerequisites by task id.", ) task_assigned_finish_times: dict[str, int] = Field( default_factory=dict, description="Predicted completion times for currently running tasks.", ) task_attempt_counts: dict[str, int] = Field( default_factory=dict, description="Retry attempts consumed by each task.", ) cumulative_busy_time: int = Field( default=0, ge=0, description="Aggregate worker busy time accrued so far.", ) time_budget: int | None = Field( default=None, ge=0, description="Optional terminal time budget for the episode.", ) degraded_workers: int = Field( default=0, ge=0, description="Workers temporarily removed from usable capacity.", ) active_worker_outage_until: int | None = Field( default=None, ge=0, description="Time when the current worker outage expires, if any.", ) recent_failure_events: list[WorkflowFailureEvent] = Field( default_factory=list, description="Failure events generated on the latest transition.", ) class SuccessMetrics(BaseModel): """Primary quality metrics used for grading and demos.""" makespan: int | None = Field( default=None, description="Total simulated completion time." ) worker_utilization: float | None = Field( default=None, description="Fraction of available worker time that was used.", ) deadline_miss_count: int = Field( default=0, ge=0, description="Missed task deadlines." ) unfinished_task_count: int = Field( default=0, ge=0, description="Tasks left incomplete at terminal time." ) weighted_priority_completion: float | None = Field( default=None, description="Priority-weighted on-time completion score.", ) benchmark_score: float | None = Field( default=None, description="Deterministic terminal benchmark score in the 0.0-1.0 range.", ) class WorkflowArenaAction(Action): """Strict action space for the workflow scheduler.""" action_type: WorkflowActionType = Field( ..., description="Dispatch ready tasks or wait for the next completion event.", ) task_ids: list[str] = Field( default_factory=list, description="Task ids to dispatch. Must be empty for wait().", ) class WorkflowArenaObservation(Observation): """Compact, typed observation contract for WorkflowArena.""" instruction: str = Field( default=( "Schedule dependency-constrained workflow tasks on limited workers using " "dispatch(task_ids=[...]) or wait()." ), description="Short prompt shown to inference agents.", ) config: EpisodeConfig = Field( default_factory=EpisodeConfig, description="Episode generation settings.", ) current_time: int = Field(default=0, ge=0, description="Current simulated time.") total_workers: int = Field(default=2, ge=1, description="Total identical workers.") effective_workers: int = Field( default=2, ge=0, description="Usable workers after temporary degradation is applied.", ) degraded_workers: int = Field( default=0, ge=0, description="Workers currently unavailable due to outages.", ) free_workers: int = Field(default=2, ge=0, description="Currently idle workers.") time_budget: int | None = Field( default=None, ge=0, description="Optional terminal time budget for the current episode.", ) time_remaining: int | None = Field( default=None, description="Remaining time until the episode budget expires, if budgeted.", ) progress: ProgressSummary = Field( default_factory=ProgressSummary, description="Task counts by lifecycle state.", ) ready_tasks: list[WorkflowTaskView] = Field( default_factory=list, description="Ready tasks eligible for dispatch.", ) running_tasks: list[WorkflowTaskView] = Field( default_factory=list, description="Tasks currently consuming workers.", ) completed_tasks: list[WorkflowTaskView] = Field( default_factory=list, description="Tasks already completed.", ) blocked_tasks: list[WorkflowTaskView] = Field( default_factory=list, description="Tasks still waiting on dependencies.", ) last_reward_breakdown: RewardBreakdown = Field( default_factory=RewardBreakdown, description="Per-step reward channel breakdown.", ) cumulative_reward: float = Field(default=0.0, description="Running total reward.") success_metrics: SuccessMetrics = Field( default_factory=SuccessMetrics, description="Primary schedule quality metrics.", ) note: str | None = Field( default=None, description="Short environment note about the latest transition.", ) validation_error: str | None = Field( default=None, description="Explicit invalid-action explanation when the previous action failed.", ) termination_reason: str | None = Field( default=None, description="Terminal reason when the episode ended unsuccessfully.", ) benchmark_score: float | None = Field( default=None, description="Top-level bounded benchmark score for easier client access.", ) recent_failure_events: list[WorkflowFailureEvent] = Field( default_factory=list, description="Failure events generated on the latest accepted transition.", ) received_action: dict[str, object] | None = Field( default=None, description="Last action accepted by the server for logging and prompting.", )