# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """WorkflowArena event-driven workflow orchestration environment.""" from __future__ import annotations import math import random from typing import Any from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from workflow_arena.generator import generate_episode from workflow_arena.models import ( DifficultyPreset, EpisodeConfig, FailureEventType, ProgressSummary, RewardBreakdown, SuccessMetrics, TaskStatus, WorkflowArenaAction, WorkflowArenaObservation, WorkflowEnvStateSnapshot, WorkflowEpisodeSpec, WorkflowFailureEvent, WorkflowTaskSpec, WorkflowTaskView, WorkflowActionType, ) from workflow_arena.presets import get_preset_config class WorkflowArenaEnvironment(Environment): """Resource-constrained workflow scheduler with event-driven semantics.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True STEP_LIMIT_FLOOR: int = 32 STEP_LIMIT_MULTIPLIER: int = 8 INVALID_ACTION_PENALTY: float = -0.1 OVERCAPACITY_INVALID_ACTION_PENALTY: float = -0.25 AVOIDABLE_WAIT_PENALTY_PER_SLOT: float = -0.08 UNFINISHED_PRIORITY_PENALTY: float = -0.02 OVERDUE_PRIORITY_PENALTY_PER_TICK: float = -0.005 MAX_RECENT_FAILURE_EVENTS: int = 6 MIN_GRADER_SCORE: float = 0.01 MAX_GRADER_SCORE: float = 0.99 def __init__(self): self._state = State(episode_id=str(uuid4()), step_count=0) self._cumulative_reward = 0.0 self._max_episode_steps = self.STEP_LIMIT_FLOOR self._config = EpisodeConfig( preset=DifficultyPreset.EASY, seed=0, worker_count=2, ) self._episode_spec: WorkflowEpisodeSpec | None = None self._env_state: WorkflowEnvStateSnapshot | None = None self._event_rng = random.Random(0) def _require_episode(self) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]: if self._episode_spec is None or self._env_state is None: raise RuntimeError("Environment must be reset before use.") return self._episode_spec, self._env_state def _preset_config(self): episode, _ = self._require_episode() return episode.preset_config def _task_map(self) -> dict[str, WorkflowTaskSpec]: episode, _ = self._require_episode() return {task.task_id: task for task in episode.tasks} def _effective_worker_capacity( self, env_state: WorkflowEnvStateSnapshot | None = None ) -> int: if env_state is None: _, env_state = self._require_episode() return max(0, self._config.worker_count - env_state.degraded_workers) def _time_remaining( self, env_state: WorkflowEnvStateSnapshot | None = None ) -> int | None: if env_state is None: _, env_state = self._require_episode() if env_state.time_budget is None: return None return max(0, env_state.time_budget - env_state.current_time) def _terminal_score(self) -> float: episode, env_state = self._require_episode() if env_state.current_time <= 0: return 0.0 lower_bound = self._lower_bound_makespan(episode) score = lower_bound / max(lower_bound, env_state.current_time) return round(score, 4) def _bounded_grader_score(self, score: float) -> float: return round( min(self.MAX_GRADER_SCORE, max(self.MIN_GRADER_SCORE, score)), 4, ) def _benchmark_score(self) -> float: makespan_score, deadline_score, utilization_score = self._grade_components( include_terminal_makespan=True ) return self._bounded_grader_score( (0.5 * makespan_score) + (0.3 * deadline_score) + (0.2 * utilization_score) ) def _grade_components( self, *, include_terminal_makespan: bool = False ) -> tuple[float, float, float]: episode, env_state = self._require_episode() utilization = ( env_state.cumulative_busy_time / (env_state.current_time * self._config.worker_count) if env_state.current_time > 0 else 0.0 ) total_priority = sum(task.priority for task in episode.tasks) or 1 on_time_priority = 0 for task in episode.tasks: end_time = env_state.task_end_times.get(task.task_id) if end_time is None: continue if task.deadline is None or end_time <= task.deadline: on_time_priority += task.priority deadline_score = round(on_time_priority / total_priority, 4) utilization_score = round(utilization, 4) makespan_score = self._terminal_score() if include_terminal_makespan else 0.0 return makespan_score, deadline_score, utilization_score def _unfinished_task_penalty(self, current_time: int) -> float: episode, env_state = self._require_episode() penalty = 0.0 for task in episode.tasks: if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED: continue penalty += self.UNFINISHED_PRIORITY_PENALTY * task.priority if task.deadline is not None and current_time > task.deadline: penalty += ( self.OVERDUE_PRIORITY_PENALTY_PER_TICK * task.priority * (current_time - task.deadline) ) return round(penalty, 4) def _success_metrics( self, *, benchmark_score_override: float | None = None ) -> SuccessMetrics: episode, env_state = self._require_episode() unfinished_task_count = sum( 1 for task in episode.tasks if env_state.task_statuses[task.task_id] != TaskStatus.COMPLETED ) deadline_miss_count = sum( 1 for task in episode.tasks if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED and task.deadline is not None and env_state.task_end_times.get(task.task_id, 0) > task.deadline ) _, deadline_score, utilization_score = self._grade_components( include_terminal_makespan=False ) all_done = unfinished_task_count == 0 return SuccessMetrics( makespan=env_state.current_time if all_done else None, worker_utilization=utilization_score, deadline_miss_count=deadline_miss_count, unfinished_task_count=unfinished_task_count, weighted_priority_completion=deadline_score, benchmark_score=benchmark_score_override, ) def _task_view( self, task: WorkflowTaskSpec, status: TaskStatus, *, include_planner_hints: bool = True, ) -> WorkflowTaskView: _, env_state = self._require_episode() return WorkflowTaskView( task_id=task.task_id, status=status, duration=task.duration, priority=task.priority, dependencies=task.dependencies, deadline=task.deadline, criticality=task.criticality if include_planner_hints else None, slack=float(task.slack) if include_planner_hints else None, downstream_count=task.downstream_count if include_planner_hints else 0, start_time=env_state.task_start_times.get(task.task_id), end_time=( env_state.task_end_times.get(task.task_id) or env_state.task_assigned_finish_times.get(task.task_id) ), attempt_count=env_state.task_attempt_counts.get(task.task_id, 0), ) def _task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]: episode, env_state = self._require_episode() return [ self._task_view(task, status, include_planner_hints=True) for task in episode.tasks if env_state.task_statuses[task.task_id] == status ] def debug_task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]: return self._task_views_for_status(status) def _set_recent_failure_events( self, env_state: WorkflowEnvStateSnapshot, events: list[WorkflowFailureEvent], ) -> None: env_state.recent_failure_events = events[-self.MAX_RECENT_FAILURE_EVENTS :] def _maybe_end_worker_outage( self, env_state: WorkflowEnvStateSnapshot, events: list[WorkflowFailureEvent], ) -> None: if ( env_state.active_worker_outage_until is not None and env_state.current_time >= env_state.active_worker_outage_until ): events.append( WorkflowFailureEvent( event_type=FailureEventType.WORKER_OUTAGE_END, time=env_state.current_time, worker_delta=1, detail="Worker capacity restored.", ) ) env_state.active_worker_outage_until = None env_state.degraded_workers = 0 def _maybe_start_worker_outage( self, env_state: WorkflowEnvStateSnapshot, events: list[WorkflowFailureEvent], ) -> None: preset_config = self._preset_config() if self._config.preset != DifficultyPreset.HARD: return if env_state.active_worker_outage_until is not None: return if preset_config.worker_outage_rate <= 0.0: return if self._event_rng.random() >= preset_config.worker_outage_rate: return duration = self._event_rng.randint( preset_config.worker_outage_duration_min, preset_config.worker_outage_duration_max, ) if duration <= 0: return env_state.degraded_workers = min(1, self._config.worker_count) env_state.active_worker_outage_until = env_state.current_time + duration events.append( WorkflowFailureEvent( event_type=FailureEventType.WORKER_OUTAGE_START, time=env_state.current_time, worker_delta=-env_state.degraded_workers, duration=duration, detail=f"Worker outage active until t={env_state.active_worker_outage_until}.", ) ) def _should_retry_fail(self, task_id: str) -> bool: preset_config = self._preset_config() _, env_state = self._require_episode() if self._config.preset != DifficultyPreset.HARD: return False if preset_config.task_retry_failure_rate <= 0.0: return False if env_state.task_attempt_counts.get(task_id, 0) >= preset_config.max_task_retries: return False return self._event_rng.random() < preset_config.task_retry_failure_rate def _dispatch_potential( self, env_state: WorkflowEnvStateSnapshot, task_map: dict[str, WorkflowTaskSpec], ) -> tuple[float, float]: if not env_state.running_task_ids: return 0.0, 0.0 episode, _ = self._require_episode() max_slack = max((task.slack for task in episode.tasks), default=0) utilization_component = 0.06 * ( len(env_state.running_task_ids) / max(1, self._config.worker_count) ) criticality_component = 0.0 for task_id in env_state.running_task_ids: task = task_map[task_id] slack_urgency = 1.0 if max_slack <= 0 else 1.0 - (task.slack / max_slack) criticality_component += (0.6 * task.criticality) + (0.4 * slack_urgency) criticality_component = 0.04 * ( criticality_component / max(1, self._config.worker_count) ) return round(utilization_component, 4), round(criticality_component, 4) def _base_observation( self, *, reward: float, breakdown: RewardBreakdown, note: str, done: bool, benchmark_score_override: float | None = None, ) -> WorkflowArenaObservation: episode, env_state = self._require_episode() ready_tasks = self._task_views_for_status(TaskStatus.READY) running_tasks = self._task_views_for_status(TaskStatus.RUNNING) completed_tasks = self._task_views_for_status(TaskStatus.COMPLETED) blocked_tasks = self._task_views_for_status(TaskStatus.BLOCKED) effective_workers = self._effective_worker_capacity(env_state) return WorkflowArenaObservation( done=done, reward=reward, config=self._config, current_time=env_state.current_time, total_workers=self._config.worker_count, effective_workers=effective_workers, degraded_workers=env_state.degraded_workers, free_workers=max(0, effective_workers - len(running_tasks)), time_budget=env_state.time_budget, time_remaining=self._time_remaining(env_state), progress=ProgressSummary( total=len(episode.tasks), blocked=len(blocked_tasks), ready=len(ready_tasks), running=len(running_tasks), completed=len(completed_tasks), ), ready_tasks=ready_tasks, running_tasks=running_tasks, completed_tasks=completed_tasks, blocked_tasks=blocked_tasks, last_reward_breakdown=breakdown, cumulative_reward=self._cumulative_reward, success_metrics=self._success_metrics( benchmark_score_override=benchmark_score_override ), note=note, benchmark_score=benchmark_score_override, recent_failure_events=env_state.recent_failure_events, metadata={ "phase": "simulation_active", "note": note, "effective_workers": effective_workers, "degraded_workers": env_state.degraded_workers, "time_budget": env_state.time_budget, "time_remaining": self._time_remaining(env_state), "recent_failure_events": [ event.model_dump(mode="json") for event in env_state.recent_failure_events ], "episode_loop": [ "reset generates a seeded workflow DAG episode", "dispatch(task_ids=[...]) starts ready tasks if workers are free", "wait() advances simulated time to the next completion event", "medium and hard episodes may end at a fixed time budget", "hard mode may trigger outages and retry failures", ], }, ) def _lower_bound_makespan(self, episode: WorkflowEpisodeSpec) -> int: total_work = sum(task.duration for task in episode.tasks) work_bound = (total_work + self._config.worker_count - 1) // self._config.worker_count path_bound = max(task.critical_path_length for task in episode.tasks) return max(1, work_bound, path_bound) def _termination_breakdown( self, *, invalid_penalty: float = 0.0, idle_penalty: float = 0.0, terminal_makespan_score: float = 0.0, unfinished_task_penalty: float = 0.0, ) -> RewardBreakdown: return RewardBreakdown( invalid_action_penalty=round(invalid_penalty, 4), idle_penalty=round(idle_penalty, 4), terminal_makespan_score=round(terminal_makespan_score, 4), unfinished_task_penalty=round(unfinished_task_penalty, 4), ) def _terminate_episode( self, *, note: str, breakdown: RewardBreakdown, reward: float, reason: str, benchmark_score: float | None = None, ) -> WorkflowArenaObservation: if benchmark_score is None: benchmark_score = self._benchmark_score() self._cumulative_reward += reward observation = self._base_observation( reward=reward, breakdown=breakdown, note=note, done=True, benchmark_score_override=benchmark_score, ) observation.termination_reason = reason observation.benchmark_score = benchmark_score observation.metadata["termination_reason"] = reason observation.metadata["benchmark_score"] = benchmark_score return observation def _step_limit_reached(self) -> bool: return self._state.step_count >= self._max_episode_steps def _maybe_terminate_for_limits(self) -> WorkflowArenaObservation | None: if not self._step_limit_reached(): return None _, env_state = self._require_episode() unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) terminal_score = self._terminal_score() breakdown = self._termination_breakdown( terminal_makespan_score=terminal_score, unfinished_task_penalty=unfinished_penalty, ) reward = round(-1.0 + unfinished_penalty + terminal_score, 4) return self._terminate_episode( note="Episode terminated after hitting the safety step limit.", breakdown=breakdown, reward=reward, reason="step_limit", ) def _apply_invalid( self, message: str, *, penalty: float | None = None, ) -> WorkflowArenaObservation: _, env_state = self._require_episode() applied_penalty = ( self.INVALID_ACTION_PENALTY if penalty is None else float(penalty) ) breakdown = RewardBreakdown(invalid_action_penalty=round(applied_penalty, 4)) self._cumulative_reward += breakdown.invalid_action_penalty self._set_recent_failure_events(env_state, []) observation = self._base_observation( reward=breakdown.invalid_action_penalty, breakdown=breakdown, note="Invalid action.", done=False, ) observation.validation_error = message observation.metadata["validation_error"] = message return observation def _transition_unlocks(self, completed_task_ids: list[str]) -> list[str]: episode, env_state = self._require_episode() task_map = {task.task_id: task for task in episode.tasks} unlocked: list[str] = [] for task_id in completed_task_ids: for dependent_id in task_map[task_id].dependents: env_state.task_remaining_dependencies[dependent_id] -= 1 if env_state.task_remaining_dependencies[dependent_id] == 0: env_state.task_statuses[dependent_id] = TaskStatus.READY if dependent_id not in env_state.ready_task_ids: env_state.ready_task_ids.append(dependent_id) if dependent_id in env_state.blocked_task_ids: env_state.blocked_task_ids.remove(dependent_id) unlocked.append(dependent_id) env_state.ready_task_ids.sort() env_state.blocked_task_ids.sort() return unlocked def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs: Any, ) -> WorkflowArenaObservation: """Generate a seeded workflow DAG episode.""" preset_raw = kwargs.pop("preset", DifficultyPreset.EASY) worker_count_raw = kwargs.pop("worker_count", None) del kwargs preset = ( preset_raw if isinstance(preset_raw, DifficultyPreset) else DifficultyPreset(str(preset_raw)) ) preset_config = get_preset_config(preset) chosen_seed = 0 if seed is None else seed chosen_worker_count = ( preset_config.worker_count if worker_count_raw is None else int(worker_count_raw) ) chosen_episode_id = str(uuid4()) if episode_id is None else episode_id self._state = State(episode_id=chosen_episode_id, step_count=0) self._cumulative_reward = 0.0 self._config = EpisodeConfig( preset=preset, seed=chosen_seed, worker_count=chosen_worker_count, ) self._event_rng = random.Random( (chosen_seed + 1) * 1009 + (chosen_worker_count * 131) + (list(DifficultyPreset).index(preset) + 1) ) self._episode_spec, self._env_state = generate_episode(self._config) self._max_episode_steps = max( self.STEP_LIMIT_FLOOR, len(self._episode_spec.tasks) * self.STEP_LIMIT_MULTIPLIER, ) self._env_state.episode_id = chosen_episode_id lower_bound = self._lower_bound_makespan(self._episode_spec) if preset_config.time_budget_multiplier is not None: self._env_state.time_budget = int( math.ceil(lower_bound * preset_config.time_budget_multiplier) ) self._set_recent_failure_events(self._env_state, []) note = "Workflow episode generated. Dispatch ready tasks or wait for completions." if self._env_state.time_budget is not None: note = ( f"Workflow episode generated. Finish as much as possible before " f"t={self._env_state.time_budget}." ) if preset == DifficultyPreset.HARD: note += " Hard mode may trigger worker outages and retry failures." return self._base_observation( reward=0.0, breakdown=RewardBreakdown(), note=note, done=False, ) def _wait_note( self, *, completed_now: list[str], failed_now: list[str], unlocked: list[str], recent_events: list[WorkflowFailureEvent], time_budget_hit: bool = False, ) -> str: chunks: list[str] = [] if time_budget_hit: chunks.append("Time budget exhausted before the next completion event.") elif completed_now: chunks.append(f"Completed: {', '.join(completed_now)}.") else: chunks.append("Advanced to next completion event.") if failed_now: chunks.append(f"Retry required: {', '.join(failed_now)}.") if unlocked: chunks.append(f"Unlocked: {', '.join(unlocked)}.") for event in recent_events: if event.event_type == FailureEventType.WORKER_OUTAGE_START: chunks.append(event.detail) elif event.event_type == FailureEventType.WORKER_OUTAGE_END: chunks.append("Worker capacity restored.") return " ".join(chunks) def step( self, action: WorkflowArenaAction, timeout_s: float | None = None, **kwargs: Any, ) -> WorkflowArenaObservation: """Apply a dispatch or wait action using event-driven semantics.""" del timeout_s, kwargs episode, env_state = self._require_episode() task_map = {task.task_id: task for task in episode.tasks} self._state.step_count += 1 self._set_recent_failure_events(env_state, []) limit_termination = self._maybe_terminate_for_limits() if limit_termination is not None: return limit_termination if action.action_type == WorkflowActionType.WAIT and action.task_ids: return self._apply_invalid("wait() must not include task_ids.") if action.action_type == WorkflowActionType.DISPATCH: if not action.task_ids: return self._apply_invalid( "dispatch(task_ids=[...]) requires at least one task id." ) if len(set(action.task_ids)) != len(action.task_ids): return self._apply_invalid( "dispatch(task_ids=[...]) must not contain duplicate task ids." ) free_workers = self._effective_worker_capacity(env_state) - len( env_state.running_task_ids ) if len(action.task_ids) > max(0, free_workers): return self._apply_invalid( "dispatch(task_ids=[...]) cannot exceed available worker capacity.", penalty=self.OVERCAPACITY_INVALID_ACTION_PENALTY, ) unknown_tasks = [task_id for task_id in action.task_ids if task_id not in task_map] if unknown_tasks: return self._apply_invalid(f"Unknown task ids: {unknown_tasks}.") not_ready = [ task_id for task_id in action.task_ids if env_state.task_statuses[task_id] != TaskStatus.READY ] if not_ready: return self._apply_invalid( f"Only ready tasks can be dispatched: {not_ready}." ) prev_utilization_potential, prev_criticality_potential = self._dispatch_potential( env_state, task_map ) for task_id in action.task_ids: task = task_map[task_id] env_state.task_statuses[task_id] = TaskStatus.RUNNING env_state.task_start_times[task_id] = env_state.current_time env_state.task_assigned_finish_times[task_id] = ( env_state.current_time + task.duration ) env_state.running_task_ids.append(task_id) env_state.ready_task_ids.remove(task_id) env_state.running_task_ids.sort() next_utilization_potential, next_criticality_potential = self._dispatch_potential( env_state, task_map ) breakdown = RewardBreakdown( utilization_reward=round( next_utilization_potential - prev_utilization_potential, 4 ), criticality_reward=round( next_criticality_potential - prev_criticality_potential, 4 ), ) reward = round( breakdown.utilization_reward + breakdown.criticality_reward, 4, ) self._cumulative_reward += reward observation = self._base_observation( reward=reward, breakdown=breakdown, note="Tasks dispatched. Use wait() to advance to the next completion event.", done=False, ) observation.received_action = action.model_dump(mode="json") observation.metadata["received_action"] = action.model_dump(mode="json") return observation if not env_state.running_task_ids: return self._apply_invalid("wait() requires at least one running task.") recent_events: list[WorkflowFailureEvent] = [] avoidable_wait_penalty = 0.0 if env_state.ready_task_ids: free_workers = self._effective_worker_capacity(env_state) - len( env_state.running_task_ids ) if free_workers > 0: avoidable_wait_penalty = self.AVOIDABLE_WAIT_PENALTY_PER_SLOT * min( free_workers, len(env_state.ready_task_ids), ) self._maybe_start_worker_outage(env_state, recent_events) next_completion_time = min( env_state.task_assigned_finish_times[task_id] for task_id in env_state.running_task_ids ) target_time = next_completion_time budget_hit_before_completion = False if env_state.time_budget is not None and env_state.time_budget < next_completion_time: target_time = env_state.time_budget budget_hit_before_completion = True elapsed = target_time - env_state.current_time env_state.cumulative_busy_time += elapsed * len(env_state.running_task_ids) env_state.current_time = target_time self._maybe_end_worker_outage(env_state, recent_events) if budget_hit_before_completion: unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) terminal_score = self._terminal_score() breakdown = RewardBreakdown( idle_penalty=round(avoidable_wait_penalty, 4), terminal_makespan_score=round(terminal_score, 4), unfinished_task_penalty=round(unfinished_penalty, 4), ) reward = round( breakdown.idle_penalty + breakdown.terminal_makespan_score + breakdown.unfinished_task_penalty, 4, ) self._set_recent_failure_events(env_state, recent_events) note = self._wait_note( completed_now=[], failed_now=[], unlocked=[], recent_events=recent_events, time_budget_hit=True, ) observation = self._terminate_episode( note=note, breakdown=breakdown, reward=reward, reason="time_budget", ) observation.received_action = action.model_dump(mode="json") observation.metadata["received_action"] = action.model_dump(mode="json") return observation completed_candidates = sorted( [ task_id for task_id in env_state.running_task_ids if env_state.task_assigned_finish_times[task_id] == next_completion_time ] ) completed_now: list[str] = [] failed_now: list[str] = [] for task_id in completed_candidates: env_state.running_task_ids.remove(task_id) del env_state.task_assigned_finish_times[task_id] if self._should_retry_fail(task_id): env_state.task_attempt_counts[task_id] += 1 env_state.task_statuses[task_id] = TaskStatus.READY env_state.task_start_times.pop(task_id, None) env_state.task_end_times.pop(task_id, None) if task_id not in env_state.ready_task_ids: env_state.ready_task_ids.append(task_id) failed_now.append(task_id) recent_events.append( WorkflowFailureEvent( event_type=FailureEventType.TASK_RETRY_FAILURE, time=next_completion_time, task_id=task_id, detail=f"{task_id} failed and returned to ready.", ) ) else: env_state.task_statuses[task_id] = TaskStatus.COMPLETED env_state.task_end_times[task_id] = next_completion_time env_state.completed_task_ids.append(task_id) completed_now.append(task_id) env_state.completed_task_ids.sort() env_state.ready_task_ids.sort() unlocked = self._transition_unlocks(completed_now) completion_reward = sum( 0.04 + 0.01 * task_map[task_id].priority for task_id in completed_now ) deadline_reward = 0.0 criticality_reward = 0.0 for task_id in completed_now: task = task_map[task_id] if task.deadline is not None: lateness = next_completion_time - task.deadline deadline_reward += 0.05 if lateness <= 0 else -0.02 * lateness criticality_reward += 0.03 * task.criticality utilization_reward = 0.06 * ( elapsed * (len(completed_candidates) + len(env_state.running_task_ids)) / max(1, self._config.worker_count) ) idle_penalty = 0.0 if not env_state.running_task_ids and env_state.ready_task_ids: idle_penalty = -0.03 * len(env_state.ready_task_ids) done = len(env_state.completed_task_ids) == len(episode.tasks) breakdown = RewardBreakdown( completion_reward=round(completion_reward, 4), utilization_reward=round(utilization_reward, 4), deadline_reward=round(deadline_reward, 4), criticality_reward=round(criticality_reward, 4), idle_penalty=round(idle_penalty + avoidable_wait_penalty, 4), terminal_makespan_score=round(self._terminal_score() if done else 0.0, 4), ) reward = round( breakdown.completion_reward + breakdown.utilization_reward + breakdown.deadline_reward + breakdown.criticality_reward + breakdown.idle_penalty + breakdown.terminal_makespan_score, 4, ) budget_exhausted_now = ( not done and env_state.time_budget is not None and env_state.current_time >= env_state.time_budget ) if budget_exhausted_now: unfinished_penalty = self._unfinished_task_penalty(env_state.current_time) breakdown.unfinished_task_penalty = round(unfinished_penalty, 4) breakdown.terminal_makespan_score = round(self._terminal_score(), 4) reward = round( reward + breakdown.unfinished_task_penalty + breakdown.terminal_makespan_score, 4, ) self._set_recent_failure_events(env_state, recent_events) note = self._wait_note( completed_now=completed_now, failed_now=failed_now, unlocked=unlocked, recent_events=recent_events, time_budget_hit=True, ) observation = self._terminate_episode( note=note, breakdown=breakdown, reward=reward, reason="time_budget", ) observation.received_action = action.model_dump(mode="json") observation.metadata["received_action"] = action.model_dump(mode="json") observation.metadata["completed_now"] = completed_now observation.metadata["unlocked_now"] = unlocked observation.metadata["failed_now"] = failed_now return observation self._cumulative_reward += reward self._set_recent_failure_events(env_state, recent_events) observation = self._base_observation( reward=reward, breakdown=breakdown, note=self._wait_note( completed_now=completed_now, failed_now=failed_now, unlocked=unlocked, recent_events=recent_events, ), done=done, benchmark_score_override=self._benchmark_score() if done else None, ) observation.received_action = action.model_dump(mode="json") observation.metadata["received_action"] = action.model_dump(mode="json") observation.metadata["completed_now"] = completed_now observation.metadata["unlocked_now"] = unlocked observation.metadata["failed_now"] = failed_now if done: observation.benchmark_score = observation.success_metrics.benchmark_score return observation @property def state(self) -> State: """Expose generic OpenEnv state metadata.""" return self._state