Spaces:
Sleeping
Sleeping
feat: enhance model client configuration and local image handling in inference module
86dae99 verified | from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import subprocess | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from openai import OpenAI | |
| from workflow_arena import WorkflowArenaAction, WorkflowArenaEnv | |
| from workflow_arena.models import ( | |
| DifficultyPreset, | |
| WorkflowActionType, | |
| WorkflowArenaObservation, | |
| WorkflowTaskView, | |
| ) | |
| BENCHMARK = "WorkflowArena" | |
| PRESETS = [ | |
| DifficultyPreset.EASY, | |
| DifficultyPreset.MEDIUM, | |
| DifficultyPreset.HARD, | |
| ] | |
| PROJECT_DIR = Path(__file__).resolve().parent | |
| IMAGE_NAME = "workflow-arena-inference:latest" | |
| DOCKERFILE_PATH = PROJECT_DIR / "server" / "Dockerfile" | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "qwen/qwen3.5-9b") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| DEFAULT_BASE_URL = os.getenv("WORKFLOW_ARENA_BASE_URL", "http://localhost:8000") | |
| TEMPERATURE = 0.0 | |
| MAX_STEPS = 256 | |
| SYSTEM_PROMPT = ( | |
| "You are scheduling a dependency-constrained workflow on limited workers. " | |
| "Respond with compact JSON only. " | |
| 'Valid formats: {"action_type":"wait","task_ids":[]} or ' | |
| '{"action_type":"dispatch","task_ids":["task_01","task_02"]}. ' | |
| "Only dispatch task ids that appear in ready_tasks for the current observation. " | |
| "Never exceed free_workers. " | |
| 'If free_workers is 0 and running_tasks is non-empty, respond with {"action_type":"wait","task_ids":[]}. ' | |
| "If your previous action was invalid, use validation_error to correct it while still reasoning from the current observation. " | |
| "Never repeat a previously dispatched task unless it still appears in ready_tasks." | |
| ) | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: str | None | |
| ) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def log_warning(message: str) -> None: | |
| print(f"[WARN] {message}", flush=True) | |
| def compact_task(task: WorkflowTaskView) -> dict[str, object]: | |
| return { | |
| "task_id": task.task_id, | |
| "duration": task.duration, | |
| "priority": task.priority, | |
| "deadline": task.deadline, | |
| "criticality": task.criticality, | |
| "slack": task.slack, | |
| "downstream_count": task.downstream_count, | |
| "dependencies": task.dependencies, | |
| "attempt_count": task.attempt_count, | |
| } | |
| def make_user_prompt(observation: WorkflowArenaObservation) -> str: | |
| must_wait = observation.free_workers == 0 and bool(observation.running_tasks) | |
| return json.dumps( | |
| { | |
| "instruction": observation.instruction, | |
| "current_time": observation.current_time, | |
| "effective_workers": observation.effective_workers, | |
| "degraded_workers": observation.degraded_workers, | |
| "free_workers": observation.free_workers, | |
| "time_budget": observation.time_budget, | |
| "time_remaining": observation.time_remaining, | |
| "must_wait": must_wait, | |
| "ready_tasks": [compact_task(task) for task in observation.ready_tasks], | |
| "running_tasks": [compact_task(task) for task in observation.running_tasks], | |
| "progress": observation.progress.model_dump(mode="json"), | |
| "reward_breakdown": observation.last_reward_breakdown.model_dump( | |
| mode="json" | |
| ), | |
| "note": observation.note, | |
| "validation_error": observation.validation_error, | |
| "recent_failure_events": [ | |
| event.model_dump(mode="json") | |
| for event in observation.recent_failure_events | |
| ], | |
| "last_action": observation.received_action, | |
| }, | |
| separators=(",", ":"), | |
| ) | |
| def heuristic_action(observation: WorkflowArenaObservation) -> WorkflowArenaAction: | |
| if observation.free_workers <= 0 and observation.running_tasks: | |
| return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[]) | |
| if not observation.ready_tasks or observation.free_workers <= 0: | |
| return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[]) | |
| time_remaining = observation.time_remaining | |
| ranked = sorted( | |
| observation.ready_tasks, | |
| key=lambda task: ( | |
| time_remaining is not None and task.duration > time_remaining, | |
| max(0, task.duration - time_remaining) if time_remaining is not None else 0, | |
| task.deadline if task.deadline is not None else 10**9, | |
| -(task.criticality or 0.0), | |
| -task.priority, | |
| task.duration, | |
| task.task_id, | |
| ), | |
| ) | |
| selected = [task.task_id for task in ranked[: observation.free_workers]] | |
| return WorkflowArenaAction( | |
| action_type=WorkflowActionType.DISPATCH, | |
| task_ids=selected, | |
| ) | |
| def parse_action( | |
| text: str, observation: WorkflowArenaObservation | |
| ) -> WorkflowArenaAction: | |
| text = text.strip() | |
| if not text: | |
| raise ValueError("Model response did not include JSON action") | |
| payload = json.loads(text) | |
| return WorkflowArenaAction.model_validate(payload) | |
| def get_model_action( | |
| client: OpenAI, | |
| model_name: str, | |
| observation: WorkflowArenaObservation, | |
| ) -> WorkflowArenaAction: | |
| prompt = make_user_prompt(observation) | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=120, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| return parse_action(text, observation) | |
| def action_to_log_string(action: WorkflowArenaAction) -> str: | |
| payload = action.model_dump(mode="json") | |
| if payload.get("metadata") == {}: | |
| payload.pop("metadata", None) | |
| return json.dumps(payload, separators=(",", ":")) | |
| def resolve_model_client() -> tuple[OpenAI | None, str]: | |
| api_key = ( | |
| os.getenv("API_KEY") | |
| or HF_TOKEN | |
| or os.getenv("OPENAI_API_KEY") | |
| ) | |
| missing = [] | |
| if not api_key: | |
| missing.append("API_KEY or HF_TOKEN") | |
| if missing: | |
| log_warning( | |
| "Missing model configuration (" | |
| + ", ".join(missing) | |
| + "). Falling back to heuristic policy." | |
| ) | |
| return None, "heuristic" | |
| try: | |
| return OpenAI(base_url=API_BASE_URL, api_key=api_key), MODEL_NAME | |
| except Exception as exc: # pragma: no cover - defensive initialization fallback | |
| log_warning( | |
| f"Failed to initialize model client: {exc}. Falling back to heuristic policy." | |
| ) | |
| return None, "heuristic" | |
| def compute_score(observation: WorkflowArenaObservation) -> float: | |
| score = observation.benchmark_score | |
| if score is None: | |
| score = observation.success_metrics.benchmark_score | |
| return max(0.0, min(1.0, float(score or 0.0))) | |
| def is_success(observation: WorkflowArenaObservation) -> bool: | |
| return bool( | |
| observation.done | |
| and observation.success_metrics.makespan is not None | |
| and observation.termination_reason is None | |
| ) | |
| class EpisodeResult: | |
| success: bool | |
| steps: int | |
| score: float | |
| rewards: list[float] | |
| def ensure_local_image() -> None: | |
| local_image_name = LOCAL_IMAGE_NAME or IMAGE_NAME | |
| try: | |
| inspect_result = subprocess.run( | |
| ["docker", "image", "inspect", local_image_name], | |
| cwd=PROJECT_DIR, | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| check=False, | |
| ) | |
| except OSError as exc: | |
| raise RuntimeError(f"Failed to execute docker: {exc}") from exc | |
| if inspect_result.returncode == 0: | |
| return | |
| try: | |
| build_result = subprocess.run( | |
| ["docker", "build", "-t", local_image_name, "-f", str(DOCKERFILE_PATH), "."], | |
| cwd=PROJECT_DIR, | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| ) | |
| except OSError as exc: | |
| raise RuntimeError(f"Failed to execute docker build: {exc}") from exc | |
| if build_result.returncode != 0: | |
| raise RuntimeError( | |
| "Failed to build Docker image for inference.\n" | |
| f"Command: docker build -t {local_image_name} -f {DOCKERFILE_PATH} .\n" | |
| f"Exit code: {build_result.returncode}\n" | |
| f"Stdout: {build_result.stdout}\n" | |
| f"Stderr: {build_result.stderr}" | |
| ) | |
| async def managed_env(): | |
| try: | |
| async with WorkflowArenaEnv(base_url=DEFAULT_BASE_URL) as env: | |
| yield env | |
| return | |
| except Exception as exc: | |
| log_warning( | |
| f"Failed to connect to environment at {DEFAULT_BASE_URL}: {exc}. " | |
| "Trying local Docker fallback." | |
| ) | |
| ensure_local_image() | |
| env = await WorkflowArenaEnv.from_docker_image(LOCAL_IMAGE_NAME or IMAGE_NAME) | |
| try: | |
| yield env | |
| finally: | |
| try: | |
| await env.close() | |
| except Exception as exc: # pragma: no cover - teardown failures should not fail inference | |
| log_warning(f"Failed to close Docker environment cleanly: {exc}") | |
| async def run_episode( | |
| env, | |
| client: OpenAI | None, | |
| model_name: str, | |
| preset: DifficultyPreset, | |
| seed: int, | |
| ) -> EpisodeResult: | |
| rewards: list[float] = [] | |
| steps_taken = 0 | |
| success = False | |
| score = 0.0 | |
| log_start(task=preset.value, env=BENCHMARK, model=model_name) | |
| try: | |
| result = await env.reset( | |
| seed=seed, | |
| preset=preset.value, | |
| ) | |
| except Exception as exc: # pragma: no cover - env availability failures are external | |
| log_warning(f"Failed to reset preset={preset.value}: {exc}") | |
| log_end(success=False, steps=steps_taken, score=score, rewards=rewards) | |
| return EpisodeResult( | |
| success=success, steps=steps_taken, score=score, rewards=rewards | |
| ) | |
| observation = result.observation | |
| while not observation.done and steps_taken < MAX_STEPS: | |
| try: | |
| if client is None: | |
| action = heuristic_action(observation) | |
| else: | |
| action = get_model_action(client, model_name, observation) | |
| except ( | |
| Exception | |
| ): # pragma: no cover - network/model failures are expected sometimes | |
| action = heuristic_action(observation) | |
| try: | |
| result = await env.step(action) | |
| except Exception as exc: # pragma: no cover - preserve log format and continue safely | |
| fallback_action = heuristic_action(observation) | |
| if fallback_action != action: | |
| log_warning( | |
| f"Step failed for preset={preset.value} with model action: {exc}. " | |
| "Retrying with heuristic action." | |
| ) | |
| action = fallback_action | |
| try: | |
| result = await env.step(action) | |
| except Exception as retry_exc: | |
| log_warning( | |
| f"Step failed for preset={preset.value} even with heuristic action: {retry_exc}" | |
| ) | |
| break | |
| observation = result.observation | |
| reward = float(result.reward or 0.0) | |
| rewards.append(reward) | |
| steps_taken += 1 | |
| log_step( | |
| step=steps_taken, | |
| action=action_to_log_string(action), | |
| reward=reward, | |
| done=bool(result.done), | |
| error=observation.validation_error, | |
| ) | |
| success = is_success(observation) | |
| score = compute_score(observation) if observation.done else 0.0 | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return EpisodeResult( | |
| success=success, steps=steps_taken, score=score, rewards=rewards | |
| ) | |
| async def main() -> None: | |
| client, model_name = resolve_model_client() | |
| async with managed_env() as env: | |
| for index, preset in enumerate(PRESETS): | |
| await run_episode( | |
| env=env, | |
| client=client, | |
| model_name=model_name, | |
| preset=preset, | |
| seed=100 + index, | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| asyncio.run(main()) | |
| except Exception as exc: # pragma: no cover - final safeguard for validator stability | |
| log_warning(f"Fatal inference error: {exc}") | |