File size: 5,951 Bytes
98a5a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
eval_trained.py — Evaluate the GRPO-trained model against the heuristic baseline.

Loads the merged model from trained_models/grpo_qwen3_0.6b/ directly (no API server needed).
Runs N episodes on hard_multi and prints mean reward vs heuristic baseline.

USAGE
    uv run python train/eval_trained.py

HOW IT WORKS
    The trained model is loaded as a plain AutoModelForCausalLM (LoRA already merged).
    At each step, we feed the current observation as a chat message and parse the
    model's text output as a tool call (same _parse_llm_action logic as inference.py).
"""

from __future__ import annotations

import argparse
import os
import sys

os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType, Observation
from budget_router.policies import heuristic_baseline_policy
from budget_router.reward import grade_episode
from budget_router.tasks import HARD_MULTI

N_EPISODES = 10
SCENARIO = HARD_MULTI

SYSTEM_PROMPT = (
    "You are a budget-aware API router. "
    "Use the available tools to route each request to the best provider. "
    "Providers can degrade mid-episode — monitor health and switch early.\n\n"
    "At each step output EXACTLY ONE action string from: "
    "route_to_a | route_to_b | route_to_c | shed_load"
)

_VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"]


def _parse_action(text: str) -> str:
    text = text.strip().lower()
    for a in _VALID_ACTIONS:
        if a in text:
            return a
    return "shed_load"


def _obs_to_text(obs: Observation) -> str:
    return (
        f"provider_a_status: {obs.provider_a_status:.3f}\n"
        f"provider_b_status: {obs.provider_b_status:.3f}\n"
        f"provider_c_status: {obs.provider_c_status:.3f}\n"
        f"budget_remaining:  {obs.budget_remaining:.3f}\n"
        f"step_count:        {obs.step_count:.3f}\n"
        f"Your action:"
    )


def run_episode_llm(model, tokenizer, seed: int, device: str) -> float:
    env = BudgetRouterEnv()
    obs = env.reset(scenario=SCENARIO, seed=seed)
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    while not obs.done:
        messages.append({"role": "user", "content": _obs_to_text(obs)})
        try:
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                chat_template_kwargs={"enable_thinking": False},
            )
        except TypeError:
            # Older Transformers versions may not expose chat_template_kwargs here.
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=20,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
            )
        generated = tokenizer.decode(
            out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
        )
        action_str = _parse_action(generated)
        messages.append({"role": "assistant", "content": action_str})

        action = Action(action_type=ActionType(action_str))
        obs = env.step(action)

    return float(grade_episode(env._internal.history)["overall_score"])


def run_episode_heuristic(seed: int) -> float:
    env = BudgetRouterEnv()
    obs = env.reset(scenario=SCENARIO, seed=seed)
    while not obs.done:
        action = heuristic_baseline_policy(obs)
        obs = env.step(action)
    return float(grade_episode(env._internal.history)["overall_score"])


def main():
    parser = argparse.ArgumentParser(description="Evaluate a GRPO-trained model vs heuristic baseline.")
    parser.add_argument(
        "--model-path",
        type=str,
        default="trained_models/grpo_Qwen_Qwen3-1.7B",
        help="Path to merged trained model directory (default: trained_models/grpo_Qwen_Qwen3-1.7B).",
    )
    parser.add_argument("--n-episodes", type=int, default=N_EPISODES, help="Number of eval episodes.")
    args = parser.parse_args()

    model_path = args.model_path
    if not os.path.exists(model_path):
        print(f"❌ Trained model not found at {MODEL_PATH}")
        print("   Run train/learn_experiment.py first.")
        sys.exit(1)

    device = "mps" if torch.backends.mps.is_available() else "cpu"
    dtype = torch.bfloat16 if device == "mps" else torch.float32

    print(f"Loading trained model from {model_path} ...")
    model = AutoModelForCausalLM.from_pretrained(model_path, dtype=dtype)
    model = model.to(device)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    print(f"\nRunning {args.n_episodes} episodes on {SCENARIO.name} ...")
    print(f"{'Seed':<6} {'LLM':>8} {'Heuristic':>12}")
    print("-" * 30)

    llm_scores, heuristic_scores = [], []
    for seed in range(args.n_episodes):
        llm_r = run_episode_llm(model, tokenizer, seed, device)
        heur_r = run_episode_heuristic(seed)
        llm_scores.append(llm_r)
        heuristic_scores.append(heur_r)
        print(f"{seed:<6} {llm_r:>8.4f} {heur_r:>12.4f}")

    llm_mean = sum(llm_scores) / len(llm_scores)
    heur_mean = sum(heuristic_scores) / len(heuristic_scores)

    print("-" * 30)
    print(f"{'Mean':<6} {llm_mean:>8.4f} {heur_mean:>12.4f}")
    print()
    if llm_mean >= heur_mean:
        print(f"✅ LLM ({llm_mean:.4f}) >= Heuristic ({heur_mean:.4f}) — BEATS BASELINE")
    else:
        gap = heur_mean - llm_mean
        print(f"⚠️  LLM ({llm_mean:.4f}) < Heuristic ({heur_mean:.4f}) — gap={gap:.4f}")


if __name__ == "__main__":
    main()