#!/usr/bin/env python3 """InferenceGym submission runner. Expected environment variables for judged LLM path: - API_BASE_URL - MODEL_NAME - HF_TOKEN """ from __future__ import annotations import json import os import re import sys from typing import Any from openai import OpenAI sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from llmserve_env.models import ServeAction, default_action # noqa: E402 from server.grader import GraderEngine # noqa: E402 from server.llmserve_environment import LLMServeEnvironment # noqa: E402 API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") HF_TOKEN = os.getenv("HF_TOKEN") LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") DEFAULT_SEED = int(os.getenv("SEED", "42")) MAX_STEPS = int(os.getenv("MAX_STEPS", "60")) ENV_NAME = "InferenceGym" TASKS = ["static_workload", "bursty_workload", "adversarial_multitenant"] SYSTEM_PROMPT = ( "You are controlling an LLM serving environment. " "Return exactly one JSON object with these keys: " "batch_cap (1..512), kv_budget_fraction (0.1..1.0), speculation_depth (0..8), " "quantization_tier (FP16|INT8|INT4), prefill_decode_split (bool), priority_routing (bool). " "Do not include markdown or extra text." ) def _action_dict(action: ServeAction) -> dict[str, Any]: payload = action.model_dump(mode="json") payload.pop("metadata", None) return payload def _create_fallback_agent(task_id: str): try: from agents.ppo_agent import PPOAgent, find_weights weights_path = find_weights(task_id) if weights_path: return PPOAgent(weights_path) except Exception: pass from server.baseline_agent import HeuristicPolicy return HeuristicPolicy() def _create_client() -> OpenAI | None: if not HF_TOKEN: return None return OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL) def _parse_action_payload(raw: str) -> dict[str, Any] | None: candidate = raw.strip() if candidate.startswith("```"): candidate = re.sub(r"^```(?:json)?\s*|\s*```$", "", candidate, flags=re.IGNORECASE | re.DOTALL).strip() start = candidate.find("{") end = candidate.rfind("}") if start != -1 and end != -1 and end > start: candidate = candidate[start : end + 1] try: parsed = json.loads(candidate) except json.JSONDecodeError: return None return parsed if isinstance(parsed, dict) else None def _llm_action(client: OpenAI, task_id: str, observation: Any, previous_action: dict[str, Any] | None) -> ServeAction: user_payload = { "task_id": task_id, "observation": observation.model_dump(mode="json"), "previous_action": previous_action, } response = client.chat.completions.create( model=MODEL_NAME, temperature=0, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))}, ], response_format={"type": "json_object"}, ) raw = response.choices[0].message.content or "{}" payload = _parse_action_payload(raw) if payload is None: return default_action() try: return ServeAction.model_validate(payload) except Exception: return default_action() def _sanitize_error(error: Exception | str | None) -> str: if error is None: return "null" text = str(error).strip() if not text: return "null" return text.replace("\n", " ").replace("\r", " ")[:220] def _log_start(task: str, env_name: str, model: str) -> None: print(f"[START] task={task} env={env_name} model={model}", flush=True) def _log_step(step: int, action: str, reward: float, done: bool, error: str) -> None: print( f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error}", flush=True, ) def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True, ) def _run_task(task_id: str, client: OpenAI | None) -> bool: model_label = MODEL_NAME if client is not None else "heuristic" _log_start(task=task_id, env_name=ENV_NAME, model=model_label) rewards: list[float] = [] steps_taken = 0 score = 0.0 success = False previous_action: dict[str, Any] | None = None env: LLMServeEnvironment | None = None grader: GraderEngine | None = None fallback_agent: Any = None try: env = LLMServeEnvironment(seed=DEFAULT_SEED, mode="sim") grader = GraderEngine() fallback_agent = _create_fallback_agent(task_id) if hasattr(fallback_agent, "reset"): fallback_agent.reset() observation = env.reset(seed=DEFAULT_SEED, task_id=task_id) task_cfg = env.task_config or {} configured_max_steps = int(task_cfg.get("max_steps", MAX_STEPS)) max_steps = min(configured_max_steps, MAX_STEPS) for step_idx in range(1, max_steps + 1): if client is not None: try: action = _llm_action(client, task_id, observation, previous_action) except Exception as exc: action = fallback_agent.act(observation, task_id) else: action = fallback_agent.act(observation, task_id) action_json = json.dumps(_action_dict(action), separators=(",", ":")) try: observation = env.step(action) reward = float(getattr(observation, "reward", 0.0) or 0.0) done = bool(getattr(observation, "done", False)) rewards.append(reward) steps_taken = step_idx _log_step(step=step_idx, action=action_json, reward=reward, done=done, error="null") previous_action = _action_dict(action) if done: break except Exception as exc: rewards.append(0.0) steps_taken = step_idx _log_step(step=step_idx, action=action_json, reward=0.0, done=True, error=_sanitize_error(exc)) break grade = grader.grade(env.export_episode_log()) if grader is not None else {"score": 0.0} score = float(grade.get("score", 0.0)) score = max(0.0, min(1.0, score)) success = score > 0.0 except Exception as exc: next_step = len(rewards) + 1 rewards.append(0.0) steps_taken = next_step _log_step(step=next_step, action="{}", reward=0.0, done=True, error=_sanitize_error(exc)) success = False finally: _log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return success def main() -> int: try: client = _create_client() except Exception as exc: print(f"[DEBUG] Failed to create LLM client: {exc}", flush=True) client = None for task_id in TASKS: try: _run_task(task_id=task_id, client=client) except Exception as exc: try: _log_start(task=task_id, env_name=ENV_NAME, model=MODEL_NAME if client is not None else "heuristic") _log_step(step=1, action="{}", reward=0.0, done=True, error=_sanitize_error(exc)) _log_end(success=False, steps=1, score=0.0, rewards=[0.0]) except Exception as log_exc: print(f"[DEBUG] Failed to log task failure: {log_exc}", flush=True) # The validator treats non-zero exits as infrastructure failures, so we always # return 0 after emitting structured episode logs for every task. return 0 if __name__ == "__main__": try: exit_code = main() raise SystemExit(exit_code) except Exception as exc: print(f"[ERROR] Unhandled exception in main: {exc}", flush=True) import traceback traceback.print_exc(file=sys.stdout) raise SystemExit(0)