workflow_arena / inference.py
Cyber-Machine's picture
feat: enhance model client configuration and local image handling in inference module
86dae99 verified
from __future__ import annotations
import asyncio
import json
import os
import subprocess
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from openai import OpenAI
from workflow_arena import WorkflowArenaAction, WorkflowArenaEnv
from workflow_arena.models import (
DifficultyPreset,
WorkflowActionType,
WorkflowArenaObservation,
WorkflowTaskView,
)
BENCHMARK = "WorkflowArena"
PRESETS = [
DifficultyPreset.EASY,
DifficultyPreset.MEDIUM,
DifficultyPreset.HARD,
]
PROJECT_DIR = Path(__file__).resolve().parent
IMAGE_NAME = "workflow-arena-inference:latest"
DOCKERFILE_PATH = PROJECT_DIR / "server" / "Dockerfile"
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "qwen/qwen3.5-9b")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
DEFAULT_BASE_URL = os.getenv("WORKFLOW_ARENA_BASE_URL", "http://localhost:8000")
TEMPERATURE = 0.0
MAX_STEPS = 256
SYSTEM_PROMPT = (
"You are scheduling a dependency-constrained workflow on limited workers. "
"Respond with compact JSON only. "
'Valid formats: {"action_type":"wait","task_ids":[]} or '
'{"action_type":"dispatch","task_ids":["task_01","task_02"]}. '
"Only dispatch task ids that appear in ready_tasks for the current observation. "
"Never exceed free_workers. "
'If free_workers is 0 and running_tasks is non-empty, respond with {"action_type":"wait","task_ids":[]}. '
"If your previous action was invalid, use validation_error to correct it while still reasoning from the current observation. "
"Never repeat a previously dispatched task unless it still appears in ready_tasks."
)
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: str | None
) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
def log_warning(message: str) -> None:
print(f"[WARN] {message}", flush=True)
def compact_task(task: WorkflowTaskView) -> dict[str, object]:
return {
"task_id": task.task_id,
"duration": task.duration,
"priority": task.priority,
"deadline": task.deadline,
"criticality": task.criticality,
"slack": task.slack,
"downstream_count": task.downstream_count,
"dependencies": task.dependencies,
"attempt_count": task.attempt_count,
}
def make_user_prompt(observation: WorkflowArenaObservation) -> str:
must_wait = observation.free_workers == 0 and bool(observation.running_tasks)
return json.dumps(
{
"instruction": observation.instruction,
"current_time": observation.current_time,
"effective_workers": observation.effective_workers,
"degraded_workers": observation.degraded_workers,
"free_workers": observation.free_workers,
"time_budget": observation.time_budget,
"time_remaining": observation.time_remaining,
"must_wait": must_wait,
"ready_tasks": [compact_task(task) for task in observation.ready_tasks],
"running_tasks": [compact_task(task) for task in observation.running_tasks],
"progress": observation.progress.model_dump(mode="json"),
"reward_breakdown": observation.last_reward_breakdown.model_dump(
mode="json"
),
"note": observation.note,
"validation_error": observation.validation_error,
"recent_failure_events": [
event.model_dump(mode="json")
for event in observation.recent_failure_events
],
"last_action": observation.received_action,
},
separators=(",", ":"),
)
def heuristic_action(observation: WorkflowArenaObservation) -> WorkflowArenaAction:
if observation.free_workers <= 0 and observation.running_tasks:
return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])
if not observation.ready_tasks or observation.free_workers <= 0:
return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])
time_remaining = observation.time_remaining
ranked = sorted(
observation.ready_tasks,
key=lambda task: (
time_remaining is not None and task.duration > time_remaining,
max(0, task.duration - time_remaining) if time_remaining is not None else 0,
task.deadline if task.deadline is not None else 10**9,
-(task.criticality or 0.0),
-task.priority,
task.duration,
task.task_id,
),
)
selected = [task.task_id for task in ranked[: observation.free_workers]]
return WorkflowArenaAction(
action_type=WorkflowActionType.DISPATCH,
task_ids=selected,
)
def parse_action(
text: str, observation: WorkflowArenaObservation
) -> WorkflowArenaAction:
text = text.strip()
if not text:
raise ValueError("Model response did not include JSON action")
payload = json.loads(text)
return WorkflowArenaAction.model_validate(payload)
def get_model_action(
client: OpenAI,
model_name: str,
observation: WorkflowArenaObservation,
) -> WorkflowArenaAction:
prompt = make_user_prompt(observation)
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=TEMPERATURE,
max_tokens=120,
)
text = (completion.choices[0].message.content or "").strip()
return parse_action(text, observation)
def action_to_log_string(action: WorkflowArenaAction) -> str:
payload = action.model_dump(mode="json")
if payload.get("metadata") == {}:
payload.pop("metadata", None)
return json.dumps(payload, separators=(",", ":"))
def resolve_model_client() -> tuple[OpenAI | None, str]:
api_key = (
os.getenv("API_KEY")
or HF_TOKEN
or os.getenv("OPENAI_API_KEY")
)
missing = []
if not api_key:
missing.append("API_KEY or HF_TOKEN")
if missing:
log_warning(
"Missing model configuration ("
+ ", ".join(missing)
+ "). Falling back to heuristic policy."
)
return None, "heuristic"
try:
return OpenAI(base_url=API_BASE_URL, api_key=api_key), MODEL_NAME
except Exception as exc: # pragma: no cover - defensive initialization fallback
log_warning(
f"Failed to initialize model client: {exc}. Falling back to heuristic policy."
)
return None, "heuristic"
def compute_score(observation: WorkflowArenaObservation) -> float:
score = observation.benchmark_score
if score is None:
score = observation.success_metrics.benchmark_score
return max(0.0, min(1.0, float(score or 0.0)))
def is_success(observation: WorkflowArenaObservation) -> bool:
return bool(
observation.done
and observation.success_metrics.makespan is not None
and observation.termination_reason is None
)
@dataclass
class EpisodeResult:
success: bool
steps: int
score: float
rewards: list[float]
def ensure_local_image() -> None:
local_image_name = LOCAL_IMAGE_NAME or IMAGE_NAME
try:
inspect_result = subprocess.run(
["docker", "image", "inspect", local_image_name],
cwd=PROJECT_DIR,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
)
except OSError as exc:
raise RuntimeError(f"Failed to execute docker: {exc}") from exc
if inspect_result.returncode == 0:
return
try:
build_result = subprocess.run(
["docker", "build", "-t", local_image_name, "-f", str(DOCKERFILE_PATH), "."],
cwd=PROJECT_DIR,
capture_output=True,
text=True,
check=False,
)
except OSError as exc:
raise RuntimeError(f"Failed to execute docker build: {exc}") from exc
if build_result.returncode != 0:
raise RuntimeError(
"Failed to build Docker image for inference.\n"
f"Command: docker build -t {local_image_name} -f {DOCKERFILE_PATH} .\n"
f"Exit code: {build_result.returncode}\n"
f"Stdout: {build_result.stdout}\n"
f"Stderr: {build_result.stderr}"
)
@asynccontextmanager
async def managed_env():
try:
async with WorkflowArenaEnv(base_url=DEFAULT_BASE_URL) as env:
yield env
return
except Exception as exc:
log_warning(
f"Failed to connect to environment at {DEFAULT_BASE_URL}: {exc}. "
"Trying local Docker fallback."
)
ensure_local_image()
env = await WorkflowArenaEnv.from_docker_image(LOCAL_IMAGE_NAME or IMAGE_NAME)
try:
yield env
finally:
try:
await env.close()
except Exception as exc: # pragma: no cover - teardown failures should not fail inference
log_warning(f"Failed to close Docker environment cleanly: {exc}")
async def run_episode(
env,
client: OpenAI | None,
model_name: str,
preset: DifficultyPreset,
seed: int,
) -> EpisodeResult:
rewards: list[float] = []
steps_taken = 0
success = False
score = 0.0
log_start(task=preset.value, env=BENCHMARK, model=model_name)
try:
result = await env.reset(
seed=seed,
preset=preset.value,
)
except Exception as exc: # pragma: no cover - env availability failures are external
log_warning(f"Failed to reset preset={preset.value}: {exc}")
log_end(success=False, steps=steps_taken, score=score, rewards=rewards)
return EpisodeResult(
success=success, steps=steps_taken, score=score, rewards=rewards
)
observation = result.observation
while not observation.done and steps_taken < MAX_STEPS:
try:
if client is None:
action = heuristic_action(observation)
else:
action = get_model_action(client, model_name, observation)
except (
Exception
): # pragma: no cover - network/model failures are expected sometimes
action = heuristic_action(observation)
try:
result = await env.step(action)
except Exception as exc: # pragma: no cover - preserve log format and continue safely
fallback_action = heuristic_action(observation)
if fallback_action != action:
log_warning(
f"Step failed for preset={preset.value} with model action: {exc}. "
"Retrying with heuristic action."
)
action = fallback_action
try:
result = await env.step(action)
except Exception as retry_exc:
log_warning(
f"Step failed for preset={preset.value} even with heuristic action: {retry_exc}"
)
break
observation = result.observation
reward = float(result.reward or 0.0)
rewards.append(reward)
steps_taken += 1
log_step(
step=steps_taken,
action=action_to_log_string(action),
reward=reward,
done=bool(result.done),
error=observation.validation_error,
)
success = is_success(observation)
score = compute_score(observation) if observation.done else 0.0
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return EpisodeResult(
success=success, steps=steps_taken, score=score, rewards=rewards
)
async def main() -> None:
client, model_name = resolve_model_client()
async with managed_env() as env:
for index, preset in enumerate(PRESETS):
await run_episode(
env=env,
client=client,
model_name=model_name,
preset=preset,
seed=100 + index,
)
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as exc: # pragma: no cover - final safeguard for validator stability
log_warning(f"Fatal inference error: {exc}")