ev-grid-oracle / training /evaluate.py
NITISHRG15102007's picture
sync: push from tools/sync_space_to_hub.py (no artifacts/)
32e1e21 verified
from __future__ import annotations
import argparse
import json
import os
from collections import defaultdict
from dataclasses import dataclass, field
from statistics import mean
from typing import Any
from ev_grid_oracle.city_graph import build_city_graph
from ev_grid_oracle.env import EVGridCore
from ev_grid_oracle.models import ActionType, EVGridAction
from ev_grid_oracle.oracle_agent import OracleAgent
from ev_grid_oracle.policies import baseline_policy
from ev_grid_oracle.scenarios import ScenarioName
@dataclass(frozen=True, slots=True)
class EpisodeMetrics:
avg_wait: float
grid_stress_events: int
peak_violations: int
renewable_mean: float
critical_deferred: int
anti_cheat_steps: int
reward_breakdown_mean: dict[str, float] = field(default_factory=dict)
def _episode_metrics_to_json(m: EpisodeMetrics) -> dict[str, Any]:
return {
"avg_wait": m.avg_wait,
"grid_stress_events": m.grid_stress_events,
"peak_violations": m.peak_violations,
"renewable_mean": m.renewable_mean,
"critical_deferred": m.critical_deferred,
"anti_cheat_steps": m.anti_cheat_steps,
"reward_breakdown_mean": dict(m.reward_breakdown_mean),
}
def run_episode(
env: EVGridCore,
*,
policy: str,
seed: int,
scenario: ScenarioName = "baseline",
oracle_repo: str | None = None,
) -> EpisodeMetrics:
obs = env.reset(seed=seed, scenario=scenario)
graph = env.city_graph
oracle = None
if policy == "oracle":
if os.getenv("ORACLE_SKIP_LLM", "").strip() not in ("", "0", "false", "False"):
oracle = OracleAgent(lora_repo_id=None)
else:
oracle = OracleAgent(lora_repo_id=(oracle_repo or "").strip() or None)
waits = []
grid_stress = 0
peak_violations = 0
renewable = []
critical_deferred = 0
anti_cheat_steps = 0
rb_sums: dict[str, float] = defaultdict(float)
rb_steps = 0
for _ in range(env.max_steps):
state = obs.state
waits.append(mean([s.avg_wait_minutes for s in state.stations]))
grid_stress += sum(1 for s in state.stations if (s.occupied_slots / max(1, s.total_slots)) > 0.85)
peak_violations += 1 if state.grid_load_pct > 0.80 else 0
renewable.append(state.renewable_pct)
if not state.pending_evs:
action = EVGridAction(action_type=ActionType.load_shift, ev_id="EV-000", defer_minutes=0)
else:
if policy == "baseline":
action = baseline_policy(state, graph)
else:
action = oracle.act(state, obs.prompt, graph) if oracle else baseline_policy(state, graph)
if action.action_type == ActionType.defer and state.pending_evs[0].battery_pct_0_100 < 15.0:
critical_deferred += 1
obs = env.step(action)
if obs.anti_cheat_flags:
anti_cheat_steps += 1
rb = obs.reward_breakdown
for k, v in rb.items():
ks = str(k)
if ks == "total" or ks.startswith("_"):
continue
rb_sums[ks] += float(v)
rb_steps += 1
if obs.done:
break
rb_mean: dict[str, float] = {}
if rb_steps > 0:
rb_mean = {k: rb_sums[k] / rb_steps for k in sorted(rb_sums.keys())}
return EpisodeMetrics(
avg_wait=float(mean(waits)) if waits else 0.0,
grid_stress_events=int(grid_stress),
peak_violations=int(peak_violations),
renewable_mean=float(mean(renewable)) if renewable else 0.0,
critical_deferred=int(critical_deferred),
anti_cheat_steps=int(anti_cheat_steps),
reward_breakdown_mean=rb_mean,
)
def summarize(metrics: list[EpisodeMetrics]) -> dict[str, float]:
return {
"avg_wait_minutes": mean([m.avg_wait for m in metrics]),
"grid_stress_events": mean([m.grid_stress_events for m in metrics]),
"peak_violations": mean([m.peak_violations for m in metrics]),
"renewable_mean": mean([m.renewable_mean for m in metrics]),
"critical_deferred": mean([m.critical_deferred for m in metrics]),
"anti_cheat_steps": mean([m.anti_cheat_steps for m in metrics]),
}
def summarize_reward_breakdown(metrics: list[EpisodeMetrics]) -> dict[str, float]:
sums: dict[str, float] = defaultdict(float)
counts: dict[str, int] = defaultdict(int)
for m in metrics:
for k, v in m.reward_breakdown_mean.items():
sums[k] += float(v)
counts[k] += 1
return {k: sums[k] / counts[k] for k in sorted(sums.keys()) if counts[k]}
SCENARIO_CHOICES: tuple[str, ...] = (
"baseline",
"heatwave_peak",
"festival_surge",
"transformer_derate",
"station_outage",
"tariff_shock",
)
def main():
ap = argparse.ArgumentParser(description="Paired baseline vs oracle rollouts (same seed + scenario per episode).")
ap.add_argument("--episodes", type=int, default=50)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument(
"--scenario",
type=str,
default="baseline",
choices=SCENARIO_CHOICES,
help="Deterministic scenario schedule (see ev_grid_oracle/scenarios.py).",
)
ap.add_argument("--out", type=str, default="training/eval_results.json")
args = ap.parse_args()
scenario: ScenarioName = args.scenario # type: ignore[assignment]
graph = build_city_graph()
env = EVGridCore(city_graph=graph)
oracle_repo = os.getenv("ORACLE_LORA_REPO", "").strip() or None
baseline_runs: list[EpisodeMetrics] = []
oracle_runs: list[EpisodeMetrics] = []
per_episode: list[dict[str, Any]] = []
for i in range(args.episodes):
episode_seed = args.seed + i
b = run_episode(env, policy="baseline", seed=episode_seed, scenario=scenario, oracle_repo=oracle_repo)
o = run_episode(env, policy="oracle", seed=episode_seed, scenario=scenario, oracle_repo=oracle_repo)
baseline_runs.append(b)
oracle_runs.append(o)
per_episode.append(
{
"episode_index": i,
"episode_seed": episode_seed,
"scenario": scenario,
"baseline": _episode_metrics_to_json(b),
"oracle": _episode_metrics_to_json(o),
"binary": {
"baseline_any_peak_violation": b.peak_violations > 0,
"oracle_any_peak_violation": o.peak_violations > 0,
"baseline_any_anti_cheat": b.anti_cheat_steps > 0,
"oracle_any_anti_cheat": o.anti_cheat_steps > 0,
"baseline_any_critical_defer": b.critical_deferred > 0,
"oracle_any_critical_defer": o.critical_deferred > 0,
"baseline_high_stress": b.grid_stress_events > int(env.max_steps * 0.25),
"oracle_high_stress": o.grid_stress_events > int(env.max_steps * 0.25),
},
}
)
out = {
"episodes": args.episodes,
"seed": args.seed,
"scenario": scenario,
"paired_same_world": True,
"baseline": summarize(baseline_runs),
"oracle": summarize(oracle_runs),
"baseline_reward_breakdown_mean": summarize_reward_breakdown(baseline_runs),
"oracle_reward_breakdown_mean": summarize_reward_breakdown(oracle_runs),
"per_episode": per_episode,
"note": "Per-episode: identical episode_seed for baseline and oracle. Oracle uses ORACLE_LORA_REPO if set; "
"ORACLE_SKIP_LLM forces baseline policy inside oracle path.",
}
with open(args.out, "w", encoding="utf-8") as f:
json.dump(out, f, indent=2)
print(json.dumps(out, indent=2))
if __name__ == "__main__":
main()