vegarl / inference.py
ronitraj's picture
fix: enhance error handling and logging across multiple modules
1d6826f
#!/usr/bin/env python3
"""InferenceGym submission runner.
Expected environment variables for judged LLM path:
- API_BASE_URL
- MODEL_NAME
- HF_TOKEN
"""
from __future__ import annotations
import json
import os
import re
import sys
from typing import Any
from openai import OpenAI
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from llmserve_env.models import ServeAction, default_action # noqa: E402
from server.grader import GraderEngine # noqa: E402
from server.llmserve_environment import LLMServeEnvironment # noqa: E402
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
DEFAULT_SEED = int(os.getenv("SEED", "42"))
MAX_STEPS = int(os.getenv("MAX_STEPS", "60"))
ENV_NAME = "InferenceGym"
TASKS = ["static_workload", "bursty_workload", "adversarial_multitenant"]
SYSTEM_PROMPT = (
"You are controlling an LLM serving environment. "
"Return exactly one JSON object with these keys: "
"batch_cap (1..512), kv_budget_fraction (0.1..1.0), speculation_depth (0..8), "
"quantization_tier (FP16|INT8|INT4), prefill_decode_split (bool), priority_routing (bool). "
"Do not include markdown or extra text."
)
def _action_dict(action: ServeAction) -> dict[str, Any]:
payload = action.model_dump(mode="json")
payload.pop("metadata", None)
return payload
def _create_fallback_agent(task_id: str):
try:
from agents.ppo_agent import PPOAgent, find_weights
weights_path = find_weights(task_id)
if weights_path:
return PPOAgent(weights_path)
except Exception:
pass
from server.baseline_agent import HeuristicPolicy
return HeuristicPolicy()
def _create_client() -> OpenAI | None:
if not HF_TOKEN:
return None
return OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
def _parse_action_payload(raw: str) -> dict[str, Any] | None:
candidate = raw.strip()
if candidate.startswith("```"):
candidate = re.sub(r"^```(?:json)?\s*|\s*```$", "", candidate, flags=re.IGNORECASE | re.DOTALL).strip()
start = candidate.find("{")
end = candidate.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = candidate[start : end + 1]
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _llm_action(client: OpenAI, task_id: str, observation: Any, previous_action: dict[str, Any] | None) -> ServeAction:
user_payload = {
"task_id": task_id,
"observation": observation.model_dump(mode="json"),
"previous_action": previous_action,
}
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=0,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))},
],
response_format={"type": "json_object"},
)
raw = response.choices[0].message.content or "{}"
payload = _parse_action_payload(raw)
if payload is None:
return default_action()
try:
return ServeAction.model_validate(payload)
except Exception:
return default_action()
def _sanitize_error(error: Exception | str | None) -> str:
if error is None:
return "null"
text = str(error).strip()
if not text:
return "null"
return text.replace("\n", " ").replace("\r", " ")[:220]
def _log_start(task: str, env_name: str, model: str) -> None:
print(f"[START] task={task} env={env_name} model={model}", flush=True)
def _log_step(step: int, action: str, reward: float, done: bool, error: str) -> None:
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error}",
flush=True,
)
def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
def _run_task(task_id: str, client: OpenAI | None) -> bool:
model_label = MODEL_NAME if client is not None else "heuristic"
_log_start(task=task_id, env_name=ENV_NAME, model=model_label)
rewards: list[float] = []
steps_taken = 0
score = 0.0
success = False
previous_action: dict[str, Any] | None = None
env: LLMServeEnvironment | None = None
grader: GraderEngine | None = None
fallback_agent: Any = None
try:
env = LLMServeEnvironment(seed=DEFAULT_SEED, mode="sim")
grader = GraderEngine()
fallback_agent = _create_fallback_agent(task_id)
if hasattr(fallback_agent, "reset"):
fallback_agent.reset()
observation = env.reset(seed=DEFAULT_SEED, task_id=task_id)
task_cfg = env.task_config or {}
configured_max_steps = int(task_cfg.get("max_steps", MAX_STEPS))
max_steps = min(configured_max_steps, MAX_STEPS)
for step_idx in range(1, max_steps + 1):
if client is not None:
try:
action = _llm_action(client, task_id, observation, previous_action)
except Exception as exc:
action = fallback_agent.act(observation, task_id)
else:
action = fallback_agent.act(observation, task_id)
action_json = json.dumps(_action_dict(action), separators=(",", ":"))
try:
observation = env.step(action)
reward = float(getattr(observation, "reward", 0.0) or 0.0)
done = bool(getattr(observation, "done", False))
rewards.append(reward)
steps_taken = step_idx
_log_step(step=step_idx, action=action_json, reward=reward, done=done, error="null")
previous_action = _action_dict(action)
if done:
break
except Exception as exc:
rewards.append(0.0)
steps_taken = step_idx
_log_step(step=step_idx, action=action_json, reward=0.0, done=True, error=_sanitize_error(exc))
break
grade = grader.grade(env.export_episode_log()) if grader is not None else {"score": 0.0}
score = float(grade.get("score", 0.0))
score = max(0.0, min(1.0, score))
success = score > 0.0
except Exception as exc:
next_step = len(rewards) + 1
rewards.append(0.0)
steps_taken = next_step
_log_step(step=next_step, action="{}", reward=0.0, done=True, error=_sanitize_error(exc))
success = False
finally:
_log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return success
def main() -> int:
try:
client = _create_client()
except Exception as exc:
print(f"[DEBUG] Failed to create LLM client: {exc}", flush=True)
client = None
for task_id in TASKS:
try:
_run_task(task_id=task_id, client=client)
except Exception as exc:
try:
_log_start(task=task_id, env_name=ENV_NAME, model=MODEL_NAME if client is not None else "heuristic")
_log_step(step=1, action="{}", reward=0.0, done=True, error=_sanitize_error(exc))
_log_end(success=False, steps=1, score=0.0, rewards=[0.0])
except Exception as log_exc:
print(f"[DEBUG] Failed to log task failure: {log_exc}", flush=True)
# The validator treats non-zero exits as infrastructure failures, so we always
# return 0 after emitting structured episode logs for every task.
return 0
if __name__ == "__main__":
try:
exit_code = main()
raise SystemExit(exit_code)
except Exception as exc:
print(f"[ERROR] Unhandled exception in main: {exc}", flush=True)
import traceback
traceback.print_exc(file=sys.stdout)
raise SystemExit(0)