workflow_arena / models.py
Cyber-Machine's picture
init: WorkFlowArena
aea0016 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Typed models for WorkflowArena.
"""
from __future__ import annotations
from enum import Enum
from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field
class TaskStatus(str, Enum):
"""Allowed lifecycle states for a workflow task."""
BLOCKED = "blocked"
READY = "ready"
RUNNING = "running"
COMPLETED = "completed"
class DifficultyPreset(str, Enum):
"""Initial task presets required by the hackathon."""
EASY = "easy"
MEDIUM = "medium"
HARD = "hard"
class WorkflowActionType(str, Enum):
"""Explicit action space for the scheduler agent."""
DISPATCH = "dispatch"
WAIT = "wait"
class RewardBreakdown(BaseModel):
"""Named reward channels for shaped feedback."""
completion_reward: float = Field(
default=0.0, description="Reward for completing tasks."
)
utilization_reward: float = Field(
default=0.0, description="Reward for keeping workers busy."
)
deadline_reward: float = Field(
default=0.0, description="Reward or penalty tied to deadlines."
)
criticality_reward: float = Field(
default=0.0,
description="Reward for prioritizing critical-path work appropriately.",
)
idle_penalty: float = Field(
default=0.0, description="Penalty for leaving workers idle."
)
invalid_action_penalty: float = Field(
default=0.0,
description="Penalty for malformed or infeasible actions.",
)
terminal_makespan_score: float = Field(
default=0.0,
description="Terminal score based on final schedule quality.",
)
unfinished_task_penalty: float = Field(
default=0.0,
description="Terminal penalty for unfinished work at episode end.",
)
class FailureEventType(str, Enum):
"""Failure events surfaced to agents and the UI."""
WORKER_OUTAGE_START = "worker_outage_start"
WORKER_OUTAGE_END = "worker_outage_end"
TASK_RETRY_FAILURE = "task_retry_failure"
class WorkflowFailureEvent(BaseModel):
"""Structured failure event emitted by the environment."""
event_type: FailureEventType = Field(..., description="Failure category.")
time: int = Field(..., ge=0, description="Simulated time when the event was observed.")
task_id: str | None = Field(default=None, description="Task affected by the event, if any.")
worker_delta: int = Field(default=0, description="Net temporary change in usable workers.")
duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.")
detail: str = Field(default="", description="Short human-readable summary.")
class WorkflowTaskView(BaseModel):
"""Compact task payload used in observations and the future UI."""
task_id: str = Field(..., description="Stable task identifier.")
status: TaskStatus = Field(..., description="Current task lifecycle state.")
duration: int = Field(
..., ge=1, description="Task runtime in simulated time units."
)
priority: int = Field(..., ge=0, description="Priority weight for the task.")
dependencies: list[str] = Field(
default_factory=list,
description="Upstream task ids that must complete first.",
)
deadline: int | None = Field(
default=None,
ge=0,
description="Optional deadline in simulated time units.",
)
criticality: float | None = Field(
default=None,
description="Derived importance score from the DAG structure.",
)
slack: float | None = Field(
default=None,
description="Derived slack estimate for scheduling decisions.",
)
downstream_count: int = Field(
default=0,
ge=0,
description="Count of downstream dependents reachable from this task.",
)
start_time: int | None = Field(
default=None,
ge=0,
description="Simulated start time if the task is running or completed.",
)
end_time: int | None = Field(
default=None,
ge=0,
description="Simulated end time if the task is completed or scheduled to finish.",
)
attempt_count: int = Field(
default=0,
ge=0,
description="Number of retry attempts already consumed by this task.",
)
class WorkflowTaskSpec(BaseModel):
"""Static task specification generated at episode reset."""
task_id: str = Field(..., description="Stable task identifier.")
duration: int = Field(..., ge=1, description="Task runtime in simulated time units.")
priority: int = Field(..., ge=0, description="Priority weight for the task.")
dependencies: list[str] = Field(
default_factory=list,
description="Upstream task ids that must complete first.",
)
dependents: list[str] = Field(
default_factory=list,
description="Downstream task ids that depend on this task.",
)
deadline: int | None = Field(
default=None,
ge=0,
description="Optional deadline in simulated time units.",
)
downstream_count: int = Field(
default=0,
ge=0,
description="Number of downstream tasks reachable from this node.",
)
critical_path_length: int = Field(
default=0,
ge=0,
description="Duration-weighted path length from this task to a sink.",
)
earliest_start: int = Field(
default=0,
ge=0,
description="Earliest feasible start time under dependency constraints.",
)
slack: int = Field(
default=0,
ge=0,
description="Scheduling slack measured in simulated time units.",
)
criticality: float = Field(
default=0.0,
description="Normalized importance score derived from critical path and downstream impact.",
)
class ProgressSummary(BaseModel):
"""Counts by task lifecycle state."""
total: int = Field(default=0, ge=0)
blocked: int = Field(default=0, ge=0)
ready: int = Field(default=0, ge=0)
running: int = Field(default=0, ge=0)
completed: int = Field(default=0, ge=0)
class EpisodeConfig(BaseModel):
"""Reset-time knobs that define the episode."""
preset: DifficultyPreset = Field(
default=DifficultyPreset.EASY,
description="Difficulty preset for the episode generator.",
)
seed: int = Field(
default=0, description="Seed for deterministic episode generation."
)
worker_count: int = Field(
default=2,
ge=1,
description="Number of identical workers available to the scheduler.",
)
class GraderTarget(BaseModel):
"""High-level target bands for each preset's grader."""
description: str = Field(..., description="What good performance means for the preset.")
score_band_hint: str = Field(..., description="Human-readable interpretation of scores.")
class DifficultyPresetConfig(BaseModel):
"""Concrete generator knobs for a preset."""
preset: DifficultyPreset = Field(..., description="Preset identifier.")
min_tasks: int = Field(..., ge=2)
max_tasks: int = Field(..., ge=2)
edge_probability: float = Field(..., ge=0.0, le=1.0)
duration_min: int = Field(..., ge=1)
duration_max: int = Field(..., ge=1)
priority_min: int = Field(..., ge=0)
priority_max: int = Field(..., ge=0)
worker_count: int = Field(..., ge=1)
deadline_tightness: float = Field(
...,
ge=0.0,
description="Larger values mean tighter deadlines.",
)
time_budget_multiplier: float | None = Field(
default=None,
gt=0.0,
description="Optional multiplier over the theoretical lower-bound makespan.",
)
worker_outage_rate: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Chance of a hard-mode worker outage being sampled on a wait transition.",
)
worker_outage_duration_min: int = Field(
default=0,
ge=0,
description="Minimum outage duration in simulated time units.",
)
worker_outage_duration_max: int = Field(
default=0,
ge=0,
description="Maximum outage duration in simulated time units.",
)
task_retry_failure_rate: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Chance that a hard-mode task completion becomes a retry failure.",
)
max_task_retries: int = Field(
default=0,
ge=0,
description="Maximum number of retry failures a task may suffer before it must complete.",
)
grader_target: GraderTarget = Field(
...,
description="Preset-specific grader interpretation.",
)
class WorkflowEpisodeSpec(BaseModel):
"""Static episode description produced by the generator."""
config: EpisodeConfig = Field(..., description="Reset-time configuration.")
preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.")
tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.")
class WorkflowEnvStateSnapshot(BaseModel):
"""Serializable environment state for the current episode."""
episode_id: str = Field(..., description="Stable current episode identifier.")
current_time: int = Field(default=0, ge=0, description="Current simulated time.")
task_statuses: dict[str, TaskStatus] = Field(
default_factory=dict,
description="Current task status by task id.",
)
running_task_ids: list[str] = Field(
default_factory=list,
description="Tasks currently consuming workers.",
)
completed_task_ids: list[str] = Field(
default_factory=list,
description="Tasks that have completed.",
)
ready_task_ids: list[str] = Field(
default_factory=list,
description="Tasks currently ready for dispatch.",
)
blocked_task_ids: list[str] = Field(
default_factory=list,
description="Tasks still blocked on dependencies.",
)
task_start_times: dict[str, int] = Field(
default_factory=dict,
description="Simulated start time by task id.",
)
task_end_times: dict[str, int] = Field(
default_factory=dict,
description="Simulated completion time by task id.",
)
task_remaining_dependencies: dict[str, int] = Field(
default_factory=dict,
description="Remaining unfinished prerequisites by task id.",
)
task_assigned_finish_times: dict[str, int] = Field(
default_factory=dict,
description="Predicted completion times for currently running tasks.",
)
task_attempt_counts: dict[str, int] = Field(
default_factory=dict,
description="Retry attempts consumed by each task.",
)
cumulative_busy_time: int = Field(
default=0,
ge=0,
description="Aggregate worker busy time accrued so far.",
)
time_budget: int | None = Field(
default=None,
ge=0,
description="Optional terminal time budget for the episode.",
)
degraded_workers: int = Field(
default=0,
ge=0,
description="Workers temporarily removed from usable capacity.",
)
active_worker_outage_until: int | None = Field(
default=None,
ge=0,
description="Time when the current worker outage expires, if any.",
)
recent_failure_events: list[WorkflowFailureEvent] = Field(
default_factory=list,
description="Failure events generated on the latest transition.",
)
class SuccessMetrics(BaseModel):
"""Primary quality metrics used for grading and demos."""
makespan: int | None = Field(
default=None, description="Total simulated completion time."
)
worker_utilization: float | None = Field(
default=None,
description="Fraction of available worker time that was used.",
)
deadline_miss_count: int = Field(
default=0, ge=0, description="Missed task deadlines."
)
unfinished_task_count: int = Field(
default=0, ge=0, description="Tasks left incomplete at terminal time."
)
weighted_priority_completion: float | None = Field(
default=None,
description="Priority-weighted on-time completion score.",
)
benchmark_score: float | None = Field(
default=None,
description="Deterministic terminal benchmark score in the 0.0-1.0 range.",
)
class WorkflowArenaAction(Action):
"""Strict action space for the workflow scheduler."""
action_type: WorkflowActionType = Field(
...,
description="Dispatch ready tasks or wait for the next completion event.",
)
task_ids: list[str] = Field(
default_factory=list,
description="Task ids to dispatch. Must be empty for wait().",
)
class WorkflowArenaObservation(Observation):
"""Compact, typed observation contract for WorkflowArena."""
instruction: str = Field(
default=(
"Schedule dependency-constrained workflow tasks on limited workers using "
"dispatch(task_ids=[...]) or wait()."
),
description="Short prompt shown to inference agents.",
)
config: EpisodeConfig = Field(
default_factory=EpisodeConfig,
description="Episode generation settings.",
)
current_time: int = Field(default=0, ge=0, description="Current simulated time.")
total_workers: int = Field(default=2, ge=1, description="Total identical workers.")
effective_workers: int = Field(
default=2,
ge=0,
description="Usable workers after temporary degradation is applied.",
)
degraded_workers: int = Field(
default=0,
ge=0,
description="Workers currently unavailable due to outages.",
)
free_workers: int = Field(default=2, ge=0, description="Currently idle workers.")
time_budget: int | None = Field(
default=None,
ge=0,
description="Optional terminal time budget for the current episode.",
)
time_remaining: int | None = Field(
default=None,
description="Remaining time until the episode budget expires, if budgeted.",
)
progress: ProgressSummary = Field(
default_factory=ProgressSummary,
description="Task counts by lifecycle state.",
)
ready_tasks: list[WorkflowTaskView] = Field(
default_factory=list,
description="Ready tasks eligible for dispatch.",
)
running_tasks: list[WorkflowTaskView] = Field(
default_factory=list,
description="Tasks currently consuming workers.",
)
completed_tasks: list[WorkflowTaskView] = Field(
default_factory=list,
description="Tasks already completed.",
)
blocked_tasks: list[WorkflowTaskView] = Field(
default_factory=list,
description="Tasks still waiting on dependencies.",
)
last_reward_breakdown: RewardBreakdown = Field(
default_factory=RewardBreakdown,
description="Per-step reward channel breakdown.",
)
cumulative_reward: float = Field(default=0.0, description="Running total reward.")
success_metrics: SuccessMetrics = Field(
default_factory=SuccessMetrics,
description="Primary schedule quality metrics.",
)
note: str | None = Field(
default=None,
description="Short environment note about the latest transition.",
)
validation_error: str | None = Field(
default=None,
description="Explicit invalid-action explanation when the previous action failed.",
)
termination_reason: str | None = Field(
default=None,
description="Terminal reason when the episode ended unsuccessfully.",
)
benchmark_score: float | None = Field(
default=None,
description="Top-level bounded benchmark score for easier client access.",
)
recent_failure_events: list[WorkflowFailureEvent] = Field(
default_factory=list,
description="Failure events generated on the latest accepted transition.",
)
received_action: dict[str, object] | None = Field(
default=None,
description="Last action accepted by the server for logging and prompting.",
)