test-rl-hackathon-budget / visualize.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
Episode visualization for Budget Router environment.
Generates a 4-panel matplotlib figure showing:
1. Provider health degradation over time
2. Budget remaining curve
3. Action distribution per step (color-coded strip)
4. Cumulative reward trajectory
Usage:
python visualize.py --scenario hard_multi --seed 42
python visualize.py --scenario medium --policy oracle
"""
from __future__ import annotations
import argparse
import random
from typing import Any, Dict, List
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType
from budget_router.policies import (
debug_upper_bound_policy,
heuristic_baseline_policy,
random_policy,
)
from budget_router.tasks import TASK_PRESETS
ACTION_ORDER = ["route_to_a", "route_to_b", "route_to_c", "shed_load"]
ACTION_LABELS = ["Route A", "Route B", "Route C", "Shed Load"]
ACTION_COLORS = ["#e74c3c", "#3498db", "#2ecc71", "#95a5a6"]
def run_and_trace(
env: BudgetRouterEnv,
policy_fn: Any,
seed: int,
scenario_name: str,
policy_name: str = "heuristic_baseline",
) -> Dict[str, List]:
"""Run an episode and collect per-step trace data for visualization."""
scenario = TASK_PRESETS[scenario_name]
obs = env.reset(seed=seed, scenario=scenario)
rng = random.Random(seed + 10000) if "random" in policy_name else None
trace: Dict[str, List] = {
"step": [],
"a_health": [],
"b_health": [],
"c_health": [],
"budget": [],
"budget_pct": [],
"reward": [],
"cumulative_reward": [],
"action": [],
"latency_ms": [],
"queue_backlog": [],
}
cumulative = 0.0
steps = 0
initial_budget = scenario.initial_budget
while not obs.done and steps < scenario.max_steps:
if "upper_bound" in policy_name:
action = policy_fn(obs, env._internal)
elif "random" in policy_name:
action = policy_fn(obs, rng=rng)
else:
action = policy_fn(obs)
obs = env.step(action)
steps += 1
reward = obs.reward or 0.0
cumulative += reward
s = env._internal
trace["step"].append(steps)
trace["a_health"].append(s.providers["A"].current_health)
trace["b_health"].append(s.providers["B"].current_health)
trace["c_health"].append(s.providers["C"].current_health)
trace["budget"].append(s.budget_dollars)
trace["budget_pct"].append(s.budget_dollars / initial_budget if initial_budget > 0 else 0)
trace["reward"].append(reward)
trace["cumulative_reward"].append(cumulative)
trace["action"].append(action.action_type.value)
trace["latency_ms"].append(s.last_latency_ms)
trace["queue_backlog"].append(s.queue_backlog_count)
return trace
def render_episode(trace: Dict[str, List], scenario_name: str, policy_name: str, seed: int) -> plt.Figure:
"""Create a 4-panel matplotlib figure from episode trace data."""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle(
f"Budget Router Episode — {scenario_name.upper()} / {policy_name.replace('_', ' ').title()} / seed={seed}",
fontsize=14,
fontweight="bold",
)
steps = trace["step"]
# ── Panel 1: Provider health degradation ──
ax1 = axes[0, 0]
ax1.plot(steps, trace["a_health"], "o-", color="#e74c3c", label="Provider A", linewidth=2, markersize=4)
ax1.plot(steps, trace["b_health"], "s-", color="#3498db", label="Provider B", linewidth=2, markersize=4)
ax1.plot(steps, trace["c_health"], "^-", color="#2ecc71", label="Provider C", linewidth=2, markersize=4)
ax1.axhline(y=0.52, color="gray", linestyle="--", alpha=0.5, label="Heuristic threshold (0.52)")
ax1.set_xlabel("Step")
ax1.set_ylabel("Health (true)")
ax1.set_title("Provider Health Degradation")
ax1.legend(fontsize=8, loc="upper right")
ax1.set_ylim(-0.05, 1.05)
ax1.grid(True, alpha=0.3)
# ── Panel 2: Budget remaining ──
ax2 = axes[0, 1]
ax2.plot(steps, trace["budget"], "o-", color="#f39c12", linewidth=2, markersize=4)
ax2.fill_between(steps, trace["budget"], alpha=0.2, color="#f39c12")
ax2.axhline(y=0, color="red", linestyle="--", alpha=0.7, label="Budget exhausted")
ax2.set_xlabel("Step")
ax2.set_ylabel("Budget Remaining ($)")
ax2.set_title("Budget Remaining")
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3)
# ── Panel 3: Action distribution (color-coded strip) ──
ax3 = axes[1, 0]
action_map = {name: i for i, name in enumerate(ACTION_ORDER)}
for i, (action_val, step_val) in enumerate(zip(trace["action"], steps)):
idx = action_map.get(action_val, 3)
ax3.barh(idx, 0.8, left=step_val - 0.4, height=0.6, color=ACTION_COLORS[idx], edgecolor="white", linewidth=0.5)
ax3.set_yticks(range(len(ACTION_LABELS)))
ax3.set_yticklabels(ACTION_LABELS)
ax3.set_xlabel("Step")
ax3.set_title("Action per Step")
ax3.set_xlim(0.5, max(steps) + 0.5 if steps else 20.5)
ax3.grid(True, alpha=0.3, axis="x")
# Add legend patches manually
from matplotlib.patches import Patch
legend_patches = [Patch(facecolor=ACTION_COLORS[i], label=ACTION_LABELS[i]) for i in range(len(ACTION_LABELS))]
ax3.legend(handles=legend_patches, fontsize=7, loc="upper right", ncol=2)
# ── Panel 4: Cumulative reward ──
ax4 = axes[1, 1]
ax4.plot(steps, trace["cumulative_reward"], "o-", color="#9b59b6", linewidth=2, markersize=4)
ax4.fill_between(steps, trace["cumulative_reward"], alpha=0.1, color="#9b59b6")
ax4.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
ax4.set_xlabel("Step")
ax4.set_ylabel("Cumulative Reward")
ax4.set_title(f"Cumulative Reward (total: {trace['cumulative_reward'][-1]:.2f})")
ax4.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def main() -> None:
parser = argparse.ArgumentParser(description="Visualize Budget Router episode")
parser.add_argument("--scenario", default="hard_multi", choices=list(TASK_PRESETS.keys()))
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--policy",
default="heuristic_baseline",
choices=["heuristic_baseline", "oracle", "random"],
)
parser.add_argument("--output", default=None, help="Output file path (default: docs/{scenario}_{policy}.png)")
args = parser.parse_args()
policies = {
"heuristic_baseline": (heuristic_baseline_policy, "heuristic_baseline"),
"oracle": (debug_upper_bound_policy, "upper_bound"),
"random": (random_policy, "random"),
}
policy_fn, policy_name = policies[args.policy]
env = BudgetRouterEnv()
print(f"Running episode: {args.scenario} / {args.policy} / seed={args.seed}")
trace = run_and_trace(env, policy_fn, args.seed, args.scenario, policy_name=policy_name)
print(f"Episode complete: {len(trace['step'])} steps, total reward={trace['cumulative_reward'][-1]:.2f}")
fig = render_episode(trace, args.scenario, args.policy, args.seed)
output = args.output or f"docs/{args.scenario}_{args.policy}_seed{args.seed}.png"
fig.savefig(output, dpi=150, bbox_inches="tight")
print(f"Saved to: {output}")
plt.close(fig)
if __name__ == "__main__":
main()