Opengrid / app.py
K446's picture
Polish for hackathon submission: training evidence, two pipelines, UI, docs
e81353d
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from typing import Dict, List
from src.models import (
GridAction, GridObservation, GridReward,
MultiAgentAction, MultiAgentStepResult,
)
from src.environment import OpenGridEnv
from src.tasks import TASKS
from src.grader import RobustnessGrader, normalize_score, _SCORE_EPSILON, _clamp_score
from src.baseline import heuristic_policy, llm_policy
from src.visualization import generate_dashboard
import copy
import json
import uuid
import os
import time
import pathlib
import threading
import warnings
app = FastAPI(
title="OpenGrid Environment",
description="Multi-agent renewable energy grid load-balancing environment with safety constraints",
version="2.0.0"
)
# Static files — mount only if present (allows API-only or test deployments)
STATIC_DIR = pathlib.Path(__file__).parent / "static"
if STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
else:
warnings.warn(
f"Static directory not found: {STATIC_DIR}. "
"Dashboard UI disabled; API endpoints remain available."
)
# ---------------------------------------------------------------------------
# Session storage with TTL + per-session locking
# ---------------------------------------------------------------------------
# _session_lock guards the sessions/history *dicts* for insert/delete/lookup.
# Each session also has its own lock ("lock" key) that serializes env
# operations, preventing race conditions when concurrent requests target
# the same session (e.g. two /step calls, or /step racing with /grader).
# ---------------------------------------------------------------------------
sessions: Dict[str, Dict] = {}
history: Dict[str, List] = {}
MAX_SESSIONS = 100
SESSION_TTL_SECONDS = 3600 # 1 hour
_session_lock = threading.Lock()
# Grader cache: bounds are expensive (10 rollouts per task), compute once.
# Construction AND bounds estimation are serialized under _grader_lock.
_grader_cache: Dict[str, RobustnessGrader] = {}
_grader_lock = threading.Lock()
def _new_session(env: OpenGridEnv, task_id: str, mode: str, **extra) -> dict:
"""Create a session dict with per-session lock and metadata."""
session = {
"env": env,
"created": time.time(),
"last_access": time.time(),
"task_id": task_id,
"rewards": [],
"mode": mode,
"done": False,
"is_blackout": False,
"lock": threading.Lock(),
}
session.update(extra)
return session
def _session_age(s: dict, now: float) -> float:
"""Return the last-access timestamp for a session (for eviction sorting)."""
ts = s.get("last_access")
if ts is None:
ts = s.get("created")
return float(ts) if ts is not None else now
def _cleanup_sessions():
"""Evict expired and excess sessions. Caller must hold _session_lock."""
now = time.time()
# Phase 1: evict expired sessions (actual TTL)
expired = [
sid for sid, s in sessions.items()
if now - _session_age(s, now) > SESSION_TTL_SECONDS
]
for sid in expired:
sessions.pop(sid, None)
history.pop(sid, None)
# Phase 2: evict oldest if still over limit
while len(sessions) >= MAX_SESSIONS:
oldest_sid = min(
sessions,
key=lambda k: _session_age(sessions[k], 0.0),
)
sessions.pop(oldest_sid, None)
history.pop(oldest_sid, None)
def _get_session(session_id: str) -> dict:
"""Look up session, update last_access, raise 404 if missing.
Caller must NOT hold _session_lock (this acquires it)."""
with _session_lock:
session = sessions.get(session_id)
if session is None:
raise HTTPException(404, "Session not found")
session["last_access"] = time.time()
return session
def _get_grader(task_id: str) -> RobustnessGrader:
"""Get or create a cached RobustnessGrader for a task.
Both construction and bounds estimation run under _grader_lock
so concurrent /grader requests don't duplicate or race on
_estimate_bounds() mutations.
"""
with _grader_lock:
if task_id not in _grader_cache:
grader = RobustnessGrader(copy.deepcopy(TASKS[task_id]))
grader.get_bounds() # force expensive mutation while locked
_grader_cache[task_id] = grader
return _grader_cache[task_id]
@app.get("/")
def root():
"""Serve the interactive dashboard (or API info if static files absent)."""
index = STATIC_DIR / "index.html"
if index.exists():
return FileResponse(str(index))
return {"status": "OpenGrid API", "version": "2.0.0", "docs": "/docs"}
@app.get("/health")
def health():
"""Health check endpoint (JSON)."""
return {"status": "OpenGrid Running", "version": "2.0.0", "docs": "/docs"}
@app.get("/tasks")
def get_tasks():
"""List available tasks with metadata including multi-agent zone info."""
action_schema = GridAction.model_json_schema()
obs_schema = GridObservation.model_json_schema()
return [
{
"id": k,
"difficulty": v.get("difficulty", k.split('_')[1]),
"num_buses": v["num_buses"],
"max_steps": v["max_steps"],
"num_agents": v.get("num_agents", 1),
"zone_names": v.get("zone_names", []),
"buses": v.get("buses", []),
"lines": v.get("lines", []),
"action_schema": action_schema,
"observation_schema": obs_schema
} for k, v in TASKS.items()
]
# ===========================================================================
# Single-Agent API (backward compatible)
# ===========================================================================
@app.post("/reset")
def reset(task_id: str = "task_easy"):
"""Reset (or create) an environment session. Returns initial observation."""
if task_id not in TASKS:
raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
obs = env.reset()
sid = str(uuid.uuid4())
with _session_lock:
_cleanup_sessions()
sessions[sid] = _new_session(env, task_id, mode="single")
history[sid] = [obs]
return {"session_id": sid, "observation": obs.model_dump()}
@app.post("/step")
def step(session_id: str, action: GridAction):
"""Execute one step in the environment."""
session = _get_session(session_id)
# Per-session lock serializes all env operations for this session
with session["lock"]:
if session.get("done"):
raise HTTPException(400, "Episode already done. Call /reset to start a new session.")
env = session["env"]
obs, reward, done, info = env.step(action)
session["rewards"].append(reward.value)
session["done"] = done
session["is_blackout"] = info.is_blackout
with _session_lock:
history[session_id].append(obs)
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info.model_dump()
}
@app.get("/state")
def get_state(session_id: str):
"""Get current state of a session."""
session = _get_session(session_id)
with session["lock"]:
return session["env"].state().model_dump()
# ===========================================================================
# Multi-Agent POMDP API
# ===========================================================================
@app.post("/reset_multi")
def reset_multi(task_id: str = "task_easy"):
"""Reset environment in multi-agent mode. Returns per-agent partial observations."""
if task_id not in TASKS:
raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
zone_obs = env.reset_multi()
sid = str(uuid.uuid4())
zone_info = env.get_zone_info()
with _session_lock:
_cleanup_sessions()
sessions[sid] = _new_session(
env, task_id, mode="multi",
per_agent_rewards={i: [] for i in range(env.num_agents)},
)
# Store full-grid observation for visualization history
history[sid] = [env.state()]
return {
"session_id": sid,
"num_agents": env.num_agents,
"zone_info": {str(k): v.model_dump() for k, v in zone_info.items()},
"observations": {str(k): v.model_dump() for k, v in zone_obs.items()},
}
@app.post("/step_multi")
def step_multi(session_id: str, actions: MultiAgentAction):
"""Multi-agent step with safety layer and oversight.
Each agent submits actions for their zone. The safety layer validates,
the oversight agent evaluates coordination, and per-agent rewards are computed.
"""
session = _get_session(session_id)
with session["lock"]:
if session.get("done"):
raise HTTPException(400, "Episode already done. Call /reset_multi to start a new session.")
env = session["env"]
if session.get("mode") != "multi":
raise HTTPException(400, "Session not in multi-agent mode. Use /reset_multi first.")
# Convert string keys from JSON to int keys, with validation
agent_actions = {}
for k, v in actions.agent_actions.items():
try:
agent_id = int(k) if isinstance(k, str) else k
except (TypeError, ValueError):
raise HTTPException(400, f"Invalid agent_id: {k!r}")
if not (0 <= agent_id < env.num_agents):
raise HTTPException(
400,
f"Invalid agent_id {agent_id}; expected 0..{env.num_agents - 1}",
)
agent_actions[agent_id] = v
result = env.step_multi(agent_actions)
session["rewards"].append(result.team_reward)
session["done"] = result.done
session["is_blackout"] = result.info.is_blackout
for agent_id, reward in result.rewards.items():
if agent_id in session.get("per_agent_rewards", {}):
session["per_agent_rewards"][agent_id].append(reward.value)
# Store full-grid observation for visualization
with _session_lock:
history[session_id].append(env.state())
return {
"observations": {str(k): v.model_dump() for k, v in result.observations.items()},
"rewards": {str(k): v.model_dump() for k, v in result.rewards.items()},
"team_reward": result.team_reward,
"done": result.done,
"safety_reports": {str(k): v.model_dump() for k, v in result.safety_reports.items()},
"oversight_report": result.oversight_report.model_dump(),
"info": result.info.model_dump(),
}
@app.get("/zones")
def get_zones(session_id: str):
"""Get zone assignments and agent info for a multi-agent session."""
session = _get_session(session_id)
with session["lock"]:
zone_info = session["env"].get_zone_info()
return {
"num_agents": session["env"].num_agents,
"zones": {str(k): v.model_dump() for k, v in zone_info.items()},
}
# ===========================================================================
# Grading & Baseline
# ===========================================================================
@app.get("/grader")
def run_grader(session_id: str):
"""
Grade a completed (or in-progress) session.
Returns a score strictly in the open interval (0, 1) using the same
normalization as the /baseline endpoint (analytical ceiling + empirical floor).
"""
session = _get_session(session_id)
with session["lock"]:
rewards = list(session["rewards"]) # snapshot under lock
task_id = session["task_id"]
is_blackout = session.get("is_blackout", False)
if not rewards:
return {"score": _SCORE_EPSILON, "message": "No steps taken yet. Run /step first."}
cumulative = sum(rewards)
n_steps = len(rewards)
grader = _get_grader(task_id)
bounds = grader.get_bounds()
n1_rate = 0.0 if is_blackout else 1.0
score = normalize_score(
cumulative_reward=cumulative,
reward_floor=bounds["reward_floor"],
reward_ceiling=bounds["reward_ceiling"],
n1_survival_rate=n1_rate
)
# Defense-in-depth: clamp again at the API boundary
score = _clamp_score(score)
return {
"score": score,
"cumulative_reward": round(cumulative, 4),
"steps": n_steps,
"is_blackout": is_blackout,
"task_id": task_id,
"reward_floor": bounds["reward_floor"],
"reward_ceiling": bounds["reward_ceiling"]
}
@app.get("/baseline")
def run_baseline(use_llm: bool = False):
"""
Run baseline policy on all registered tasks. Returns 0.0–1.0 scores.
Default: heuristic (reproducible). Set use_llm=true for LLM agent.
Uses the same cached grader as /grader — bounds are computed once
and reused across all endpoints.
"""
api_key = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", ""))
if use_llm and not api_key:
raise HTTPException(
400,
"use_llm=true requires HF_TOKEN or OPENAI_API_KEY environment variable",
)
policy = llm_policy if use_llm and api_key else heuristic_policy
policy_name = "llm" if policy is llm_policy else "heuristic"
results = {}
for task_id, config in TASKS.items():
grader = _get_grader(task_id) # cached — no duplicate rollouts
res = grader.evaluate_policy(policy, n_episodes=3)
results[task_id] = res
return {"policy": policy_name, "baseline_scores": results}
@app.get("/visualize")
def visualize(session_id: str):
"""Generate a visualization of the current grid state and frequency history."""
session = _get_session(session_id)
with session["lock"]:
obs = session["env"].state()
with _session_lock:
hist = list(history.get(session_id, []))
img_str = generate_dashboard(hist, obs)
return {"image_base64": img_str}
# ===========================================================================
# Training Results
# ===========================================================================
@app.get("/training-results")
def training_results():
"""Return GRPO training results if available."""
summary_path = pathlib.Path("training/outputs/summary.json")
if not summary_path.exists():
return {"available": False}
with open(summary_path) as f:
data = json.load(f)
# Check if it was an error
if "error" in data:
return {"available": True, "error": data["error"]}
# Add plot URLs
data["available"] = True
data["plots"] = {}
for name in ["before_after", "training_loss", "training_reward_curve"]:
p = pathlib.Path(f"training/outputs/{name}.png")
if p.exists():
data["plots"][name] = f"/training-plots/{name}"
return data
@app.get("/training-plots/{name}")
def training_plot(name: str):
"""Serve a training plot image."""
from fastapi.responses import FileResponse
allowed = {"before_after", "training_loss", "training_reward_curve"}
if name not in allowed:
raise HTTPException(404, "Plot not found")
p = pathlib.Path(f"training/outputs/{name}.png")
if not p.exists():
raise HTTPException(404, "Plot not generated yet")
return FileResponse(str(p), media_type="image/png")