vegarl / agents /llm_agent.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
#!/usr/bin/env python3
"""LLM agent — uses OpenAI-compatible API to decide serving configuration.
Requires environment variables: API_BASE_URL, MODEL_NAME, HF_TOKEN
Falls back to PPO agent if API is unavailable.
"""
from __future__ import annotations
import json
import os
import sys
from typing import Any
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from llmserve_env.models import ServeAction, ServeObservation # noqa: E402
SYSTEM_PROMPT = """You are an LLM serving configuration optimizer. Your goal is to maximize throughput while meeting latency SLOs. Given the current server metrics as JSON, respond with a JSON ServeAction.
Action fields and ranges:
- batch_cap: int 1..512
- kv_budget_fraction: float 0.1..1.0
- speculation_depth: int 0..8
- quantization_tier: one of FP16, INT8, INT4
- prefill_decode_split: bool
- priority_routing: bool
Return ONLY valid JSON. No markdown, no explanation.""".strip()
class LLMAgent:
"""Agent that uses an OpenAI-compatible API for action selection."""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
) -> None:
from openai import OpenAI
self.api_key = api_key or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
self.base_url = base_url or os.getenv("API_BASE_URL", "")
self.model = model or os.getenv("MODEL_NAME", "gpt-4.1-mini")
self._history: list[dict[str, Any]] = []
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url or None)
def reset(self) -> None:
self._history.clear()
def act(self, observation: ServeObservation, task_id: str) -> ServeAction:
"""Query the LLM for an action, with retry and fallback."""
obs_dict = {
"queue_depth": observation.queue_depth,
"active_requests": observation.active_requests,
"kv_cache_occupancy": round(observation.kv_cache_occupancy, 3),
"mean_prompt_length": round(observation.mean_prompt_length, 1),
"p99_ttft_ms": round(observation.p99_ttft_ms, 1),
"slo_compliance_rate": round(observation.slo_compliance_rate, 3),
"throughput_tps": round(observation.throughput_tps, 1),
"eviction_events": observation.eviction_events,
"request_arrival_rate": round(observation.request_arrival_rate, 1),
"step_index": observation.step_index,
}
user_msg = f"Task: {task_id}\nCurrent metrics: {json.dumps(obs_dict)}"
if self._history:
user_msg += f"\nPrevious action: {json.dumps(self._history[-1])}"
for attempt in range(2):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
temperature=0.1 if attempt == 0 else 0.0,
max_tokens=200,
)
raw = response.choices[0].message.content or ""
action = self._parse(raw)
self._history.append(action.model_dump(mode="json"))
return action
except Exception:
if attempt == 0:
user_msg += "\n\nPrevious response was invalid. Return ONLY a JSON object with the action fields."
continue
# Fallback to heuristic if LLM fails
from server.baseline_agent import HeuristicPolicy
fallback = HeuristicPolicy()
return fallback.act(observation, task_id)
def _parse(self, raw: str) -> ServeAction:
"""Parse LLM response into a ServeAction."""
# Strip markdown code fences if present
text = raw.strip()
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines)
data = json.loads(text)
return ServeAction(**data)