vegarl / agents /ppo_agent.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
#!/usr/bin/env python3
"""PPO agent — loads pre-trained weights and runs inference only.
Usage:
from agents.ppo_agent import PPOAgent
agent = PPOAgent("weights/ppo_task1_static.pt")
action = agent.act(observation, task_id)
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch # noqa: E402
from llmserve_env.models import ServeAction, ServeObservation # noqa: E402
from rl.env_wrapper import obs_to_vector # noqa: E402
from rl.normalize import RunningNormalizer # noqa: E402
from rl.policy_network import PolicyNetwork # noqa: E402
class PPOAgent:
"""Inference-only agent that loads trained PPO weights."""
def __init__(self, weights_path: str, obs_dim: int = 15) -> None:
self.policy = PolicyNetwork(obs_dim=obs_dim)
self.normalizer: RunningNormalizer | None = None
state = torch.load(weights_path, map_location="cpu", weights_only=False)
self.policy.load_state_dict(state["policy"])
self.policy.eval()
if "normalizer" in state:
self.normalizer = RunningNormalizer(shape=(obs_dim,))
self.normalizer.load_state_dict(state["normalizer"])
def reset(self) -> None:
pass # No internal state to reset
def act(self, observation: ServeObservation, task_id: str) -> ServeAction:
"""Select a deterministic action from the trained policy."""
del task_id
vec = obs_to_vector(observation)
if self.normalizer is not None:
vec = self.normalizer.normalize(vec)
with torch.no_grad():
obs_t = torch.from_numpy(vec).unsqueeze(0)
params, _ = self.policy.forward(obs_t)
batch_cap = int(torch.clamp(params["batch_cap_mean"], 1.0, 512.0).round().item())
kv_budget = float(torch.clamp(params["kv_budget_mean"], 0.10, 1.0).item())
spec_depth = int(torch.argmax(params["spec_depth_logits"], dim=-1).item())
quant_tier = int(torch.argmax(params["quant_tier_logits"], dim=-1).item())
prefill_split = bool((params["prefill_split_logit"] > 0).item())
priority_route = bool((params["priority_route_logit"] > 0).item())
return ServeAction(
batch_cap=batch_cap,
kv_budget_fraction=round(kv_budget, 2),
speculation_depth=spec_depth,
quantization_tier=["FP16", "INT8", "INT4"][quant_tier],
prefill_decode_split=prefill_split,
priority_routing=priority_route,
)
def find_weights(task_id: str) -> str | None:
"""Find the weights file for a given task_id."""
label_map = {
"static_workload": "task1_static",
"bursty_workload": "task2_bursty",
"adversarial_multitenant": "task3_adversarial",
}
label = label_map.get(task_id)
if not label:
return None
weights_dir = Path(__file__).resolve().parents[1] / "weights"
path = weights_dir / f"ppo_{label}.pt"
return str(path) if path.exists() else None