Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from statistics import mean | |
| from typing import Any | |
| from ev_grid_oracle.city_graph import build_city_graph | |
| from ev_grid_oracle.env import EVGridCore | |
| from ev_grid_oracle.models import ActionType, EVGridAction | |
| from ev_grid_oracle.oracle_agent import OracleAgent | |
| from ev_grid_oracle.policies import baseline_policy | |
| from ev_grid_oracle.scenarios import ScenarioName | |
| class EpisodeMetrics: | |
| avg_wait: float | |
| grid_stress_events: int | |
| peak_violations: int | |
| renewable_mean: float | |
| critical_deferred: int | |
| anti_cheat_steps: int | |
| reward_breakdown_mean: dict[str, float] = field(default_factory=dict) | |
| def _episode_metrics_to_json(m: EpisodeMetrics) -> dict[str, Any]: | |
| return { | |
| "avg_wait": m.avg_wait, | |
| "grid_stress_events": m.grid_stress_events, | |
| "peak_violations": m.peak_violations, | |
| "renewable_mean": m.renewable_mean, | |
| "critical_deferred": m.critical_deferred, | |
| "anti_cheat_steps": m.anti_cheat_steps, | |
| "reward_breakdown_mean": dict(m.reward_breakdown_mean), | |
| } | |
| def run_episode( | |
| env: EVGridCore, | |
| *, | |
| policy: str, | |
| seed: int, | |
| scenario: ScenarioName = "baseline", | |
| oracle_repo: str | None = None, | |
| ) -> EpisodeMetrics: | |
| obs = env.reset(seed=seed, scenario=scenario) | |
| graph = env.city_graph | |
| oracle = None | |
| if policy == "oracle": | |
| if os.getenv("ORACLE_SKIP_LLM", "").strip() not in ("", "0", "false", "False"): | |
| oracle = OracleAgent(lora_repo_id=None) | |
| else: | |
| oracle = OracleAgent(lora_repo_id=(oracle_repo or "").strip() or None) | |
| waits = [] | |
| grid_stress = 0 | |
| peak_violations = 0 | |
| renewable = [] | |
| critical_deferred = 0 | |
| anti_cheat_steps = 0 | |
| rb_sums: dict[str, float] = defaultdict(float) | |
| rb_steps = 0 | |
| for _ in range(env.max_steps): | |
| state = obs.state | |
| waits.append(mean([s.avg_wait_minutes for s in state.stations])) | |
| grid_stress += sum(1 for s in state.stations if (s.occupied_slots / max(1, s.total_slots)) > 0.85) | |
| peak_violations += 1 if state.grid_load_pct > 0.80 else 0 | |
| renewable.append(state.renewable_pct) | |
| if not state.pending_evs: | |
| action = EVGridAction(action_type=ActionType.load_shift, ev_id="EV-000", defer_minutes=0) | |
| else: | |
| if policy == "baseline": | |
| action = baseline_policy(state, graph) | |
| else: | |
| action = oracle.act(state, obs.prompt, graph) if oracle else baseline_policy(state, graph) | |
| if action.action_type == ActionType.defer and state.pending_evs[0].battery_pct_0_100 < 15.0: | |
| critical_deferred += 1 | |
| obs = env.step(action) | |
| if obs.anti_cheat_flags: | |
| anti_cheat_steps += 1 | |
| rb = obs.reward_breakdown | |
| for k, v in rb.items(): | |
| ks = str(k) | |
| if ks == "total" or ks.startswith("_"): | |
| continue | |
| rb_sums[ks] += float(v) | |
| rb_steps += 1 | |
| if obs.done: | |
| break | |
| rb_mean: dict[str, float] = {} | |
| if rb_steps > 0: | |
| rb_mean = {k: rb_sums[k] / rb_steps for k in sorted(rb_sums.keys())} | |
| return EpisodeMetrics( | |
| avg_wait=float(mean(waits)) if waits else 0.0, | |
| grid_stress_events=int(grid_stress), | |
| peak_violations=int(peak_violations), | |
| renewable_mean=float(mean(renewable)) if renewable else 0.0, | |
| critical_deferred=int(critical_deferred), | |
| anti_cheat_steps=int(anti_cheat_steps), | |
| reward_breakdown_mean=rb_mean, | |
| ) | |
| def summarize(metrics: list[EpisodeMetrics]) -> dict[str, float]: | |
| return { | |
| "avg_wait_minutes": mean([m.avg_wait for m in metrics]), | |
| "grid_stress_events": mean([m.grid_stress_events for m in metrics]), | |
| "peak_violations": mean([m.peak_violations for m in metrics]), | |
| "renewable_mean": mean([m.renewable_mean for m in metrics]), | |
| "critical_deferred": mean([m.critical_deferred for m in metrics]), | |
| "anti_cheat_steps": mean([m.anti_cheat_steps for m in metrics]), | |
| } | |
| def summarize_reward_breakdown(metrics: list[EpisodeMetrics]) -> dict[str, float]: | |
| sums: dict[str, float] = defaultdict(float) | |
| counts: dict[str, int] = defaultdict(int) | |
| for m in metrics: | |
| for k, v in m.reward_breakdown_mean.items(): | |
| sums[k] += float(v) | |
| counts[k] += 1 | |
| return {k: sums[k] / counts[k] for k in sorted(sums.keys()) if counts[k]} | |
| SCENARIO_CHOICES: tuple[str, ...] = ( | |
| "baseline", | |
| "heatwave_peak", | |
| "festival_surge", | |
| "transformer_derate", | |
| "station_outage", | |
| "tariff_shock", | |
| ) | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Paired baseline vs oracle rollouts (same seed + scenario per episode).") | |
| ap.add_argument("--episodes", type=int, default=50) | |
| ap.add_argument("--seed", type=int, default=123) | |
| ap.add_argument( | |
| "--scenario", | |
| type=str, | |
| default="baseline", | |
| choices=SCENARIO_CHOICES, | |
| help="Deterministic scenario schedule (see ev_grid_oracle/scenarios.py).", | |
| ) | |
| ap.add_argument("--out", type=str, default="training/eval_results.json") | |
| args = ap.parse_args() | |
| scenario: ScenarioName = args.scenario # type: ignore[assignment] | |
| graph = build_city_graph() | |
| env = EVGridCore(city_graph=graph) | |
| oracle_repo = os.getenv("ORACLE_LORA_REPO", "").strip() or None | |
| baseline_runs: list[EpisodeMetrics] = [] | |
| oracle_runs: list[EpisodeMetrics] = [] | |
| per_episode: list[dict[str, Any]] = [] | |
| for i in range(args.episodes): | |
| episode_seed = args.seed + i | |
| b = run_episode(env, policy="baseline", seed=episode_seed, scenario=scenario, oracle_repo=oracle_repo) | |
| o = run_episode(env, policy="oracle", seed=episode_seed, scenario=scenario, oracle_repo=oracle_repo) | |
| baseline_runs.append(b) | |
| oracle_runs.append(o) | |
| per_episode.append( | |
| { | |
| "episode_index": i, | |
| "episode_seed": episode_seed, | |
| "scenario": scenario, | |
| "baseline": _episode_metrics_to_json(b), | |
| "oracle": _episode_metrics_to_json(o), | |
| "binary": { | |
| "baseline_any_peak_violation": b.peak_violations > 0, | |
| "oracle_any_peak_violation": o.peak_violations > 0, | |
| "baseline_any_anti_cheat": b.anti_cheat_steps > 0, | |
| "oracle_any_anti_cheat": o.anti_cheat_steps > 0, | |
| "baseline_any_critical_defer": b.critical_deferred > 0, | |
| "oracle_any_critical_defer": o.critical_deferred > 0, | |
| "baseline_high_stress": b.grid_stress_events > int(env.max_steps * 0.25), | |
| "oracle_high_stress": o.grid_stress_events > int(env.max_steps * 0.25), | |
| }, | |
| } | |
| ) | |
| out = { | |
| "episodes": args.episodes, | |
| "seed": args.seed, | |
| "scenario": scenario, | |
| "paired_same_world": True, | |
| "baseline": summarize(baseline_runs), | |
| "oracle": summarize(oracle_runs), | |
| "baseline_reward_breakdown_mean": summarize_reward_breakdown(baseline_runs), | |
| "oracle_reward_breakdown_mean": summarize_reward_breakdown(oracle_runs), | |
| "per_episode": per_episode, | |
| "note": "Per-episode: identical episode_seed for baseline and oracle. Oracle uses ORACLE_LORA_REPO if set; " | |
| "ORACLE_SKIP_LLM forces baseline policy inside oracle path.", | |
| } | |
| with open(args.out, "w", encoding="utf-8") as f: | |
| json.dump(out, f, indent=2) | |
| print(json.dumps(out, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |