budget-router-openenv / train /eval_trained.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
eval_trained.py — Evaluate the GRPO-trained model against the heuristic baseline.
Loads the merged model from trained_models/grpo_qwen3_0.6b/ directly (no API server needed).
Runs N episodes on hard_multi and prints mean reward vs heuristic baseline.
USAGE
uv run python train/eval_trained.py
HOW IT WORKS
The trained model is loaded as a plain AutoModelForCausalLM (LoRA already merged).
At each step, we feed the current observation as a chat message and parse the
model's text output as a tool call (same _parse_llm_action logic as inference.py).
"""
from __future__ import annotations
import argparse
import os
import sys
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType, Observation
from budget_router.policies import heuristic_baseline_policy
from budget_router.reward import grade_episode
from budget_router.tasks import HARD_MULTI
N_EPISODES = 10
SCENARIO = HARD_MULTI
SYSTEM_PROMPT = (
"You are a budget-aware API router. "
"Use the available tools to route each request to the best provider. "
"Providers can degrade mid-episode — monitor health and switch early.\n\n"
"At each step output EXACTLY ONE action string from: "
"route_to_a | route_to_b | route_to_c | shed_load"
)
_VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"]
def _parse_action(text: str) -> str:
text = text.strip().lower()
for a in _VALID_ACTIONS:
if a in text:
return a
return "shed_load"
def _obs_to_text(obs: Observation) -> str:
return (
f"provider_a_status: {obs.provider_a_status:.3f}\n"
f"provider_b_status: {obs.provider_b_status:.3f}\n"
f"provider_c_status: {obs.provider_c_status:.3f}\n"
f"budget_remaining: {obs.budget_remaining:.3f}\n"
f"step_count: {obs.step_count:.3f}\n"
f"Your action:"
)
def run_episode_llm(model, tokenizer, seed: int, device: str) -> float:
env = BudgetRouterEnv()
obs = env.reset(scenario=SCENARIO, seed=seed)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
while not obs.done:
messages.append({"role": "user", "content": _obs_to_text(obs)})
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template_kwargs={"enable_thinking": False},
)
except TypeError:
# Older Transformers versions may not expose chat_template_kwargs here.
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
)
action_str = _parse_action(generated)
messages.append({"role": "assistant", "content": action_str})
action = Action(action_type=ActionType(action_str))
obs = env.step(action)
return float(grade_episode(env._internal.history)["overall_score"])
def run_episode_heuristic(seed: int) -> float:
env = BudgetRouterEnv()
obs = env.reset(scenario=SCENARIO, seed=seed)
while not obs.done:
action = heuristic_baseline_policy(obs)
obs = env.step(action)
return float(grade_episode(env._internal.history)["overall_score"])
def main():
parser = argparse.ArgumentParser(description="Evaluate a GRPO-trained model vs heuristic baseline.")
parser.add_argument(
"--model-path",
type=str,
default="trained_models/grpo_Qwen_Qwen3-1.7B",
help="Path to merged trained model directory (default: trained_models/grpo_Qwen_Qwen3-1.7B).",
)
parser.add_argument("--n-episodes", type=int, default=N_EPISODES, help="Number of eval episodes.")
args = parser.parse_args()
model_path = args.model_path
if not os.path.exists(model_path):
print(f"❌ Trained model not found at {MODEL_PATH}")
print(" Run train/learn_experiment.py first.")
sys.exit(1)
device = "mps" if torch.backends.mps.is_available() else "cpu"
dtype = torch.bfloat16 if device == "mps" else torch.float32
print(f"Loading trained model from {model_path} ...")
model = AutoModelForCausalLM.from_pretrained(model_path, dtype=dtype)
model = model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"\nRunning {args.n_episodes} episodes on {SCENARIO.name} ...")
print(f"{'Seed':<6} {'LLM':>8} {'Heuristic':>12}")
print("-" * 30)
llm_scores, heuristic_scores = [], []
for seed in range(args.n_episodes):
llm_r = run_episode_llm(model, tokenizer, seed, device)
heur_r = run_episode_heuristic(seed)
llm_scores.append(llm_r)
heuristic_scores.append(heur_r)
print(f"{seed:<6} {llm_r:>8.4f} {heur_r:>12.4f}")
llm_mean = sum(llm_scores) / len(llm_scores)
heur_mean = sum(heuristic_scores) / len(heuristic_scores)
print("-" * 30)
print(f"{'Mean':<6} {llm_mean:>8.4f} {heur_mean:>12.4f}")
print()
if llm_mean >= heur_mean:
print(f"✅ LLM ({llm_mean:.4f}) >= Heuristic ({heur_mean:.4f}) — BEATS BASELINE")
else:
gap = heur_mean - llm_mean
print(f"⚠️ LLM ({llm_mean:.4f}) < Heuristic ({heur_mean:.4f}) — gap={gap:.4f}")
if __name__ == "__main__":
main()