Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """WorkflowArena event-driven workflow orchestration environment.""" | |
| from __future__ import annotations | |
| import math | |
| import random | |
| from typing import Any | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from workflow_arena.generator import generate_episode | |
| from workflow_arena.models import ( | |
| DifficultyPreset, | |
| EpisodeConfig, | |
| FailureEventType, | |
| ProgressSummary, | |
| RewardBreakdown, | |
| SuccessMetrics, | |
| TaskStatus, | |
| WorkflowArenaAction, | |
| WorkflowArenaObservation, | |
| WorkflowEnvStateSnapshot, | |
| WorkflowEpisodeSpec, | |
| WorkflowFailureEvent, | |
| WorkflowTaskSpec, | |
| WorkflowTaskView, | |
| WorkflowActionType, | |
| ) | |
| from workflow_arena.presets import get_preset_config | |
| class WorkflowArenaEnvironment(Environment): | |
| """Resource-constrained workflow scheduler with event-driven semantics.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| STEP_LIMIT_FLOOR: int = 32 | |
| STEP_LIMIT_MULTIPLIER: int = 8 | |
| INVALID_ACTION_PENALTY: float = -0.1 | |
| OVERCAPACITY_INVALID_ACTION_PENALTY: float = -0.25 | |
| AVOIDABLE_WAIT_PENALTY_PER_SLOT: float = -0.08 | |
| UNFINISHED_PRIORITY_PENALTY: float = -0.02 | |
| OVERDUE_PRIORITY_PENALTY_PER_TICK: float = -0.005 | |
| MAX_RECENT_FAILURE_EVENTS: int = 6 | |
| MIN_GRADER_SCORE: float = 0.01 | |
| MAX_GRADER_SCORE: float = 0.99 | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._cumulative_reward = 0.0 | |
| self._max_episode_steps = self.STEP_LIMIT_FLOOR | |
| self._config = EpisodeConfig( | |
| preset=DifficultyPreset.EASY, | |
| seed=0, | |
| worker_count=2, | |
| ) | |
| self._episode_spec: WorkflowEpisodeSpec | None = None | |
| self._env_state: WorkflowEnvStateSnapshot | None = None | |
| self._event_rng = random.Random(0) | |
| def _require_episode(self) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]: | |
| if self._episode_spec is None or self._env_state is None: | |
| raise RuntimeError("Environment must be reset before use.") | |
| return self._episode_spec, self._env_state | |
| def _preset_config(self): | |
| episode, _ = self._require_episode() | |
| return episode.preset_config | |
| def _task_map(self) -> dict[str, WorkflowTaskSpec]: | |
| episode, _ = self._require_episode() | |
| return {task.task_id: task for task in episode.tasks} | |
| def _effective_worker_capacity( | |
| self, env_state: WorkflowEnvStateSnapshot | None = None | |
| ) -> int: | |
| if env_state is None: | |
| _, env_state = self._require_episode() | |
| return max(0, self._config.worker_count - env_state.degraded_workers) | |
| def _time_remaining( | |
| self, env_state: WorkflowEnvStateSnapshot | None = None | |
| ) -> int | None: | |
| if env_state is None: | |
| _, env_state = self._require_episode() | |
| if env_state.time_budget is None: | |
| return None | |
| return max(0, env_state.time_budget - env_state.current_time) | |
| def _terminal_score(self) -> float: | |
| episode, env_state = self._require_episode() | |
| if env_state.current_time <= 0: | |
| return 0.0 | |
| lower_bound = self._lower_bound_makespan(episode) | |
| score = lower_bound / max(lower_bound, env_state.current_time) | |
| return round(score, 4) | |
| def _bounded_grader_score(self, score: float) -> float: | |
| return round( | |
| min(self.MAX_GRADER_SCORE, max(self.MIN_GRADER_SCORE, score)), | |
| 4, | |
| ) | |
| def _benchmark_score(self) -> float: | |
| makespan_score, deadline_score, utilization_score = self._grade_components( | |
| include_terminal_makespan=True | |
| ) | |
| return self._bounded_grader_score( | |
| (0.5 * makespan_score) + (0.3 * deadline_score) + (0.2 * utilization_score) | |
| ) | |
| def _grade_components( | |
| self, *, include_terminal_makespan: bool = False | |
| ) -> tuple[float, float, float]: | |
| episode, env_state = self._require_episode() | |
| utilization = ( | |
| env_state.cumulative_busy_time | |
| / (env_state.current_time * self._config.worker_count) | |
| if env_state.current_time > 0 | |
| else 0.0 | |
| ) | |
| total_priority = sum(task.priority for task in episode.tasks) or 1 | |
| on_time_priority = 0 | |
| for task in episode.tasks: | |
| end_time = env_state.task_end_times.get(task.task_id) | |
| if end_time is None: | |
| continue | |
| if task.deadline is None or end_time <= task.deadline: | |
| on_time_priority += task.priority | |
| deadline_score = round(on_time_priority / total_priority, 4) | |
| utilization_score = round(utilization, 4) | |
| makespan_score = self._terminal_score() if include_terminal_makespan else 0.0 | |
| return makespan_score, deadline_score, utilization_score | |
| def _unfinished_task_penalty(self, current_time: int) -> float: | |
| episode, env_state = self._require_episode() | |
| penalty = 0.0 | |
| for task in episode.tasks: | |
| if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED: | |
| continue | |
| penalty += self.UNFINISHED_PRIORITY_PENALTY * task.priority | |
| if task.deadline is not None and current_time > task.deadline: | |
| penalty += ( | |
| self.OVERDUE_PRIORITY_PENALTY_PER_TICK | |
| * task.priority | |
| * (current_time - task.deadline) | |
| ) | |
| return round(penalty, 4) | |
| def _success_metrics( | |
| self, *, benchmark_score_override: float | None = None | |
| ) -> SuccessMetrics: | |
| episode, env_state = self._require_episode() | |
| unfinished_task_count = sum( | |
| 1 | |
| for task in episode.tasks | |
| if env_state.task_statuses[task.task_id] != TaskStatus.COMPLETED | |
| ) | |
| deadline_miss_count = sum( | |
| 1 | |
| for task in episode.tasks | |
| if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED | |
| and task.deadline is not None | |
| and env_state.task_end_times.get(task.task_id, 0) > task.deadline | |
| ) | |
| _, deadline_score, utilization_score = self._grade_components( | |
| include_terminal_makespan=False | |
| ) | |
| all_done = unfinished_task_count == 0 | |
| return SuccessMetrics( | |
| makespan=env_state.current_time if all_done else None, | |
| worker_utilization=utilization_score, | |
| deadline_miss_count=deadline_miss_count, | |
| unfinished_task_count=unfinished_task_count, | |
| weighted_priority_completion=deadline_score, | |
| benchmark_score=benchmark_score_override, | |
| ) | |
| def _task_view( | |
| self, | |
| task: WorkflowTaskSpec, | |
| status: TaskStatus, | |
| *, | |
| include_planner_hints: bool = True, | |
| ) -> WorkflowTaskView: | |
| _, env_state = self._require_episode() | |
| return WorkflowTaskView( | |
| task_id=task.task_id, | |
| status=status, | |
| duration=task.duration, | |
| priority=task.priority, | |
| dependencies=task.dependencies, | |
| deadline=task.deadline, | |
| criticality=task.criticality if include_planner_hints else None, | |
| slack=float(task.slack) if include_planner_hints else None, | |
| downstream_count=task.downstream_count if include_planner_hints else 0, | |
| start_time=env_state.task_start_times.get(task.task_id), | |
| end_time=( | |
| env_state.task_end_times.get(task.task_id) | |
| or env_state.task_assigned_finish_times.get(task.task_id) | |
| ), | |
| attempt_count=env_state.task_attempt_counts.get(task.task_id, 0), | |
| ) | |
| def _task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]: | |
| episode, env_state = self._require_episode() | |
| return [ | |
| self._task_view(task, status, include_planner_hints=True) | |
| for task in episode.tasks | |
| if env_state.task_statuses[task.task_id] == status | |
| ] | |
| def debug_task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]: | |
| return self._task_views_for_status(status) | |
| def _set_recent_failure_events( | |
| self, | |
| env_state: WorkflowEnvStateSnapshot, | |
| events: list[WorkflowFailureEvent], | |
| ) -> None: | |
| env_state.recent_failure_events = events[-self.MAX_RECENT_FAILURE_EVENTS :] | |
| def _maybe_end_worker_outage( | |
| self, | |
| env_state: WorkflowEnvStateSnapshot, | |
| events: list[WorkflowFailureEvent], | |
| ) -> None: | |
| if ( | |
| env_state.active_worker_outage_until is not None | |
| and env_state.current_time >= env_state.active_worker_outage_until | |
| ): | |
| events.append( | |
| WorkflowFailureEvent( | |
| event_type=FailureEventType.WORKER_OUTAGE_END, | |
| time=env_state.current_time, | |
| worker_delta=1, | |
| detail="Worker capacity restored.", | |
| ) | |
| ) | |
| env_state.active_worker_outage_until = None | |
| env_state.degraded_workers = 0 | |
| def _maybe_start_worker_outage( | |
| self, | |
| env_state: WorkflowEnvStateSnapshot, | |
| events: list[WorkflowFailureEvent], | |
| ) -> None: | |
| preset_config = self._preset_config() | |
| if self._config.preset != DifficultyPreset.HARD: | |
| return | |
| if env_state.active_worker_outage_until is not None: | |
| return | |
| if preset_config.worker_outage_rate <= 0.0: | |
| return | |
| if self._event_rng.random() >= preset_config.worker_outage_rate: | |
| return | |
| duration = self._event_rng.randint( | |
| preset_config.worker_outage_duration_min, | |
| preset_config.worker_outage_duration_max, | |
| ) | |
| if duration <= 0: | |
| return | |
| env_state.degraded_workers = min(1, self._config.worker_count) | |
| env_state.active_worker_outage_until = env_state.current_time + duration | |
| events.append( | |
| WorkflowFailureEvent( | |
| event_type=FailureEventType.WORKER_OUTAGE_START, | |
| time=env_state.current_time, | |
| worker_delta=-env_state.degraded_workers, | |
| duration=duration, | |
| detail=f"Worker outage active until t={env_state.active_worker_outage_until}.", | |
| ) | |
| ) | |
| def _should_retry_fail(self, task_id: str) -> bool: | |
| preset_config = self._preset_config() | |
| _, env_state = self._require_episode() | |
| if self._config.preset != DifficultyPreset.HARD: | |
| return False | |
| if preset_config.task_retry_failure_rate <= 0.0: | |
| return False | |
| if env_state.task_attempt_counts.get(task_id, 0) >= preset_config.max_task_retries: | |
| return False | |
| return self._event_rng.random() < preset_config.task_retry_failure_rate | |
| def _dispatch_potential( | |
| self, | |
| env_state: WorkflowEnvStateSnapshot, | |
| task_map: dict[str, WorkflowTaskSpec], | |
| ) -> tuple[float, float]: | |
| if not env_state.running_task_ids: | |
| return 0.0, 0.0 | |
| episode, _ = self._require_episode() | |
| max_slack = max((task.slack for task in episode.tasks), default=0) | |
| utilization_component = 0.06 * ( | |
| len(env_state.running_task_ids) / max(1, self._config.worker_count) | |
| ) | |
| criticality_component = 0.0 | |
| for task_id in env_state.running_task_ids: | |
| task = task_map[task_id] | |
| slack_urgency = 1.0 if max_slack <= 0 else 1.0 - (task.slack / max_slack) | |
| criticality_component += (0.6 * task.criticality) + (0.4 * slack_urgency) | |
| criticality_component = 0.04 * ( | |
| criticality_component / max(1, self._config.worker_count) | |
| ) | |
| return round(utilization_component, 4), round(criticality_component, 4) | |
| def _base_observation( | |
| self, | |
| *, | |
| reward: float, | |
| breakdown: RewardBreakdown, | |
| note: str, | |
| done: bool, | |
| benchmark_score_override: float | None = None, | |
| ) -> WorkflowArenaObservation: | |
| episode, env_state = self._require_episode() | |
| ready_tasks = self._task_views_for_status(TaskStatus.READY) | |
| running_tasks = self._task_views_for_status(TaskStatus.RUNNING) | |
| completed_tasks = self._task_views_for_status(TaskStatus.COMPLETED) | |
| blocked_tasks = self._task_views_for_status(TaskStatus.BLOCKED) | |
| effective_workers = self._effective_worker_capacity(env_state) | |
| return WorkflowArenaObservation( | |
| done=done, | |
| reward=reward, | |
| config=self._config, | |
| current_time=env_state.current_time, | |
| total_workers=self._config.worker_count, | |
| effective_workers=effective_workers, | |
| degraded_workers=env_state.degraded_workers, | |
| free_workers=max(0, effective_workers - len(running_tasks)), | |
| time_budget=env_state.time_budget, | |
| time_remaining=self._time_remaining(env_state), | |
| progress=ProgressSummary( | |
| total=len(episode.tasks), | |
| blocked=len(blocked_tasks), | |
| ready=len(ready_tasks), | |
| running=len(running_tasks), | |
| completed=len(completed_tasks), | |
| ), | |
| ready_tasks=ready_tasks, | |
| running_tasks=running_tasks, | |
| completed_tasks=completed_tasks, | |
| blocked_tasks=blocked_tasks, | |
| last_reward_breakdown=breakdown, | |
| cumulative_reward=self._cumulative_reward, | |
| success_metrics=self._success_metrics( | |
| benchmark_score_override=benchmark_score_override | |
| ), | |
| note=note, | |
| benchmark_score=benchmark_score_override, | |
| recent_failure_events=env_state.recent_failure_events, | |
| metadata={ | |
| "phase": "simulation_active", | |
| "note": note, | |
| "effective_workers": effective_workers, | |
| "degraded_workers": env_state.degraded_workers, | |
| "time_budget": env_state.time_budget, | |
| "time_remaining": self._time_remaining(env_state), | |
| "recent_failure_events": [ | |
| event.model_dump(mode="json") for event in env_state.recent_failure_events | |
| ], | |
| "episode_loop": [ | |
| "reset generates a seeded workflow DAG episode", | |
| "dispatch(task_ids=[...]) starts ready tasks if workers are free", | |
| "wait() advances simulated time to the next completion event", | |
| "medium and hard episodes may end at a fixed time budget", | |
| "hard mode may trigger outages and retry failures", | |
| ], | |
| }, | |
| ) | |
| def _lower_bound_makespan(self, episode: WorkflowEpisodeSpec) -> int: | |
| total_work = sum(task.duration for task in episode.tasks) | |
| work_bound = (total_work + self._config.worker_count - 1) // self._config.worker_count | |
| path_bound = max(task.critical_path_length for task in episode.tasks) | |
| return max(1, work_bound, path_bound) | |
| def _termination_breakdown( | |
| self, | |
| *, | |
| invalid_penalty: float = 0.0, | |
| idle_penalty: float = 0.0, | |
| terminal_makespan_score: float = 0.0, | |
| unfinished_task_penalty: float = 0.0, | |
| ) -> RewardBreakdown: | |
| return RewardBreakdown( | |
| invalid_action_penalty=round(invalid_penalty, 4), | |
| idle_penalty=round(idle_penalty, 4), | |
| terminal_makespan_score=round(terminal_makespan_score, 4), | |
| unfinished_task_penalty=round(unfinished_task_penalty, 4), | |
| ) | |
| def _terminate_episode( | |
| self, | |
| *, | |
| note: str, | |
| breakdown: RewardBreakdown, | |
| reward: float, | |
| reason: str, | |
| benchmark_score: float | None = None, | |
| ) -> WorkflowArenaObservation: | |
| if benchmark_score is None: | |
| benchmark_score = self._benchmark_score() | |
| self._cumulative_reward += reward | |
| observation = self._base_observation( | |
| reward=reward, | |
| breakdown=breakdown, | |
| note=note, | |
| done=True, | |
| benchmark_score_override=benchmark_score, | |
| ) | |
| observation.termination_reason = reason | |
| observation.benchmark_score = benchmark_score | |
| observation.metadata["termination_reason"] = reason | |
| observation.metadata["benchmark_score"] = benchmark_score | |
| return observation | |
| def _step_limit_reached(self) -> bool: | |
| return self._state.step_count >= self._max_episode_steps | |
| def _maybe_terminate_for_limits(self) -> WorkflowArenaObservation | None: | |
| if not self._step_limit_reached(): | |
| return None | |
| _, env_state = self._require_episode() | |
| unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) | |
| terminal_score = self._terminal_score() | |
| breakdown = self._termination_breakdown( | |
| terminal_makespan_score=terminal_score, | |
| unfinished_task_penalty=unfinished_penalty, | |
| ) | |
| reward = round(-1.0 + unfinished_penalty + terminal_score, 4) | |
| return self._terminate_episode( | |
| note="Episode terminated after hitting the safety step limit.", | |
| breakdown=breakdown, | |
| reward=reward, | |
| reason="step_limit", | |
| ) | |
| def _apply_invalid( | |
| self, | |
| message: str, | |
| *, | |
| penalty: float | None = None, | |
| ) -> WorkflowArenaObservation: | |
| _, env_state = self._require_episode() | |
| applied_penalty = ( | |
| self.INVALID_ACTION_PENALTY if penalty is None else float(penalty) | |
| ) | |
| breakdown = RewardBreakdown(invalid_action_penalty=round(applied_penalty, 4)) | |
| self._cumulative_reward += breakdown.invalid_action_penalty | |
| self._set_recent_failure_events(env_state, []) | |
| observation = self._base_observation( | |
| reward=breakdown.invalid_action_penalty, | |
| breakdown=breakdown, | |
| note="Invalid action.", | |
| done=False, | |
| ) | |
| observation.validation_error = message | |
| observation.metadata["validation_error"] = message | |
| return observation | |
| def _transition_unlocks(self, completed_task_ids: list[str]) -> list[str]: | |
| episode, env_state = self._require_episode() | |
| task_map = {task.task_id: task for task in episode.tasks} | |
| unlocked: list[str] = [] | |
| for task_id in completed_task_ids: | |
| for dependent_id in task_map[task_id].dependents: | |
| env_state.task_remaining_dependencies[dependent_id] -= 1 | |
| if env_state.task_remaining_dependencies[dependent_id] == 0: | |
| env_state.task_statuses[dependent_id] = TaskStatus.READY | |
| if dependent_id not in env_state.ready_task_ids: | |
| env_state.ready_task_ids.append(dependent_id) | |
| if dependent_id in env_state.blocked_task_ids: | |
| env_state.blocked_task_ids.remove(dependent_id) | |
| unlocked.append(dependent_id) | |
| env_state.ready_task_ids.sort() | |
| env_state.blocked_task_ids.sort() | |
| return unlocked | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> WorkflowArenaObservation: | |
| """Generate a seeded workflow DAG episode.""" | |
| preset_raw = kwargs.pop("preset", DifficultyPreset.EASY) | |
| worker_count_raw = kwargs.pop("worker_count", None) | |
| del kwargs | |
| preset = ( | |
| preset_raw | |
| if isinstance(preset_raw, DifficultyPreset) | |
| else DifficultyPreset(str(preset_raw)) | |
| ) | |
| preset_config = get_preset_config(preset) | |
| chosen_seed = 0 if seed is None else seed | |
| chosen_worker_count = ( | |
| preset_config.worker_count | |
| if worker_count_raw is None | |
| else int(worker_count_raw) | |
| ) | |
| chosen_episode_id = str(uuid4()) if episode_id is None else episode_id | |
| self._state = State(episode_id=chosen_episode_id, step_count=0) | |
| self._cumulative_reward = 0.0 | |
| self._config = EpisodeConfig( | |
| preset=preset, | |
| seed=chosen_seed, | |
| worker_count=chosen_worker_count, | |
| ) | |
| self._event_rng = random.Random( | |
| (chosen_seed + 1) * 1009 | |
| + (chosen_worker_count * 131) | |
| + (list(DifficultyPreset).index(preset) + 1) | |
| ) | |
| self._episode_spec, self._env_state = generate_episode(self._config) | |
| self._max_episode_steps = max( | |
| self.STEP_LIMIT_FLOOR, | |
| len(self._episode_spec.tasks) * self.STEP_LIMIT_MULTIPLIER, | |
| ) | |
| self._env_state.episode_id = chosen_episode_id | |
| lower_bound = self._lower_bound_makespan(self._episode_spec) | |
| if preset_config.time_budget_multiplier is not None: | |
| self._env_state.time_budget = int( | |
| math.ceil(lower_bound * preset_config.time_budget_multiplier) | |
| ) | |
| self._set_recent_failure_events(self._env_state, []) | |
| note = "Workflow episode generated. Dispatch ready tasks or wait for completions." | |
| if self._env_state.time_budget is not None: | |
| note = ( | |
| f"Workflow episode generated. Finish as much as possible before " | |
| f"t={self._env_state.time_budget}." | |
| ) | |
| if preset == DifficultyPreset.HARD: | |
| note += " Hard mode may trigger worker outages and retry failures." | |
| return self._base_observation( | |
| reward=0.0, | |
| breakdown=RewardBreakdown(), | |
| note=note, | |
| done=False, | |
| ) | |
| def _wait_note( | |
| self, | |
| *, | |
| completed_now: list[str], | |
| failed_now: list[str], | |
| unlocked: list[str], | |
| recent_events: list[WorkflowFailureEvent], | |
| time_budget_hit: bool = False, | |
| ) -> str: | |
| chunks: list[str] = [] | |
| if time_budget_hit: | |
| chunks.append("Time budget exhausted before the next completion event.") | |
| elif completed_now: | |
| chunks.append(f"Completed: {', '.join(completed_now)}.") | |
| else: | |
| chunks.append("Advanced to next completion event.") | |
| if failed_now: | |
| chunks.append(f"Retry required: {', '.join(failed_now)}.") | |
| if unlocked: | |
| chunks.append(f"Unlocked: {', '.join(unlocked)}.") | |
| for event in recent_events: | |
| if event.event_type == FailureEventType.WORKER_OUTAGE_START: | |
| chunks.append(event.detail) | |
| elif event.event_type == FailureEventType.WORKER_OUTAGE_END: | |
| chunks.append("Worker capacity restored.") | |
| return " ".join(chunks) | |
| def step( | |
| self, | |
| action: WorkflowArenaAction, | |
| timeout_s: float | None = None, | |
| **kwargs: Any, | |
| ) -> WorkflowArenaObservation: | |
| """Apply a dispatch or wait action using event-driven semantics.""" | |
| del timeout_s, kwargs | |
| episode, env_state = self._require_episode() | |
| task_map = {task.task_id: task for task in episode.tasks} | |
| self._state.step_count += 1 | |
| self._set_recent_failure_events(env_state, []) | |
| limit_termination = self._maybe_terminate_for_limits() | |
| if limit_termination is not None: | |
| return limit_termination | |
| if action.action_type == WorkflowActionType.WAIT and action.task_ids: | |
| return self._apply_invalid("wait() must not include task_ids.") | |
| if action.action_type == WorkflowActionType.DISPATCH: | |
| if not action.task_ids: | |
| return self._apply_invalid( | |
| "dispatch(task_ids=[...]) requires at least one task id." | |
| ) | |
| if len(set(action.task_ids)) != len(action.task_ids): | |
| return self._apply_invalid( | |
| "dispatch(task_ids=[...]) must not contain duplicate task ids." | |
| ) | |
| free_workers = self._effective_worker_capacity(env_state) - len( | |
| env_state.running_task_ids | |
| ) | |
| if len(action.task_ids) > max(0, free_workers): | |
| return self._apply_invalid( | |
| "dispatch(task_ids=[...]) cannot exceed available worker capacity.", | |
| penalty=self.OVERCAPACITY_INVALID_ACTION_PENALTY, | |
| ) | |
| unknown_tasks = [task_id for task_id in action.task_ids if task_id not in task_map] | |
| if unknown_tasks: | |
| return self._apply_invalid(f"Unknown task ids: {unknown_tasks}.") | |
| not_ready = [ | |
| task_id | |
| for task_id in action.task_ids | |
| if env_state.task_statuses[task_id] != TaskStatus.READY | |
| ] | |
| if not_ready: | |
| return self._apply_invalid( | |
| f"Only ready tasks can be dispatched: {not_ready}." | |
| ) | |
| prev_utilization_potential, prev_criticality_potential = self._dispatch_potential( | |
| env_state, task_map | |
| ) | |
| for task_id in action.task_ids: | |
| task = task_map[task_id] | |
| env_state.task_statuses[task_id] = TaskStatus.RUNNING | |
| env_state.task_start_times[task_id] = env_state.current_time | |
| env_state.task_assigned_finish_times[task_id] = ( | |
| env_state.current_time + task.duration | |
| ) | |
| env_state.running_task_ids.append(task_id) | |
| env_state.ready_task_ids.remove(task_id) | |
| env_state.running_task_ids.sort() | |
| next_utilization_potential, next_criticality_potential = self._dispatch_potential( | |
| env_state, task_map | |
| ) | |
| breakdown = RewardBreakdown( | |
| utilization_reward=round( | |
| next_utilization_potential - prev_utilization_potential, 4 | |
| ), | |
| criticality_reward=round( | |
| next_criticality_potential - prev_criticality_potential, 4 | |
| ), | |
| ) | |
| reward = round( | |
| breakdown.utilization_reward + breakdown.criticality_reward, | |
| 4, | |
| ) | |
| self._cumulative_reward += reward | |
| observation = self._base_observation( | |
| reward=reward, | |
| breakdown=breakdown, | |
| note="Tasks dispatched. Use wait() to advance to the next completion event.", | |
| done=False, | |
| ) | |
| observation.received_action = action.model_dump(mode="json") | |
| observation.metadata["received_action"] = action.model_dump(mode="json") | |
| return observation | |
| if not env_state.running_task_ids: | |
| return self._apply_invalid("wait() requires at least one running task.") | |
| recent_events: list[WorkflowFailureEvent] = [] | |
| avoidable_wait_penalty = 0.0 | |
| if env_state.ready_task_ids: | |
| free_workers = self._effective_worker_capacity(env_state) - len( | |
| env_state.running_task_ids | |
| ) | |
| if free_workers > 0: | |
| avoidable_wait_penalty = self.AVOIDABLE_WAIT_PENALTY_PER_SLOT * min( | |
| free_workers, | |
| len(env_state.ready_task_ids), | |
| ) | |
| self._maybe_start_worker_outage(env_state, recent_events) | |
| next_completion_time = min( | |
| env_state.task_assigned_finish_times[task_id] | |
| for task_id in env_state.running_task_ids | |
| ) | |
| target_time = next_completion_time | |
| budget_hit_before_completion = False | |
| if env_state.time_budget is not None and env_state.time_budget < next_completion_time: | |
| target_time = env_state.time_budget | |
| budget_hit_before_completion = True | |
| elapsed = target_time - env_state.current_time | |
| env_state.cumulative_busy_time += elapsed * len(env_state.running_task_ids) | |
| env_state.current_time = target_time | |
| self._maybe_end_worker_outage(env_state, recent_events) | |
| if budget_hit_before_completion: | |
| unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) | |
| terminal_score = self._terminal_score() | |
| breakdown = RewardBreakdown( | |
| idle_penalty=round(avoidable_wait_penalty, 4), | |
| terminal_makespan_score=round(terminal_score, 4), | |
| unfinished_task_penalty=round(unfinished_penalty, 4), | |
| ) | |
| reward = round( | |
| breakdown.idle_penalty | |
| + breakdown.terminal_makespan_score | |
| + breakdown.unfinished_task_penalty, | |
| 4, | |
| ) | |
| self._set_recent_failure_events(env_state, recent_events) | |
| note = self._wait_note( | |
| completed_now=[], | |
| failed_now=[], | |
| unlocked=[], | |
| recent_events=recent_events, | |
| time_budget_hit=True, | |
| ) | |
| observation = self._terminate_episode( | |
| note=note, | |
| breakdown=breakdown, | |
| reward=reward, | |
| reason="time_budget", | |
| ) | |
| observation.received_action = action.model_dump(mode="json") | |
| observation.metadata["received_action"] = action.model_dump(mode="json") | |
| return observation | |
| completed_candidates = sorted( | |
| [ | |
| task_id | |
| for task_id in env_state.running_task_ids | |
| if env_state.task_assigned_finish_times[task_id] == next_completion_time | |
| ] | |
| ) | |
| completed_now: list[str] = [] | |
| failed_now: list[str] = [] | |
| for task_id in completed_candidates: | |
| env_state.running_task_ids.remove(task_id) | |
| del env_state.task_assigned_finish_times[task_id] | |
| if self._should_retry_fail(task_id): | |
| env_state.task_attempt_counts[task_id] += 1 | |
| env_state.task_statuses[task_id] = TaskStatus.READY | |
| env_state.task_start_times.pop(task_id, None) | |
| env_state.task_end_times.pop(task_id, None) | |
| if task_id not in env_state.ready_task_ids: | |
| env_state.ready_task_ids.append(task_id) | |
| failed_now.append(task_id) | |
| recent_events.append( | |
| WorkflowFailureEvent( | |
| event_type=FailureEventType.TASK_RETRY_FAILURE, | |
| time=next_completion_time, | |
| task_id=task_id, | |
| detail=f"{task_id} failed and returned to ready.", | |
| ) | |
| ) | |
| else: | |
| env_state.task_statuses[task_id] = TaskStatus.COMPLETED | |
| env_state.task_end_times[task_id] = next_completion_time | |
| env_state.completed_task_ids.append(task_id) | |
| completed_now.append(task_id) | |
| env_state.completed_task_ids.sort() | |
| env_state.ready_task_ids.sort() | |
| unlocked = self._transition_unlocks(completed_now) | |
| completion_reward = sum( | |
| 0.04 + 0.01 * task_map[task_id].priority for task_id in completed_now | |
| ) | |
| deadline_reward = 0.0 | |
| criticality_reward = 0.0 | |
| for task_id in completed_now: | |
| task = task_map[task_id] | |
| if task.deadline is not None: | |
| lateness = next_completion_time - task.deadline | |
| deadline_reward += 0.05 if lateness <= 0 else -0.02 * lateness | |
| criticality_reward += 0.03 * task.criticality | |
| utilization_reward = 0.06 * ( | |
| elapsed | |
| * (len(completed_candidates) + len(env_state.running_task_ids)) | |
| / max(1, self._config.worker_count) | |
| ) | |
| idle_penalty = 0.0 | |
| if not env_state.running_task_ids and env_state.ready_task_ids: | |
| idle_penalty = -0.03 * len(env_state.ready_task_ids) | |
| done = len(env_state.completed_task_ids) == len(episode.tasks) | |
| breakdown = RewardBreakdown( | |
| completion_reward=round(completion_reward, 4), | |
| utilization_reward=round(utilization_reward, 4), | |
| deadline_reward=round(deadline_reward, 4), | |
| criticality_reward=round(criticality_reward, 4), | |
| idle_penalty=round(idle_penalty + avoidable_wait_penalty, 4), | |
| terminal_makespan_score=round(self._terminal_score() if done else 0.0, 4), | |
| ) | |
| reward = round( | |
| breakdown.completion_reward | |
| + breakdown.utilization_reward | |
| + breakdown.deadline_reward | |
| + breakdown.criticality_reward | |
| + breakdown.idle_penalty | |
| + breakdown.terminal_makespan_score, | |
| 4, | |
| ) | |
| budget_exhausted_now = ( | |
| not done | |
| and env_state.time_budget is not None | |
| and env_state.current_time >= env_state.time_budget | |
| ) | |
| if budget_exhausted_now: | |
| unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) | |
| breakdown.unfinished_task_penalty = round(unfinished_penalty, 4) | |
| breakdown.terminal_makespan_score = round(self._terminal_score(), 4) | |
| reward = round( | |
| reward | |
| + breakdown.unfinished_task_penalty | |
| + breakdown.terminal_makespan_score, | |
| 4, | |
| ) | |
| self._set_recent_failure_events(env_state, recent_events) | |
| note = self._wait_note( | |
| completed_now=completed_now, | |
| failed_now=failed_now, | |
| unlocked=unlocked, | |
| recent_events=recent_events, | |
| time_budget_hit=True, | |
| ) | |
| observation = self._terminate_episode( | |
| note=note, | |
| breakdown=breakdown, | |
| reward=reward, | |
| reason="time_budget", | |
| ) | |
| observation.received_action = action.model_dump(mode="json") | |
| observation.metadata["received_action"] = action.model_dump(mode="json") | |
| observation.metadata["completed_now"] = completed_now | |
| observation.metadata["unlocked_now"] = unlocked | |
| observation.metadata["failed_now"] = failed_now | |
| return observation | |
| self._cumulative_reward += reward | |
| self._set_recent_failure_events(env_state, recent_events) | |
| observation = self._base_observation( | |
| reward=reward, | |
| breakdown=breakdown, | |
| note=self._wait_note( | |
| completed_now=completed_now, | |
| failed_now=failed_now, | |
| unlocked=unlocked, | |
| recent_events=recent_events, | |
| ), | |
| done=done, | |
| benchmark_score_override=self._benchmark_score() if done else None, | |
| ) | |
| observation.received_action = action.model_dump(mode="json") | |
| observation.metadata["received_action"] = action.model_dump(mode="json") | |
| observation.metadata["completed_now"] = completed_now | |
| observation.metadata["unlocked_now"] = unlocked | |
| observation.metadata["failed_now"] = failed_now | |
| if done: | |
| observation.benchmark_score = observation.success_metrics.benchmark_score | |
| return observation | |
| def state(self) -> State: | |
| """Expose generic OpenEnv state metadata.""" | |
| return self._state | |