workflow_arena / server /workflow_arena_environment.py
Cyber-Machine's picture
feat: implement grading system with task definitions and score extraction
b522b5c verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""WorkflowArena event-driven workflow orchestration environment."""
from __future__ import annotations
import math
import random
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from workflow_arena.generator import generate_episode
from workflow_arena.models import (
DifficultyPreset,
EpisodeConfig,
FailureEventType,
ProgressSummary,
RewardBreakdown,
SuccessMetrics,
TaskStatus,
WorkflowArenaAction,
WorkflowArenaObservation,
WorkflowEnvStateSnapshot,
WorkflowEpisodeSpec,
WorkflowFailureEvent,
WorkflowTaskSpec,
WorkflowTaskView,
WorkflowActionType,
)
from workflow_arena.presets import get_preset_config
class WorkflowArenaEnvironment(Environment):
"""Resource-constrained workflow scheduler with event-driven semantics."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
STEP_LIMIT_FLOOR: int = 32
STEP_LIMIT_MULTIPLIER: int = 8
INVALID_ACTION_PENALTY: float = -0.1
OVERCAPACITY_INVALID_ACTION_PENALTY: float = -0.25
AVOIDABLE_WAIT_PENALTY_PER_SLOT: float = -0.08
UNFINISHED_PRIORITY_PENALTY: float = -0.02
OVERDUE_PRIORITY_PENALTY_PER_TICK: float = -0.005
MAX_RECENT_FAILURE_EVENTS: int = 6
MIN_GRADER_SCORE: float = 0.01
MAX_GRADER_SCORE: float = 0.99
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._cumulative_reward = 0.0
self._max_episode_steps = self.STEP_LIMIT_FLOOR
self._config = EpisodeConfig(
preset=DifficultyPreset.EASY,
seed=0,
worker_count=2,
)
self._episode_spec: WorkflowEpisodeSpec | None = None
self._env_state: WorkflowEnvStateSnapshot | None = None
self._event_rng = random.Random(0)
def _require_episode(self) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
if self._episode_spec is None or self._env_state is None:
raise RuntimeError("Environment must be reset before use.")
return self._episode_spec, self._env_state
def _preset_config(self):
episode, _ = self._require_episode()
return episode.preset_config
def _task_map(self) -> dict[str, WorkflowTaskSpec]:
episode, _ = self._require_episode()
return {task.task_id: task for task in episode.tasks}
def _effective_worker_capacity(
self, env_state: WorkflowEnvStateSnapshot | None = None
) -> int:
if env_state is None:
_, env_state = self._require_episode()
return max(0, self._config.worker_count - env_state.degraded_workers)
def _time_remaining(
self, env_state: WorkflowEnvStateSnapshot | None = None
) -> int | None:
if env_state is None:
_, env_state = self._require_episode()
if env_state.time_budget is None:
return None
return max(0, env_state.time_budget - env_state.current_time)
def _terminal_score(self) -> float:
episode, env_state = self._require_episode()
if env_state.current_time <= 0:
return 0.0
lower_bound = self._lower_bound_makespan(episode)
score = lower_bound / max(lower_bound, env_state.current_time)
return round(score, 4)
def _bounded_grader_score(self, score: float) -> float:
return round(
min(self.MAX_GRADER_SCORE, max(self.MIN_GRADER_SCORE, score)),
4,
)
def _benchmark_score(self) -> float:
makespan_score, deadline_score, utilization_score = self._grade_components(
include_terminal_makespan=True
)
return self._bounded_grader_score(
(0.5 * makespan_score) + (0.3 * deadline_score) + (0.2 * utilization_score)
)
def _grade_components(
self, *, include_terminal_makespan: bool = False
) -> tuple[float, float, float]:
episode, env_state = self._require_episode()
utilization = (
env_state.cumulative_busy_time
/ (env_state.current_time * self._config.worker_count)
if env_state.current_time > 0
else 0.0
)
total_priority = sum(task.priority for task in episode.tasks) or 1
on_time_priority = 0
for task in episode.tasks:
end_time = env_state.task_end_times.get(task.task_id)
if end_time is None:
continue
if task.deadline is None or end_time <= task.deadline:
on_time_priority += task.priority
deadline_score = round(on_time_priority / total_priority, 4)
utilization_score = round(utilization, 4)
makespan_score = self._terminal_score() if include_terminal_makespan else 0.0
return makespan_score, deadline_score, utilization_score
def _unfinished_task_penalty(self, current_time: int) -> float:
episode, env_state = self._require_episode()
penalty = 0.0
for task in episode.tasks:
if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED:
continue
penalty += self.UNFINISHED_PRIORITY_PENALTY * task.priority
if task.deadline is not None and current_time > task.deadline:
penalty += (
self.OVERDUE_PRIORITY_PENALTY_PER_TICK
* task.priority
* (current_time - task.deadline)
)
return round(penalty, 4)
def _success_metrics(
self, *, benchmark_score_override: float | None = None
) -> SuccessMetrics:
episode, env_state = self._require_episode()
unfinished_task_count = sum(
1
for task in episode.tasks
if env_state.task_statuses[task.task_id] != TaskStatus.COMPLETED
)
deadline_miss_count = sum(
1
for task in episode.tasks
if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED
and task.deadline is not None
and env_state.task_end_times.get(task.task_id, 0) > task.deadline
)
_, deadline_score, utilization_score = self._grade_components(
include_terminal_makespan=False
)
all_done = unfinished_task_count == 0
return SuccessMetrics(
makespan=env_state.current_time if all_done else None,
worker_utilization=utilization_score,
deadline_miss_count=deadline_miss_count,
unfinished_task_count=unfinished_task_count,
weighted_priority_completion=deadline_score,
benchmark_score=benchmark_score_override,
)
def _task_view(
self,
task: WorkflowTaskSpec,
status: TaskStatus,
*,
include_planner_hints: bool = True,
) -> WorkflowTaskView:
_, env_state = self._require_episode()
return WorkflowTaskView(
task_id=task.task_id,
status=status,
duration=task.duration,
priority=task.priority,
dependencies=task.dependencies,
deadline=task.deadline,
criticality=task.criticality if include_planner_hints else None,
slack=float(task.slack) if include_planner_hints else None,
downstream_count=task.downstream_count if include_planner_hints else 0,
start_time=env_state.task_start_times.get(task.task_id),
end_time=(
env_state.task_end_times.get(task.task_id)
or env_state.task_assigned_finish_times.get(task.task_id)
),
attempt_count=env_state.task_attempt_counts.get(task.task_id, 0),
)
def _task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
episode, env_state = self._require_episode()
return [
self._task_view(task, status, include_planner_hints=True)
for task in episode.tasks
if env_state.task_statuses[task.task_id] == status
]
def debug_task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
return self._task_views_for_status(status)
def _set_recent_failure_events(
self,
env_state: WorkflowEnvStateSnapshot,
events: list[WorkflowFailureEvent],
) -> None:
env_state.recent_failure_events = events[-self.MAX_RECENT_FAILURE_EVENTS :]
def _maybe_end_worker_outage(
self,
env_state: WorkflowEnvStateSnapshot,
events: list[WorkflowFailureEvent],
) -> None:
if (
env_state.active_worker_outage_until is not None
and env_state.current_time >= env_state.active_worker_outage_until
):
events.append(
WorkflowFailureEvent(
event_type=FailureEventType.WORKER_OUTAGE_END,
time=env_state.current_time,
worker_delta=1,
detail="Worker capacity restored.",
)
)
env_state.active_worker_outage_until = None
env_state.degraded_workers = 0
def _maybe_start_worker_outage(
self,
env_state: WorkflowEnvStateSnapshot,
events: list[WorkflowFailureEvent],
) -> None:
preset_config = self._preset_config()
if self._config.preset != DifficultyPreset.HARD:
return
if env_state.active_worker_outage_until is not None:
return
if preset_config.worker_outage_rate <= 0.0:
return
if self._event_rng.random() >= preset_config.worker_outage_rate:
return
duration = self._event_rng.randint(
preset_config.worker_outage_duration_min,
preset_config.worker_outage_duration_max,
)
if duration <= 0:
return
env_state.degraded_workers = min(1, self._config.worker_count)
env_state.active_worker_outage_until = env_state.current_time + duration
events.append(
WorkflowFailureEvent(
event_type=FailureEventType.WORKER_OUTAGE_START,
time=env_state.current_time,
worker_delta=-env_state.degraded_workers,
duration=duration,
detail=f"Worker outage active until t={env_state.active_worker_outage_until}.",
)
)
def _should_retry_fail(self, task_id: str) -> bool:
preset_config = self._preset_config()
_, env_state = self._require_episode()
if self._config.preset != DifficultyPreset.HARD:
return False
if preset_config.task_retry_failure_rate <= 0.0:
return False
if env_state.task_attempt_counts.get(task_id, 0) >= preset_config.max_task_retries:
return False
return self._event_rng.random() < preset_config.task_retry_failure_rate
def _dispatch_potential(
self,
env_state: WorkflowEnvStateSnapshot,
task_map: dict[str, WorkflowTaskSpec],
) -> tuple[float, float]:
if not env_state.running_task_ids:
return 0.0, 0.0
episode, _ = self._require_episode()
max_slack = max((task.slack for task in episode.tasks), default=0)
utilization_component = 0.06 * (
len(env_state.running_task_ids) / max(1, self._config.worker_count)
)
criticality_component = 0.0
for task_id in env_state.running_task_ids:
task = task_map[task_id]
slack_urgency = 1.0 if max_slack <= 0 else 1.0 - (task.slack / max_slack)
criticality_component += (0.6 * task.criticality) + (0.4 * slack_urgency)
criticality_component = 0.04 * (
criticality_component / max(1, self._config.worker_count)
)
return round(utilization_component, 4), round(criticality_component, 4)
def _base_observation(
self,
*,
reward: float,
breakdown: RewardBreakdown,
note: str,
done: bool,
benchmark_score_override: float | None = None,
) -> WorkflowArenaObservation:
episode, env_state = self._require_episode()
ready_tasks = self._task_views_for_status(TaskStatus.READY)
running_tasks = self._task_views_for_status(TaskStatus.RUNNING)
completed_tasks = self._task_views_for_status(TaskStatus.COMPLETED)
blocked_tasks = self._task_views_for_status(TaskStatus.BLOCKED)
effective_workers = self._effective_worker_capacity(env_state)
return WorkflowArenaObservation(
done=done,
reward=reward,
config=self._config,
current_time=env_state.current_time,
total_workers=self._config.worker_count,
effective_workers=effective_workers,
degraded_workers=env_state.degraded_workers,
free_workers=max(0, effective_workers - len(running_tasks)),
time_budget=env_state.time_budget,
time_remaining=self._time_remaining(env_state),
progress=ProgressSummary(
total=len(episode.tasks),
blocked=len(blocked_tasks),
ready=len(ready_tasks),
running=len(running_tasks),
completed=len(completed_tasks),
),
ready_tasks=ready_tasks,
running_tasks=running_tasks,
completed_tasks=completed_tasks,
blocked_tasks=blocked_tasks,
last_reward_breakdown=breakdown,
cumulative_reward=self._cumulative_reward,
success_metrics=self._success_metrics(
benchmark_score_override=benchmark_score_override
),
note=note,
benchmark_score=benchmark_score_override,
recent_failure_events=env_state.recent_failure_events,
metadata={
"phase": "simulation_active",
"note": note,
"effective_workers": effective_workers,
"degraded_workers": env_state.degraded_workers,
"time_budget": env_state.time_budget,
"time_remaining": self._time_remaining(env_state),
"recent_failure_events": [
event.model_dump(mode="json") for event in env_state.recent_failure_events
],
"episode_loop": [
"reset generates a seeded workflow DAG episode",
"dispatch(task_ids=[...]) starts ready tasks if workers are free",
"wait() advances simulated time to the next completion event",
"medium and hard episodes may end at a fixed time budget",
"hard mode may trigger outages and retry failures",
],
},
)
def _lower_bound_makespan(self, episode: WorkflowEpisodeSpec) -> int:
total_work = sum(task.duration for task in episode.tasks)
work_bound = (total_work + self._config.worker_count - 1) // self._config.worker_count
path_bound = max(task.critical_path_length for task in episode.tasks)
return max(1, work_bound, path_bound)
def _termination_breakdown(
self,
*,
invalid_penalty: float = 0.0,
idle_penalty: float = 0.0,
terminal_makespan_score: float = 0.0,
unfinished_task_penalty: float = 0.0,
) -> RewardBreakdown:
return RewardBreakdown(
invalid_action_penalty=round(invalid_penalty, 4),
idle_penalty=round(idle_penalty, 4),
terminal_makespan_score=round(terminal_makespan_score, 4),
unfinished_task_penalty=round(unfinished_task_penalty, 4),
)
def _terminate_episode(
self,
*,
note: str,
breakdown: RewardBreakdown,
reward: float,
reason: str,
benchmark_score: float | None = None,
) -> WorkflowArenaObservation:
if benchmark_score is None:
benchmark_score = self._benchmark_score()
self._cumulative_reward += reward
observation = self._base_observation(
reward=reward,
breakdown=breakdown,
note=note,
done=True,
benchmark_score_override=benchmark_score,
)
observation.termination_reason = reason
observation.benchmark_score = benchmark_score
observation.metadata["termination_reason"] = reason
observation.metadata["benchmark_score"] = benchmark_score
return observation
def _step_limit_reached(self) -> bool:
return self._state.step_count >= self._max_episode_steps
def _maybe_terminate_for_limits(self) -> WorkflowArenaObservation | None:
if not self._step_limit_reached():
return None
_, env_state = self._require_episode()
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
terminal_score = self._terminal_score()
breakdown = self._termination_breakdown(
terminal_makespan_score=terminal_score,
unfinished_task_penalty=unfinished_penalty,
)
reward = round(-1.0 + unfinished_penalty + terminal_score, 4)
return self._terminate_episode(
note="Episode terminated after hitting the safety step limit.",
breakdown=breakdown,
reward=reward,
reason="step_limit",
)
def _apply_invalid(
self,
message: str,
*,
penalty: float | None = None,
) -> WorkflowArenaObservation:
_, env_state = self._require_episode()
applied_penalty = (
self.INVALID_ACTION_PENALTY if penalty is None else float(penalty)
)
breakdown = RewardBreakdown(invalid_action_penalty=round(applied_penalty, 4))
self._cumulative_reward += breakdown.invalid_action_penalty
self._set_recent_failure_events(env_state, [])
observation = self._base_observation(
reward=breakdown.invalid_action_penalty,
breakdown=breakdown,
note="Invalid action.",
done=False,
)
observation.validation_error = message
observation.metadata["validation_error"] = message
return observation
def _transition_unlocks(self, completed_task_ids: list[str]) -> list[str]:
episode, env_state = self._require_episode()
task_map = {task.task_id: task for task in episode.tasks}
unlocked: list[str] = []
for task_id in completed_task_ids:
for dependent_id in task_map[task_id].dependents:
env_state.task_remaining_dependencies[dependent_id] -= 1
if env_state.task_remaining_dependencies[dependent_id] == 0:
env_state.task_statuses[dependent_id] = TaskStatus.READY
if dependent_id not in env_state.ready_task_ids:
env_state.ready_task_ids.append(dependent_id)
if dependent_id in env_state.blocked_task_ids:
env_state.blocked_task_ids.remove(dependent_id)
unlocked.append(dependent_id)
env_state.ready_task_ids.sort()
env_state.blocked_task_ids.sort()
return unlocked
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> WorkflowArenaObservation:
"""Generate a seeded workflow DAG episode."""
preset_raw = kwargs.pop("preset", DifficultyPreset.EASY)
worker_count_raw = kwargs.pop("worker_count", None)
del kwargs
preset = (
preset_raw
if isinstance(preset_raw, DifficultyPreset)
else DifficultyPreset(str(preset_raw))
)
preset_config = get_preset_config(preset)
chosen_seed = 0 if seed is None else seed
chosen_worker_count = (
preset_config.worker_count
if worker_count_raw is None
else int(worker_count_raw)
)
chosen_episode_id = str(uuid4()) if episode_id is None else episode_id
self._state = State(episode_id=chosen_episode_id, step_count=0)
self._cumulative_reward = 0.0
self._config = EpisodeConfig(
preset=preset,
seed=chosen_seed,
worker_count=chosen_worker_count,
)
self._event_rng = random.Random(
(chosen_seed + 1) * 1009
+ (chosen_worker_count * 131)
+ (list(DifficultyPreset).index(preset) + 1)
)
self._episode_spec, self._env_state = generate_episode(self._config)
self._max_episode_steps = max(
self.STEP_LIMIT_FLOOR,
len(self._episode_spec.tasks) * self.STEP_LIMIT_MULTIPLIER,
)
self._env_state.episode_id = chosen_episode_id
lower_bound = self._lower_bound_makespan(self._episode_spec)
if preset_config.time_budget_multiplier is not None:
self._env_state.time_budget = int(
math.ceil(lower_bound * preset_config.time_budget_multiplier)
)
self._set_recent_failure_events(self._env_state, [])
note = "Workflow episode generated. Dispatch ready tasks or wait for completions."
if self._env_state.time_budget is not None:
note = (
f"Workflow episode generated. Finish as much as possible before "
f"t={self._env_state.time_budget}."
)
if preset == DifficultyPreset.HARD:
note += " Hard mode may trigger worker outages and retry failures."
return self._base_observation(
reward=0.0,
breakdown=RewardBreakdown(),
note=note,
done=False,
)
def _wait_note(
self,
*,
completed_now: list[str],
failed_now: list[str],
unlocked: list[str],
recent_events: list[WorkflowFailureEvent],
time_budget_hit: bool = False,
) -> str:
chunks: list[str] = []
if time_budget_hit:
chunks.append("Time budget exhausted before the next completion event.")
elif completed_now:
chunks.append(f"Completed: {', '.join(completed_now)}.")
else:
chunks.append("Advanced to next completion event.")
if failed_now:
chunks.append(f"Retry required: {', '.join(failed_now)}.")
if unlocked:
chunks.append(f"Unlocked: {', '.join(unlocked)}.")
for event in recent_events:
if event.event_type == FailureEventType.WORKER_OUTAGE_START:
chunks.append(event.detail)
elif event.event_type == FailureEventType.WORKER_OUTAGE_END:
chunks.append("Worker capacity restored.")
return " ".join(chunks)
def step(
self,
action: WorkflowArenaAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> WorkflowArenaObservation:
"""Apply a dispatch or wait action using event-driven semantics."""
del timeout_s, kwargs
episode, env_state = self._require_episode()
task_map = {task.task_id: task for task in episode.tasks}
self._state.step_count += 1
self._set_recent_failure_events(env_state, [])
limit_termination = self._maybe_terminate_for_limits()
if limit_termination is not None:
return limit_termination
if action.action_type == WorkflowActionType.WAIT and action.task_ids:
return self._apply_invalid("wait() must not include task_ids.")
if action.action_type == WorkflowActionType.DISPATCH:
if not action.task_ids:
return self._apply_invalid(
"dispatch(task_ids=[...]) requires at least one task id."
)
if len(set(action.task_ids)) != len(action.task_ids):
return self._apply_invalid(
"dispatch(task_ids=[...]) must not contain duplicate task ids."
)
free_workers = self._effective_worker_capacity(env_state) - len(
env_state.running_task_ids
)
if len(action.task_ids) > max(0, free_workers):
return self._apply_invalid(
"dispatch(task_ids=[...]) cannot exceed available worker capacity.",
penalty=self.OVERCAPACITY_INVALID_ACTION_PENALTY,
)
unknown_tasks = [task_id for task_id in action.task_ids if task_id not in task_map]
if unknown_tasks:
return self._apply_invalid(f"Unknown task ids: {unknown_tasks}.")
not_ready = [
task_id
for task_id in action.task_ids
if env_state.task_statuses[task_id] != TaskStatus.READY
]
if not_ready:
return self._apply_invalid(
f"Only ready tasks can be dispatched: {not_ready}."
)
prev_utilization_potential, prev_criticality_potential = self._dispatch_potential(
env_state, task_map
)
for task_id in action.task_ids:
task = task_map[task_id]
env_state.task_statuses[task_id] = TaskStatus.RUNNING
env_state.task_start_times[task_id] = env_state.current_time
env_state.task_assigned_finish_times[task_id] = (
env_state.current_time + task.duration
)
env_state.running_task_ids.append(task_id)
env_state.ready_task_ids.remove(task_id)
env_state.running_task_ids.sort()
next_utilization_potential, next_criticality_potential = self._dispatch_potential(
env_state, task_map
)
breakdown = RewardBreakdown(
utilization_reward=round(
next_utilization_potential - prev_utilization_potential, 4
),
criticality_reward=round(
next_criticality_potential - prev_criticality_potential, 4
),
)
reward = round(
breakdown.utilization_reward + breakdown.criticality_reward,
4,
)
self._cumulative_reward += reward
observation = self._base_observation(
reward=reward,
breakdown=breakdown,
note="Tasks dispatched. Use wait() to advance to the next completion event.",
done=False,
)
observation.received_action = action.model_dump(mode="json")
observation.metadata["received_action"] = action.model_dump(mode="json")
return observation
if not env_state.running_task_ids:
return self._apply_invalid("wait() requires at least one running task.")
recent_events: list[WorkflowFailureEvent] = []
avoidable_wait_penalty = 0.0
if env_state.ready_task_ids:
free_workers = self._effective_worker_capacity(env_state) - len(
env_state.running_task_ids
)
if free_workers > 0:
avoidable_wait_penalty = self.AVOIDABLE_WAIT_PENALTY_PER_SLOT * min(
free_workers,
len(env_state.ready_task_ids),
)
self._maybe_start_worker_outage(env_state, recent_events)
next_completion_time = min(
env_state.task_assigned_finish_times[task_id]
for task_id in env_state.running_task_ids
)
target_time = next_completion_time
budget_hit_before_completion = False
if env_state.time_budget is not None and env_state.time_budget < next_completion_time:
target_time = env_state.time_budget
budget_hit_before_completion = True
elapsed = target_time - env_state.current_time
env_state.cumulative_busy_time += elapsed * len(env_state.running_task_ids)
env_state.current_time = target_time
self._maybe_end_worker_outage(env_state, recent_events)
if budget_hit_before_completion:
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
terminal_score = self._terminal_score()
breakdown = RewardBreakdown(
idle_penalty=round(avoidable_wait_penalty, 4),
terminal_makespan_score=round(terminal_score, 4),
unfinished_task_penalty=round(unfinished_penalty, 4),
)
reward = round(
breakdown.idle_penalty
+ breakdown.terminal_makespan_score
+ breakdown.unfinished_task_penalty,
4,
)
self._set_recent_failure_events(env_state, recent_events)
note = self._wait_note(
completed_now=[],
failed_now=[],
unlocked=[],
recent_events=recent_events,
time_budget_hit=True,
)
observation = self._terminate_episode(
note=note,
breakdown=breakdown,
reward=reward,
reason="time_budget",
)
observation.received_action = action.model_dump(mode="json")
observation.metadata["received_action"] = action.model_dump(mode="json")
return observation
completed_candidates = sorted(
[
task_id
for task_id in env_state.running_task_ids
if env_state.task_assigned_finish_times[task_id] == next_completion_time
]
)
completed_now: list[str] = []
failed_now: list[str] = []
for task_id in completed_candidates:
env_state.running_task_ids.remove(task_id)
del env_state.task_assigned_finish_times[task_id]
if self._should_retry_fail(task_id):
env_state.task_attempt_counts[task_id] += 1
env_state.task_statuses[task_id] = TaskStatus.READY
env_state.task_start_times.pop(task_id, None)
env_state.task_end_times.pop(task_id, None)
if task_id not in env_state.ready_task_ids:
env_state.ready_task_ids.append(task_id)
failed_now.append(task_id)
recent_events.append(
WorkflowFailureEvent(
event_type=FailureEventType.TASK_RETRY_FAILURE,
time=next_completion_time,
task_id=task_id,
detail=f"{task_id} failed and returned to ready.",
)
)
else:
env_state.task_statuses[task_id] = TaskStatus.COMPLETED
env_state.task_end_times[task_id] = next_completion_time
env_state.completed_task_ids.append(task_id)
completed_now.append(task_id)
env_state.completed_task_ids.sort()
env_state.ready_task_ids.sort()
unlocked = self._transition_unlocks(completed_now)
completion_reward = sum(
0.04 + 0.01 * task_map[task_id].priority for task_id in completed_now
)
deadline_reward = 0.0
criticality_reward = 0.0
for task_id in completed_now:
task = task_map[task_id]
if task.deadline is not None:
lateness = next_completion_time - task.deadline
deadline_reward += 0.05 if lateness <= 0 else -0.02 * lateness
criticality_reward += 0.03 * task.criticality
utilization_reward = 0.06 * (
elapsed
* (len(completed_candidates) + len(env_state.running_task_ids))
/ max(1, self._config.worker_count)
)
idle_penalty = 0.0
if not env_state.running_task_ids and env_state.ready_task_ids:
idle_penalty = -0.03 * len(env_state.ready_task_ids)
done = len(env_state.completed_task_ids) == len(episode.tasks)
breakdown = RewardBreakdown(
completion_reward=round(completion_reward, 4),
utilization_reward=round(utilization_reward, 4),
deadline_reward=round(deadline_reward, 4),
criticality_reward=round(criticality_reward, 4),
idle_penalty=round(idle_penalty + avoidable_wait_penalty, 4),
terminal_makespan_score=round(self._terminal_score() if done else 0.0, 4),
)
reward = round(
breakdown.completion_reward
+ breakdown.utilization_reward
+ breakdown.deadline_reward
+ breakdown.criticality_reward
+ breakdown.idle_penalty
+ breakdown.terminal_makespan_score,
4,
)
budget_exhausted_now = (
not done
and env_state.time_budget is not None
and env_state.current_time >= env_state.time_budget
)
if budget_exhausted_now:
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
breakdown.unfinished_task_penalty = round(unfinished_penalty, 4)
breakdown.terminal_makespan_score = round(self._terminal_score(), 4)
reward = round(
reward
+ breakdown.unfinished_task_penalty
+ breakdown.terminal_makespan_score,
4,
)
self._set_recent_failure_events(env_state, recent_events)
note = self._wait_note(
completed_now=completed_now,
failed_now=failed_now,
unlocked=unlocked,
recent_events=recent_events,
time_budget_hit=True,
)
observation = self._terminate_episode(
note=note,
breakdown=breakdown,
reward=reward,
reason="time_budget",
)
observation.received_action = action.model_dump(mode="json")
observation.metadata["received_action"] = action.model_dump(mode="json")
observation.metadata["completed_now"] = completed_now
observation.metadata["unlocked_now"] = unlocked
observation.metadata["failed_now"] = failed_now
return observation
self._cumulative_reward += reward
self._set_recent_failure_events(env_state, recent_events)
observation = self._base_observation(
reward=reward,
breakdown=breakdown,
note=self._wait_note(
completed_now=completed_now,
failed_now=failed_now,
unlocked=unlocked,
recent_events=recent_events,
),
done=done,
benchmark_score_override=self._benchmark_score() if done else None,
)
observation.received_action = action.model_dump(mode="json")
observation.metadata["received_action"] = action.model_dump(mode="json")
observation.metadata["completed_now"] = completed_now
observation.metadata["unlocked_now"] = unlocked
observation.metadata["failed_now"] = failed_now
if done:
observation.benchmark_score = observation.success_metrics.benchmark_score
return observation
@property
def state(self) -> State:
"""Expose generic OpenEnv state metadata."""
return self._state