File size: 15,675 Bytes
78131a0
 
 
 
 
 
 
 
 
 
 
 
 
 
e81353d
78131a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e81353d
78131a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcce6af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e81353d
bcce6af
 
 
 
 
 
 
 
 
 
e81353d
bcce6af
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from typing import Dict, List
from src.models import (
    GridAction, GridObservation, GridReward,
    MultiAgentAction, MultiAgentStepResult,
)
from src.environment import OpenGridEnv
from src.tasks import TASKS
from src.grader import RobustnessGrader, normalize_score, _SCORE_EPSILON, _clamp_score
from src.baseline import heuristic_policy, llm_policy
from src.visualization import generate_dashboard
import copy
import json
import uuid
import os
import time
import pathlib
import threading
import warnings

app = FastAPI(
    title="OpenGrid Environment",
    description="Multi-agent renewable energy grid load-balancing environment with safety constraints",
    version="2.0.0"
)

# Static files — mount only if present (allows API-only or test deployments)
STATIC_DIR = pathlib.Path(__file__).parent / "static"
if STATIC_DIR.exists():
    app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
else:
    warnings.warn(
        f"Static directory not found: {STATIC_DIR}. "
        "Dashboard UI disabled; API endpoints remain available."
    )

# ---------------------------------------------------------------------------
# Session storage with TTL + per-session locking
# ---------------------------------------------------------------------------
# _session_lock guards the sessions/history *dicts* for insert/delete/lookup.
# Each session also has its own lock ("lock" key) that serializes env
# operations, preventing race conditions when concurrent requests target
# the same session (e.g. two /step calls, or /step racing with /grader).
# ---------------------------------------------------------------------------
sessions: Dict[str, Dict] = {}
history: Dict[str, List] = {}
MAX_SESSIONS = 100
SESSION_TTL_SECONDS = 3600  # 1 hour
_session_lock = threading.Lock()

# Grader cache: bounds are expensive (10 rollouts per task), compute once.
# Construction AND bounds estimation are serialized under _grader_lock.
_grader_cache: Dict[str, RobustnessGrader] = {}
_grader_lock = threading.Lock()


def _new_session(env: OpenGridEnv, task_id: str, mode: str, **extra) -> dict:
    """Create a session dict with per-session lock and metadata."""
    session = {
        "env": env,
        "created": time.time(),
        "last_access": time.time(),
        "task_id": task_id,
        "rewards": [],
        "mode": mode,
        "done": False,
        "is_blackout": False,
        "lock": threading.Lock(),
    }
    session.update(extra)
    return session


def _session_age(s: dict, now: float) -> float:
    """Return the last-access timestamp for a session (for eviction sorting)."""
    ts = s.get("last_access")
    if ts is None:
        ts = s.get("created")
    return float(ts) if ts is not None else now


def _cleanup_sessions():
    """Evict expired and excess sessions. Caller must hold _session_lock."""
    now = time.time()

    # Phase 1: evict expired sessions (actual TTL)
    expired = [
        sid for sid, s in sessions.items()
        if now - _session_age(s, now) > SESSION_TTL_SECONDS
    ]
    for sid in expired:
        sessions.pop(sid, None)
        history.pop(sid, None)

    # Phase 2: evict oldest if still over limit
    while len(sessions) >= MAX_SESSIONS:
        oldest_sid = min(
            sessions,
            key=lambda k: _session_age(sessions[k], 0.0),
        )
        sessions.pop(oldest_sid, None)
        history.pop(oldest_sid, None)


def _get_session(session_id: str) -> dict:
    """Look up session, update last_access, raise 404 if missing.
    Caller must NOT hold _session_lock (this acquires it)."""
    with _session_lock:
        session = sessions.get(session_id)
        if session is None:
            raise HTTPException(404, "Session not found")
        session["last_access"] = time.time()
        return session


def _get_grader(task_id: str) -> RobustnessGrader:
    """Get or create a cached RobustnessGrader for a task.

    Both construction and bounds estimation run under _grader_lock
    so concurrent /grader requests don't duplicate or race on
    _estimate_bounds() mutations.
    """
    with _grader_lock:
        if task_id not in _grader_cache:
            grader = RobustnessGrader(copy.deepcopy(TASKS[task_id]))
            grader.get_bounds()  # force expensive mutation while locked
            _grader_cache[task_id] = grader
        return _grader_cache[task_id]


@app.get("/")
def root():
    """Serve the interactive dashboard (or API info if static files absent)."""
    index = STATIC_DIR / "index.html"
    if index.exists():
        return FileResponse(str(index))
    return {"status": "OpenGrid API", "version": "2.0.0", "docs": "/docs"}


@app.get("/health")
def health():
    """Health check endpoint (JSON)."""
    return {"status": "OpenGrid Running", "version": "2.0.0", "docs": "/docs"}


@app.get("/tasks")
def get_tasks():
    """List available tasks with metadata including multi-agent zone info."""
    action_schema = GridAction.model_json_schema()
    obs_schema = GridObservation.model_json_schema()
    return [
        {
            "id": k,
            "difficulty": v.get("difficulty", k.split('_')[1]),
            "num_buses": v["num_buses"],
            "max_steps": v["max_steps"],
            "num_agents": v.get("num_agents", 1),
            "zone_names": v.get("zone_names", []),
            "buses": v.get("buses", []),
            "lines": v.get("lines", []),
            "action_schema": action_schema,
            "observation_schema": obs_schema
        } for k, v in TASKS.items()
    ]


# ===========================================================================
# Single-Agent API (backward compatible)
# ===========================================================================

@app.post("/reset")
def reset(task_id: str = "task_easy"):
    """Reset (or create) an environment session. Returns initial observation."""
    if task_id not in TASKS:
        raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")

    env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
    obs = env.reset()
    sid = str(uuid.uuid4())

    with _session_lock:
        _cleanup_sessions()
        sessions[sid] = _new_session(env, task_id, mode="single")
        history[sid] = [obs]

    return {"session_id": sid, "observation": obs.model_dump()}


@app.post("/step")
def step(session_id: str, action: GridAction):
    """Execute one step in the environment."""
    session = _get_session(session_id)

    # Per-session lock serializes all env operations for this session
    with session["lock"]:
        if session.get("done"):
            raise HTTPException(400, "Episode already done. Call /reset to start a new session.")

        env = session["env"]
        obs, reward, done, info = env.step(action)

        session["rewards"].append(reward.value)
        session["done"] = done
        session["is_blackout"] = info.is_blackout

        with _session_lock:
            history[session_id].append(obs)

    return {
        "observation": obs.model_dump(),
        "reward": reward.model_dump(),
        "done": done,
        "info": info.model_dump()
    }


@app.get("/state")
def get_state(session_id: str):
    """Get current state of a session."""
    session = _get_session(session_id)

    with session["lock"]:
        return session["env"].state().model_dump()


# ===========================================================================
# Multi-Agent POMDP API
# ===========================================================================

@app.post("/reset_multi")
def reset_multi(task_id: str = "task_easy"):
    """Reset environment in multi-agent mode. Returns per-agent partial observations."""
    if task_id not in TASKS:
        raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")

    env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
    zone_obs = env.reset_multi()
    sid = str(uuid.uuid4())

    zone_info = env.get_zone_info()

    with _session_lock:
        _cleanup_sessions()
        sessions[sid] = _new_session(
            env, task_id, mode="multi",
            per_agent_rewards={i: [] for i in range(env.num_agents)},
        )
        # Store full-grid observation for visualization history
        history[sid] = [env.state()]

    return {
        "session_id": sid,
        "num_agents": env.num_agents,
        "zone_info": {str(k): v.model_dump() for k, v in zone_info.items()},
        "observations": {str(k): v.model_dump() for k, v in zone_obs.items()},
    }


@app.post("/step_multi")
def step_multi(session_id: str, actions: MultiAgentAction):
    """Multi-agent step with safety layer and oversight.

    Each agent submits actions for their zone. The safety layer validates,
    the oversight agent evaluates coordination, and per-agent rewards are computed.
    """
    session = _get_session(session_id)

    with session["lock"]:
        if session.get("done"):
            raise HTTPException(400, "Episode already done. Call /reset_multi to start a new session.")

        env = session["env"]
        if session.get("mode") != "multi":
            raise HTTPException(400, "Session not in multi-agent mode. Use /reset_multi first.")

        # Convert string keys from JSON to int keys, with validation
        agent_actions = {}
        for k, v in actions.agent_actions.items():
            try:
                agent_id = int(k) if isinstance(k, str) else k
            except (TypeError, ValueError):
                raise HTTPException(400, f"Invalid agent_id: {k!r}")
            if not (0 <= agent_id < env.num_agents):
                raise HTTPException(
                    400,
                    f"Invalid agent_id {agent_id}; expected 0..{env.num_agents - 1}",
                )
            agent_actions[agent_id] = v

        result = env.step_multi(agent_actions)

        session["rewards"].append(result.team_reward)
        session["done"] = result.done
        session["is_blackout"] = result.info.is_blackout
        for agent_id, reward in result.rewards.items():
            if agent_id in session.get("per_agent_rewards", {}):
                session["per_agent_rewards"][agent_id].append(reward.value)

        # Store full-grid observation for visualization
        with _session_lock:
            history[session_id].append(env.state())

    return {
        "observations": {str(k): v.model_dump() for k, v in result.observations.items()},
        "rewards": {str(k): v.model_dump() for k, v in result.rewards.items()},
        "team_reward": result.team_reward,
        "done": result.done,
        "safety_reports": {str(k): v.model_dump() for k, v in result.safety_reports.items()},
        "oversight_report": result.oversight_report.model_dump(),
        "info": result.info.model_dump(),
    }


@app.get("/zones")
def get_zones(session_id: str):
    """Get zone assignments and agent info for a multi-agent session."""
    session = _get_session(session_id)

    with session["lock"]:
        zone_info = session["env"].get_zone_info()

    return {
        "num_agents": session["env"].num_agents,
        "zones": {str(k): v.model_dump() for k, v in zone_info.items()},
    }


# ===========================================================================
# Grading & Baseline
# ===========================================================================

@app.get("/grader")
def run_grader(session_id: str):
    """
    Grade a completed (or in-progress) session.
    Returns a score strictly in the open interval (0, 1) using the same
    normalization as the /baseline endpoint (analytical ceiling + empirical floor).
    """
    session = _get_session(session_id)

    with session["lock"]:
        rewards = list(session["rewards"])  # snapshot under lock
        task_id = session["task_id"]
        is_blackout = session.get("is_blackout", False)

    if not rewards:
        return {"score": _SCORE_EPSILON, "message": "No steps taken yet. Run /step first."}

    cumulative = sum(rewards)
    n_steps = len(rewards)

    grader = _get_grader(task_id)
    bounds = grader.get_bounds()
    n1_rate = 0.0 if is_blackout else 1.0

    score = normalize_score(
        cumulative_reward=cumulative,
        reward_floor=bounds["reward_floor"],
        reward_ceiling=bounds["reward_ceiling"],
        n1_survival_rate=n1_rate
    )

    # Defense-in-depth: clamp again at the API boundary
    score = _clamp_score(score)

    return {
        "score": score,
        "cumulative_reward": round(cumulative, 4),
        "steps": n_steps,
        "is_blackout": is_blackout,
        "task_id": task_id,
        "reward_floor": bounds["reward_floor"],
        "reward_ceiling": bounds["reward_ceiling"]
    }


@app.get("/baseline")
def run_baseline(use_llm: bool = False):
    """
    Run baseline policy on all registered tasks. Returns 0.0–1.0 scores.
    Default: heuristic (reproducible). Set use_llm=true for LLM agent.

    Uses the same cached grader as /grader — bounds are computed once
    and reused across all endpoints.
    """
    api_key = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", ""))
    if use_llm and not api_key:
        raise HTTPException(
            400,
            "use_llm=true requires HF_TOKEN or OPENAI_API_KEY environment variable",
        )

    policy = llm_policy if use_llm and api_key else heuristic_policy
    policy_name = "llm" if policy is llm_policy else "heuristic"

    results = {}
    for task_id, config in TASKS.items():
        grader = _get_grader(task_id)  # cached — no duplicate rollouts
        res = grader.evaluate_policy(policy, n_episodes=3)
        results[task_id] = res

    return {"policy": policy_name, "baseline_scores": results}


@app.get("/visualize")
def visualize(session_id: str):
    """Generate a visualization of the current grid state and frequency history."""
    session = _get_session(session_id)

    with session["lock"]:
        obs = session["env"].state()
        with _session_lock:
            hist = list(history.get(session_id, []))

    img_str = generate_dashboard(hist, obs)
    return {"image_base64": img_str}


# ===========================================================================
# Training Results
# ===========================================================================

@app.get("/training-results")
def training_results():
    """Return GRPO training results if available."""
    summary_path = pathlib.Path("training/outputs/summary.json")
    if not summary_path.exists():
        return {"available": False}
    with open(summary_path) as f:
        data = json.load(f)
    # Check if it was an error
    if "error" in data:
        return {"available": True, "error": data["error"]}
    # Add plot URLs
    data["available"] = True
    data["plots"] = {}
    for name in ["before_after", "training_loss", "training_reward_curve"]:
        p = pathlib.Path(f"training/outputs/{name}.png")
        if p.exists():
            data["plots"][name] = f"/training-plots/{name}"
    return data


@app.get("/training-plots/{name}")
def training_plot(name: str):
    """Serve a training plot image."""
    from fastapi.responses import FileResponse
    allowed = {"before_after", "training_loss", "training_reward_curve"}
    if name not in allowed:
        raise HTTPException(404, "Plot not found")
    p = pathlib.Path(f"training/outputs/{name}.png")
    if not p.exists():
        raise HTTPException(404, "Plot not generated yet")
    return FileResponse(str(p), media_type="image/png")