File size: 3,077 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
"""PPO agent — loads pre-trained weights and runs inference only.

Usage:
    from agents.ppo_agent import PPOAgent
    agent = PPOAgent("weights/ppo_task1_static.pt")
    action = agent.act(observation, task_id)
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch  # noqa: E402

from llmserve_env.models import ServeAction, ServeObservation  # noqa: E402
from rl.env_wrapper import obs_to_vector  # noqa: E402
from rl.normalize import RunningNormalizer  # noqa: E402
from rl.policy_network import PolicyNetwork  # noqa: E402


class PPOAgent:
    """Inference-only agent that loads trained PPO weights."""

    def __init__(self, weights_path: str, obs_dim: int = 15) -> None:
        self.policy = PolicyNetwork(obs_dim=obs_dim)
        self.normalizer: RunningNormalizer | None = None

        state = torch.load(weights_path, map_location="cpu", weights_only=False)
        self.policy.load_state_dict(state["policy"])
        self.policy.eval()

        if "normalizer" in state:
            self.normalizer = RunningNormalizer(shape=(obs_dim,))
            self.normalizer.load_state_dict(state["normalizer"])

    def reset(self) -> None:
        pass  # No internal state to reset

    def act(self, observation: ServeObservation, task_id: str) -> ServeAction:
        """Select a deterministic action from the trained policy."""
        del task_id
        vec = obs_to_vector(observation)
        if self.normalizer is not None:
            vec = self.normalizer.normalize(vec)

        with torch.no_grad():
            obs_t = torch.from_numpy(vec).unsqueeze(0)
            params, _ = self.policy.forward(obs_t)

        batch_cap = int(torch.clamp(params["batch_cap_mean"], 1.0, 512.0).round().item())
        kv_budget = float(torch.clamp(params["kv_budget_mean"], 0.10, 1.0).item())
        spec_depth = int(torch.argmax(params["spec_depth_logits"], dim=-1).item())
        quant_tier = int(torch.argmax(params["quant_tier_logits"], dim=-1).item())
        prefill_split = bool((params["prefill_split_logit"] > 0).item())
        priority_route = bool((params["priority_route_logit"] > 0).item())

        return ServeAction(
            batch_cap=batch_cap,
            kv_budget_fraction=round(kv_budget, 2),
            speculation_depth=spec_depth,
            quantization_tier=["FP16", "INT8", "INT4"][quant_tier],
            prefill_decode_split=prefill_split,
            priority_routing=priority_route,
        )


def find_weights(task_id: str) -> str | None:
    """Find the weights file for a given task_id."""
    label_map = {
        "static_workload": "task1_static",
        "bursty_workload": "task2_bursty",
        "adversarial_multitenant": "task3_adversarial",
    }
    label = label_map.get(task_id)
    if not label:
        return None
    weights_dir = Path(__file__).resolve().parents[1] / "weights"
    path = weights_dir / f"ppo_{label}.pt"
    return str(path) if path.exists() else None