| |
| """PPO agent — loads pre-trained weights and runs inference only. |
| |
| Usage: |
| from agents.ppo_agent import PPOAgent |
| agent = PPOAgent("weights/ppo_task1_static.pt") |
| action = agent.act(observation, task_id) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
| import torch |
|
|
| from llmserve_env.models import ServeAction, ServeObservation |
| from rl.env_wrapper import obs_to_vector |
| from rl.normalize import RunningNormalizer |
| from rl.policy_network import PolicyNetwork |
|
|
|
|
| class PPOAgent: |
| """Inference-only agent that loads trained PPO weights.""" |
|
|
| def __init__(self, weights_path: str, obs_dim: int = 15) -> None: |
| self.policy = PolicyNetwork(obs_dim=obs_dim) |
| self.normalizer: RunningNormalizer | None = None |
|
|
| state = torch.load(weights_path, map_location="cpu", weights_only=False) |
| self.policy.load_state_dict(state["policy"]) |
| self.policy.eval() |
|
|
| if "normalizer" in state: |
| self.normalizer = RunningNormalizer(shape=(obs_dim,)) |
| self.normalizer.load_state_dict(state["normalizer"]) |
|
|
| def reset(self) -> None: |
| pass |
|
|
| def act(self, observation: ServeObservation, task_id: str) -> ServeAction: |
| """Select a deterministic action from the trained policy.""" |
| del task_id |
| vec = obs_to_vector(observation) |
| if self.normalizer is not None: |
| vec = self.normalizer.normalize(vec) |
|
|
| with torch.no_grad(): |
| obs_t = torch.from_numpy(vec).unsqueeze(0) |
| params, _ = self.policy.forward(obs_t) |
|
|
| batch_cap = int(torch.clamp(params["batch_cap_mean"], 1.0, 512.0).round().item()) |
| kv_budget = float(torch.clamp(params["kv_budget_mean"], 0.10, 1.0).item()) |
| spec_depth = int(torch.argmax(params["spec_depth_logits"], dim=-1).item()) |
| quant_tier = int(torch.argmax(params["quant_tier_logits"], dim=-1).item()) |
| prefill_split = bool((params["prefill_split_logit"] > 0).item()) |
| priority_route = bool((params["priority_route_logit"] > 0).item()) |
|
|
| return ServeAction( |
| batch_cap=batch_cap, |
| kv_budget_fraction=round(kv_budget, 2), |
| speculation_depth=spec_depth, |
| quantization_tier=["FP16", "INT8", "INT4"][quant_tier], |
| prefill_decode_split=prefill_split, |
| priority_routing=priority_route, |
| ) |
|
|
|
|
| def find_weights(task_id: str) -> str | None: |
| """Find the weights file for a given task_id.""" |
| label_map = { |
| "static_workload": "task1_static", |
| "bursty_workload": "task2_bursty", |
| "adversarial_multitenant": "task3_adversarial", |
| } |
| label = label_map.get(task_id) |
| if not label: |
| return None |
| weights_dir = Path(__file__).resolve().parents[1] / "weights" |
| path = weights_dir / f"ppo_{label}.pt" |
| return str(path) if path.exists() else None |
|
|