| |
| """LLM agent — uses OpenAI-compatible API to decide serving configuration. |
| |
| Requires environment variables: API_BASE_URL, MODEL_NAME, HF_TOKEN |
| Falls back to PPO agent if API is unavailable. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from typing import Any |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
| from llmserve_env.models import ServeAction, ServeObservation |
|
|
| SYSTEM_PROMPT = """You are an LLM serving configuration optimizer. Your goal is to maximize throughput while meeting latency SLOs. Given the current server metrics as JSON, respond with a JSON ServeAction. |
| |
| Action fields and ranges: |
| - batch_cap: int 1..512 |
| - kv_budget_fraction: float 0.1..1.0 |
| - speculation_depth: int 0..8 |
| - quantization_tier: one of FP16, INT8, INT4 |
| - prefill_decode_split: bool |
| - priority_routing: bool |
| |
| Return ONLY valid JSON. No markdown, no explanation.""".strip() |
|
|
|
|
| class LLMAgent: |
| """Agent that uses an OpenAI-compatible API for action selection.""" |
|
|
| def __init__( |
| self, |
| api_key: str | None = None, |
| base_url: str | None = None, |
| model: str | None = None, |
| ) -> None: |
| from openai import OpenAI |
|
|
| self.api_key = api_key or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "") |
| self.base_url = base_url or os.getenv("API_BASE_URL", "") |
| self.model = model or os.getenv("MODEL_NAME", "gpt-4.1-mini") |
| self._history: list[dict[str, Any]] = [] |
|
|
| self.client = OpenAI(api_key=self.api_key, base_url=self.base_url or None) |
|
|
| def reset(self) -> None: |
| self._history.clear() |
|
|
| def act(self, observation: ServeObservation, task_id: str) -> ServeAction: |
| """Query the LLM for an action, with retry and fallback.""" |
| obs_dict = { |
| "queue_depth": observation.queue_depth, |
| "active_requests": observation.active_requests, |
| "kv_cache_occupancy": round(observation.kv_cache_occupancy, 3), |
| "mean_prompt_length": round(observation.mean_prompt_length, 1), |
| "p99_ttft_ms": round(observation.p99_ttft_ms, 1), |
| "slo_compliance_rate": round(observation.slo_compliance_rate, 3), |
| "throughput_tps": round(observation.throughput_tps, 1), |
| "eviction_events": observation.eviction_events, |
| "request_arrival_rate": round(observation.request_arrival_rate, 1), |
| "step_index": observation.step_index, |
| } |
|
|
| user_msg = f"Task: {task_id}\nCurrent metrics: {json.dumps(obs_dict)}" |
| if self._history: |
| user_msg += f"\nPrevious action: {json.dumps(self._history[-1])}" |
|
|
| for attempt in range(2): |
| try: |
| response = self.client.chat.completions.create( |
| model=self.model, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_msg}, |
| ], |
| temperature=0.1 if attempt == 0 else 0.0, |
| max_tokens=200, |
| ) |
| raw = response.choices[0].message.content or "" |
| action = self._parse(raw) |
| self._history.append(action.model_dump(mode="json")) |
| return action |
| except Exception: |
| if attempt == 0: |
| user_msg += "\n\nPrevious response was invalid. Return ONLY a JSON object with the action fields." |
| continue |
|
|
| |
| from server.baseline_agent import HeuristicPolicy |
| fallback = HeuristicPolicy() |
| return fallback.act(observation, task_id) |
|
|
| def _parse(self, raw: str) -> ServeAction: |
| """Parse LLM response into a ServeAction.""" |
| |
| text = raw.strip() |
| if text.startswith("```"): |
| lines = text.split("\n") |
| lines = [l for l in lines if not l.strip().startswith("```")] |
| text = "\n".join(lines) |
| data = json.loads(text) |
| return ServeAction(**data) |
|
|