File size: 4,493 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""Evaluate agents on InferenceGym tasks and print benchmark table.

Usage:
    python evaluate.py --agent ppo --task all --episodes 20 --seed 42
    python evaluate.py --agent heuristic --task static_workload --episodes 10
    python evaluate.py --agent random --task all --episodes 10
"""
from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import numpy as np  # noqa: E402

from server.llmserve_environment import LLMServeEnvironment  # noqa: E402

TASK_IDS = ["static_workload", "bursty_workload", "adversarial_multitenant"]
AGENT_TYPES = ["random", "heuristic", "ppo"]
WEIGHTS_DIR = Path(__file__).resolve().parent / "weights"


def _get_agent(agent_type: str, task_id: str):
    """Return an agent object with a .act(obs, task_id) method."""
    if agent_type == "heuristic":
        from server.baseline_agent import HeuristicPolicy
        return HeuristicPolicy()

    if agent_type == "random":
        import random as rnd
        from agents.random_agent import random_action
        rng = rnd.Random(42)

        class _RandomAgent:
            def reset(self): pass
            def act(self, obs, tid): return random_action(rng)

        return _RandomAgent()

    if agent_type == "ppo":
        from agents.ppo_agent import PPOAgent
        label_map = {
            "static_workload": "task1_static",
            "bursty_workload": "task2_bursty",
            "adversarial_multitenant": "task3_adversarial",
        }
        label = label_map.get(task_id, "task1_static")
        weight_path = WEIGHTS_DIR / f"ppo_{label}.pt"
        if not weight_path.exists():
            print(f"[WARN] PPO weights not found at {weight_path}, falling back to heuristic")
            from server.baseline_agent import HeuristicPolicy
            return HeuristicPolicy()
        return PPOAgent(str(weight_path))

    raise ValueError(f"Unknown agent type: {agent_type}")


def run_episode(env: LLMServeEnvironment, agent, task_id: str, seed: int) -> float:
    if hasattr(agent, "reset"):
        agent.reset()
    obs = env.reset(seed=seed, task_id=task_id)
    task_cfg = env.task_config
    max_steps = int(task_cfg["max_steps"]) if task_cfg else 60
    total_reward = 0.0
    for _ in range(max_steps):
        action = agent.act(obs, task_id)
        obs = env.step(action)
        total_reward += float(getattr(obs, "reward", 0.0) or 0.0)
        if getattr(obs, "done", False):
            break
    return total_reward


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Evaluate agents on InferenceGym")
    parser.add_argument("--agent", default="ppo", choices=AGENT_TYPES + ["all"])
    parser.add_argument("--task", default="all")
    parser.add_argument("--episodes", type=int, default=20)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output", type=str, default=None)
    args = parser.parse_args(argv)

    tasks = TASK_IDS if args.task == "all" else [args.task]
    env = LLMServeEnvironment(seed=args.seed, mode="sim")

    results = {}
    selected_agents = AGENT_TYPES if args.agent == "all" else [args.agent]

    print(f"\n{'Agent':<12} {'Task':<28} {'Mean Reward':>12} {'Std':>8} {'Episodes':>9}")
    print("-" * 72)

    for agent_type in selected_agents:
        agent_results = {}
        for task_id in tasks:
            agent = _get_agent(agent_type, task_id)
            rewards = []
            for ep in range(args.episodes):
                r = run_episode(env, agent, task_id, args.seed + ep)
                rewards.append(r)
            mean_r = float(np.mean(rewards))
            std_r = float(np.std(rewards))
            agent_results[task_id] = {"mean_reward": round(mean_r, 4), "std_reward": round(std_r, 4), "episodes": args.episodes}
            print(f"{agent_type:<12} {task_id:<28} {mean_r:>12.4f} {std_r:>8.4f} {args.episodes:>9d}")
        if args.agent == "all":
            results[agent_type] = agent_results
        else:
            results = agent_results

    if args.output:
        Path(args.output).parent.mkdir(parents=True, exist_ok=True)
        with open(args.output, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {args.output}")

    print(f"\n{json.dumps(results, indent=2)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())