# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Seeded workflow DAG generator and derived static metrics for WorkflowArena.""" from __future__ import annotations import random from workflow_arena.models import ( EpisodeConfig, TaskStatus, WorkflowEnvStateSnapshot, WorkflowEpisodeSpec, WorkflowTaskSpec, ) from workflow_arena.presets import get_preset_config def _task_id(index: int) -> str: return f"task_{index:02d}" def _compute_earliest_start(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int: task = task_map[task_id] if not task.dependencies: return 0 return max( _compute_earliest_start(task_map, dep_id) + task_map[dep_id].duration for dep_id in task.dependencies ) def _compute_critical_path(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int: task = task_map[task_id] if not task.dependents: return task.duration return task.duration + max( _compute_critical_path(task_map, child_id) for child_id in task.dependents ) def _compute_downstream_count( task_map: dict[str, WorkflowTaskSpec], task_id: str, seen: set[str] | None = None ) -> int: task = task_map[task_id] local_seen = set() if seen is None else seen count = 0 for child_id in task.dependents: if child_id in local_seen: continue local_seen.add(child_id) count += 1 + _compute_downstream_count(task_map, child_id, local_seen) return count def _estimate_deadline( task: WorkflowTaskSpec, workflow_critical_path: int, rng: random.Random, tightness: float, ) -> int: slack_allowance = max(1, int(round((workflow_critical_path - task.earliest_start) * (1.15 - tightness)))) jitter = rng.randint(0, max(1, task.duration // 2)) return task.earliest_start + task.duration + slack_allowance + jitter def generate_episode( config: EpisodeConfig, ) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]: """Generate a deterministic workflow episode from a preset and seed.""" preset_config = get_preset_config(config.preset) worker_count = config.worker_count or preset_config.worker_count resolved_config = config.model_copy(update={"worker_count": worker_count}) rng = random.Random(resolved_config.seed) task_count = rng.randint(preset_config.min_tasks, preset_config.max_tasks) dependency_map: dict[str, list[str]] = {} dependent_map: dict[str, list[str]] = {} task_ids = [_task_id(index + 1) for index in range(task_count)] for index, task_id in enumerate(task_ids): candidates = task_ids[:index] dependencies: list[str] = [] if candidates: for candidate in candidates: if rng.random() < preset_config.edge_probability: dependencies.append(candidate) if not dependencies and index > 0 and rng.random() < 0.6: dependencies.append(rng.choice(candidates)) dependency_map[task_id] = sorted(set(dependencies), key=task_ids.index) dependent_map[task_id] = [] for task_id, dependencies in dependency_map.items(): for dependency in dependencies: dependent_map[dependency].append(task_id) tasks = [ WorkflowTaskSpec( task_id=task_id, duration=rng.randint(preset_config.duration_min, preset_config.duration_max), priority=rng.randint(preset_config.priority_min, preset_config.priority_max), dependencies=dependency_map[task_id], dependents=sorted(dependent_map[task_id], key=task_ids.index), deadline=None, ) for task_id in task_ids ] task_map = {task.task_id: task for task in tasks} workflow_critical_path = 0 for task in tasks: task.earliest_start = _compute_earliest_start(task_map, task.task_id) task.critical_path_length = _compute_critical_path(task_map, task.task_id) task.downstream_count = _compute_downstream_count(task_map, task.task_id) workflow_critical_path = max( workflow_critical_path, task.earliest_start + task.duration ) workflow_critical_path = max( workflow_critical_path, max(task.critical_path_length for task in tasks), ) max_downstream = max(task.downstream_count for task in tasks) if tasks else 1 max_critical_path = max(task.critical_path_length for task in tasks) if tasks else 1 for task in tasks: latest_start = max( task.earliest_start, workflow_critical_path - task.critical_path_length ) task.slack = max(0, latest_start - task.earliest_start) task.criticality = round( 0.7 * (task.critical_path_length / max_critical_path) + 0.3 * (task.downstream_count / max(1, max_downstream)), 4, ) task.deadline = _estimate_deadline( task=task, workflow_critical_path=workflow_critical_path, rng=rng, tightness=preset_config.deadline_tightness, ) episode = WorkflowEpisodeSpec( config=resolved_config, preset_config=preset_config, tasks=tasks, ) ready_task_ids = [task.task_id for task in tasks if not task.dependencies] blocked_task_ids = [task.task_id for task in tasks if task.dependencies] state = WorkflowEnvStateSnapshot( episode_id=f"seed-{resolved_config.seed}", current_time=0, task_statuses={ task.task_id: ( TaskStatus.READY if not task.dependencies else TaskStatus.BLOCKED ) for task in tasks }, running_task_ids=[], completed_task_ids=[], ready_task_ids=ready_task_ids, blocked_task_ids=blocked_task_ids, task_start_times={}, task_end_times={}, task_remaining_dependencies={ task.task_id: len(task.dependencies) for task in tasks }, task_assigned_finish_times={}, task_attempt_counts={task.task_id: 0 for task in tasks}, cumulative_busy_time=0, time_budget=None, degraded_workers=0, active_worker_outage_until=None, recent_failure_events=[], ) return episode, state