K446 commited on
Commit
78131a0
·
0 Parent(s):

OpenGrid: Multi-agent POMDP power grid environment with GRPO training

Browse files

Features:
- Multi-agent POMDP environment with safety layer and oversight agent
- Environment-grounded GRPO reward function (steps actual physics)
- FastAPI server with single/multi-agent APIs, grading, and visualization
- Heuristic baseline, LLM inference pipeline, and training notebook
- Karnataka KPTCL real-world grid task
- 4 task difficulties: easy, medium, hard, karnataka

.dockerignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .pytest_cache/
4
+ .venv/
5
+ venv/
6
+ .git/
7
+ .gitignore
8
+ .vscode/
9
+ .env
10
+
11
+ # Docs (keep README for the Space)
12
+ guide.md
13
+ detailed judging criteria.md
14
+ ui_skill.md
15
+ project-spec.md
16
+ codebase_summary.md
17
+ pyrightconfig.json
18
+
19
+ # Generated files
20
+ inference_output.txt
21
+ generate_code_md.py
22
+ uv.lock
23
+
24
+ # Training outputs (not needed in Docker image)
25
+ training/outputs/
26
+ *.safetensors
27
+ *.bin
28
+
29
+ # Tests not needed in production
30
+ tests/
31
+ test_multiagent.py
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.gif filter=lfs diff=lfs merge=lfs -text
5
+ *.ico filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .pytest_cache/
5
+ .venv/
6
+ venv/
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ .env
11
+ .vscode/
12
+
13
+ # Generated / temporary files
14
+ inference_output.txt
15
+ codebase_summary.md
16
+ generate_code_md.py
17
+ uv.lock
18
+
19
+ # Reference docs (not part of submission)
20
+ guide.md
21
+ detailed judging criteria.md
22
+ ui_skill.md
23
+ project-spec.md
24
+ pyrightconfig.json
25
+
26
+ # Training outputs (large files — push separately or add to HF)
27
+ training/outputs/
28
+ *.safetensors
29
+ *.bin
30
+
31
+ # OS files
32
+ Thumbs.db
33
+ .DS_Store
34
+
35
+ # Duplicate test file (tests/ directory has the real one)
36
+ test_multiagent.py
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Docker Space — OpenGrid
2
+ # Docs: https://huggingface.co/docs/hub/spaces-sdks-docker
3
+
4
+ FROM python:3.10-slim
5
+
6
+ LABEL org.opencontainers.image.title="OpenGrid"
7
+ LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
8
+ LABEL openenv="true"
9
+
10
+ # Create non-root user required by HF Spaces
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
+ ENV PATH="/home/user/.local/bin:$PATH"
14
+
15
+ WORKDIR /app
16
+
17
+ # Install dependencies
18
+ COPY --chown=user requirements.txt .
19
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
20
+
21
+ # Copy application code
22
+ COPY --chown=user . /app
23
+
24
+ # Expose HF Spaces default port
25
+ EXPOSE 7860
26
+
27
+ # Healthcheck
28
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=15s \
29
+ CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
30
+
31
+ # Run the server
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KRISHNA GOYAL
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OpenGrid
3
+ emoji: ⚡
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ <p align="center">
12
+ <img src="static/logo.png" alt="OpenGrid Logo" width="120">
13
+ </p>
14
+
15
+ <h1 align="center">OpenGrid ⚡</h1>
16
+ <p align="center"><strong>Safe Multi-Agent RL for Power Grid Operations</strong></p>
17
+
18
+ <p align="center">
19
+ <a href="https://huggingface.co/spaces/K446/Opengrid"><img src="https://img.shields.io/badge/🤗%20Live%20Demo-HuggingFace%20Space-yellow" alt="Live Demo"></a>
20
+ <a href="https://github.com/krishnagoyal099/Opengrid_env"><img src="https://img.shields.io/badge/GitHub-Repository-181717?logo=github" alt="GitHub"></a>
21
+ <a href="https://github.com/openenv"><img src="https://img.shields.io/badge/OpenEnv-compatible-blue" alt="OpenEnv"></a>
22
+ <a href="https://www.python.org"><img src="https://img.shields.io/badge/python-3.10%2B-blue" alt="Python 3.10+"></a>
23
+ <a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License: MIT"></a>
24
+ </p>
25
+
26
+ ---
27
+
28
+ ## What is OpenGrid?
29
+
30
+ OpenGrid is a **multi-agent reinforcement learning environment** where AI agents control a power grid. Multiple agents, each managing a zone, must coordinate under **partial observability** to keep the lights on — balancing electricity supply and demand in real-time while managing renewable energy volatility.
31
+
32
+ What makes OpenGrid different:
33
+
34
+ - **Multi-Agent POMDP**: 2-3 agents, each seeing only their local zone + noisy global signals
35
+ - **Safety Layer**: Hard constraint filter blocks unsafe actions before they reach the physics engine (N-1 security, anti-islanding, ramp limits)
36
+ - **Oversight Agent**: Monitors cross-zone coordination, penalizes selfish behavior
37
+ - **Composable Rewards**: 6 independent reward functions — survival, frequency, congestion, safety compliance, coordination, efficiency
38
+ - **Real Physics**: DC power flow solver with droop frequency model
39
+
40
+ > **🔗 Try it live:** [huggingface.co/spaces/K446/Opengrid](https://huggingface.co/spaces/K446/Opengrid)
41
+
42
+ ---
43
+
44
+ ## How It Works
45
+
46
+ ```
47
+ ┌─────────────────────────────────────────────────────────┐
48
+ │ MULTI-AGENT LOOP │
49
+ │ │
50
+ │ Each agent observes LOCAL zone state (POMDP) │
51
+ │ │ │
52
+ │ ▼ │
53
+ │ Each agent proposes action (adjust power, switch │
54
+ │ lines — only within their zone) │
55
+ │ │ │
56
+ │ ▼ │
57
+ │ SAFETY LAYER validates all actions: │
58
+ │ - N-1 security check │
59
+ │ - Anti-islanding │
60
+ │ - Projects unsafe → nearest safe alternative │
61
+ │ │ │
62
+ │ ▼ │
63
+ │ OVERSIGHT AGENT evaluates coordination: │
64
+ │ - Detects conflicts between agents │
65
+ │ - Penalizes selfish behavior │
66
+ │ │ │
67
+ │ ▼ │
68
+ │ Physics engine solves DC power flow │
69
+ │ │ │
70
+ │ ▼ │
71
+ │ Per-agent rewards: local + global + safety + coord │
72
+ │ │ │
73
+ │ Repeat for 50 steps — or until blackout! │
74
+ └─────────────────────────────────────────────────────────┘
75
+ ```
76
+
77
+ The agent interacts through a **REST API** — any language or framework that can make HTTP requests can play. Both single-agent (backward compatible) and multi-agent modes are supported.
78
+
79
+ ---
80
+
81
+ ## Three Difficulty Levels
82
+
83
+ | Task | Grid Size | Agents | Renewable Mix | What Makes It Hard |
84
+ |---|---|---|---|---|
85
+ | `task_easy` | 5 buses | 2 | 20% | Basic frequency control, 2-zone coordination |
86
+ | `task_medium` | 10 buses | 3 | 50% | Volatile renewables + congestion + 3-zone POMDP |
87
+ | `task_hard` | 14 buses | 3 | 70% | High volatility, tight margins, complex topology |
88
+ | `task_karnataka` | 15 buses | 4 | Real mix | Real KPTCL topology (Raichur, Ballari, Bengaluru, Mysuru) with GPS coordinates |
89
+
90
+ All tasks run for **50 timesteps**. Scores range from **0.02 to 0.98** (higher = better).
91
+
92
+ ---
93
+
94
+ ## Quick Start
95
+
96
+ ### 1. Clone & Install
97
+
98
+ ```bash
99
+ git clone https://github.com/krishnagoyal099/Opengrid_env.git
100
+ cd Opengrid_env
101
+
102
+ pip install -r requirements.txt
103
+ ```
104
+
105
+ ### 2. Start the Server
106
+
107
+ ```bash
108
+ uvicorn app:app --host 0.0.0.0 --port 7860
109
+ ```
110
+
111
+ Then open [http://localhost:7860](http://localhost:7860) — you'll see the **interactive SCADA dashboard** with a Leaflet.js GIS map showing the Karnataka grid topology in real-time.
112
+
113
+ ### 3. Run the AI Agent
114
+
115
+ ```bash
116
+ # Set your LLM API credentials
117
+ export API_BASE_URL="https://api.openai.com/v1"
118
+ export MODEL_NAME="gpt-4o"
119
+ export HF_TOKEN="your-api-key"
120
+ export ENV_URL="http://localhost:7860"
121
+
122
+ # Run inference on all 3 tasks
123
+ python inference.py
124
+ ```
125
+
126
+ ### 4. Train with GRPO
127
+
128
+ ```bash
129
+ # Test the training pipeline (no GPU needed)
130
+ python training/train_grpo.py --test-mode
131
+
132
+ # Full training with Unsloth (needs GPU)
133
+ python training/train_grpo.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit --use-unsloth
134
+ ```
135
+
136
+ ### Docker (Alternative)
137
+
138
+ ```bash
139
+ docker build -t opengrid .
140
+ docker run -p 7860:7860 opengrid
141
+ ```
142
+
143
+ ---
144
+
145
+ ## Multi-Agent API
146
+
147
+ ### Reset in Multi-Agent Mode
148
+
149
+ ```bash
150
+ curl -X POST "http://localhost:7860/reset_multi?task_id=task_medium"
151
+ # Returns: {
152
+ # "session_id": "abc-123",
153
+ # "num_agents": 3,
154
+ # "zone_info": {"0": {"zone_name": "Bengaluru_Region", "bus_ids": [...]}, ...},
155
+ # "observations": {"0": {...}, "1": {...}, "2": {...}}
156
+ # }
157
+ ```
158
+
159
+ ### Take a Multi-Agent Step
160
+
161
+ ```bash
162
+ curl -X POST "http://localhost:7860/step_multi?session_id=abc-123" \
163
+ -H "Content-Type: application/json" \
164
+ -d '{
165
+ "agent_actions": {
166
+ "0": {"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []},
167
+ "1": {"bus_adjustments": [], "topology_actions": []},
168
+ "2": {"bus_adjustments": [{"bus_id": 9, "delta": -3.0}], "topology_actions": []}
169
+ }
170
+ }'
171
+ # Returns: per-agent observations, per-agent rewards, safety reports, oversight report
172
+ ```
173
+
174
+ ### Single-Agent API (Backward Compatible)
175
+
176
+ The original single-agent API (`/reset`, `/step`, `/state`, `/grader`) is fully preserved.
177
+
178
+ ---
179
+
180
+ ## What Each Agent Sees (POMDP Observation)
181
+
182
+ Each agent receives a **partial** observation of their zone:
183
+
184
+ | Field | Example | Meaning |
185
+ |---|---|---|
186
+ | `grid_frequency` | `49.87` | **Noisy** frequency reading (Gaussian noise added) |
187
+ | `local_buses[].type` | `"solar"` | Bus type (only buses in agent's zone) |
188
+ | `local_buses[].p_injection` | `35.2` | Power output in MW |
189
+ | `boundary_lines[].rho` | `0.78` | Lines connecting to other zones |
190
+ | `internal_lines[].flow` | `62.4` | Lines within agent's zone |
191
+ | `neighbor_signals` | `{1: 12.5}` | Average injection of neighboring zones |
192
+ | `zone_load_mw` | `85.3` | Total load in this zone |
193
+ | `zone_gen_mw` | `42.1` | Total generation in this zone |
194
+
195
+ Agents do **NOT** see buses or lines in other zones — they must coordinate through limited neighbor signals and the shared (but noisy) frequency reading.
196
+
197
+ ---
198
+
199
+ ## Safety Layer
200
+
201
+ The safety layer validates every action BEFORE it reaches the physics engine:
202
+
203
+ | Check | What It Does | If Violated |
204
+ |---|---|---|
205
+ | **Zone Boundary** | Agent can only adjust buses in their zone | Action removed |
206
+ | **N-1 Security** | Grid must survive loss of any single line | Action blocked |
207
+ | **Anti-Islanding** | Opening a line must not disconnect the grid | Switch blocked |
208
+ | **Ramp Limits** | Power changes within physical ramp rates | Delta clamped |
209
+ | **Capacity Limits** | Generation within min/max bounds | Output clamped |
210
+ | **Battery SoC** | Can't discharge below 0 or charge above capacity | Delta clamped |
211
+
212
+ Critically, unsafe actions are **projected to the nearest safe alternative** rather than simply rejected. This preserves the agent's intent while enforcing safety, and provides a richer training signal.
213
+
214
+ ---
215
+
216
+ ## Reward System
217
+
218
+ Six composable, independent reward functions:
219
+
220
+ | Component | Range | When |
221
+ |---|---|---|
222
+ | **survival** | +1.0 / -100.0 | Grid stays connected / blackout |
223
+ | **frequency** | -1.5 to +0.2 | Based on deviation from 50 Hz |
224
+ | **local_congestion** | ≤ 0 | Line overloads in agent's zone |
225
+ | **safety_compliance** | -0.3 to +0.1 | Penalty if safety layer corrected action |
226
+ | **coordination** | ≤ 0 | Penalty for selfish/conflicting actions |
227
+ | **action_cost** | -0.5 / switch | Topology change cost |
228
+
229
+ ---
230
+
231
+ ## Scoring
232
+
233
+ Scores are normalized to **(0.02 – 0.98)** using:
234
+
235
+ ```
236
+ score = (agent_reward - worst_case) / (best_case - worst_case) + N1_bonus
237
+ ```
238
+
239
+ | Bound | How It's Computed |
240
+ |---|---|
241
+ | **Worst case (floor)** | Random agent that chaotically switches lines — causes blackouts fast |
242
+ | **Best case (ceiling)** | Theoretical perfect agent: survives every step + perfect frequency bonus |
243
+ | **N-1 bonus** | Up to +10% for completing the episode without a blackout |
244
+
245
+ ### Baseline Scores (Heuristic Policy)
246
+
247
+ | Task | Score | Strategy |
248
+ |---|---|---|
249
+ | `task_easy` | ~0.90 | Proportional frequency control, no line switching |
250
+ | `task_medium` | ~0.98 | Same heuristic — medium grid happens to be well-balanced |
251
+ | `task_hard` | ~0.98 | Same heuristic — hard grid has more buses but similar dynamics |
252
+ | `task_karnataka` | ~0.98 | 15-bus real topology, 4 zones, generators warm-started |
253
+
254
+ > Reproduce with: `python get_scores.py`
255
+
256
+ ---
257
+
258
+ ## Project Structure
259
+
260
+ ```
261
+ OpenGrid/
262
+ ├── app.py # FastAPI server (single + multi-agent endpoints)
263
+ ├── inference.py # LLM inference script
264
+ ├── get_scores.py # Reproduce baseline scores
265
+ ├── openenv.yaml # OpenEnv manifest
266
+ ├── Dockerfile # Container config
267
+ ├── requirements.txt # Python dependencies
268
+
269
+ ├── src/ # Core environment
270
+ │ ├── models.py # Pydantic models (single + multi-agent)
271
+ │ ├── environment.py # Grid simulation (POMDP + backward-compatible)
272
+ │ ├── physics.py # DC power flow solver
273
+ │ ├── tasks.py # Procedural grid generation with zone assignment
274
+ │ ├── grader.py # Scoring (floor/ceiling normalization)
275
+ │ ├── baseline.py # Heuristic + LLM policies
276
+ │ ├── safety.py # Safety layer (N-1, anti-islanding, projection)
277
+ │ ├── oversight.py # Oversight agent (coordination monitoring)
278
+ │ └── visualization.py # Grid topology & frequency plots
279
+
280
+ ├── training/ # RL training pipeline
281
+ │ ├── train_grpo.py # TRL GRPO training script
282
+ │ └── opengrid_grpo_colab.ipynb # Google Colab notebook for GPU training
283
+
284
+ ├── tests/ # Test suite (28 tests)
285
+ │ ├── test_solver.py # Physics, environment, grader tests
286
+ │ └── test_multi_agent.py # Multi-agent, safety, oversight tests
287
+
288
+ ├── static/ # Dashboard frontend
289
+ │ ├── index.html
290
+ │ ├── style.css
291
+ │ └── app.js
292
+
293
+ └── server/ # Alternative entry point
294
+ └── app.py
295
+ ```
296
+
297
+ ---
298
+
299
+ ## Training Results (GRPO)
300
+
301
+ We trained **Qwen 2.5 1.5B** using GRPO (Group Relative Policy Optimization) on the Karnataka grid topology.
302
+
303
+ ### Training Loss
304
+
305
+ The loss converges from ~0.09 to near 0 by step ~400, confirming end-to-end training pipeline functionality.
306
+
307
+ ### Before vs After (Average Episode Reward)
308
+
309
+ | Task | Heuristic Baseline | GRPO Trained |
310
+ |---|---|---|
311
+ | `task_easy` | 27.6 | 27.6 |
312
+ | `task_medium` | 48.7 | 48.7 |
313
+ | `task_karnataka` | 19.6 | -316.9 |
314
+
315
+ **Key Finding**: Naive LLM training on simplified proxy rewards does not transfer to real-world grid topologies — Karnataka collapses to -316.9. This validates our architectural decision to pair RL agents with a **safety layer + oversight agent**. The heuristic baseline with safety corrections (19.6 reward, zero blackouts) outperforms pure RL, proving that critical infrastructure needs guardrails, not just learned policies.
316
+
317
+ > **Reproduce training**: Open `training/opengrid_grpo_colab.ipynb` in Google Colab (T4 GPU)
318
+
319
+ ---
320
+
321
+ ## Technical Details
322
+
323
+ <details>
324
+ <summary><strong>Physics Engine</strong></summary>
325
+
326
+ - **DC Power Flow** with B-matrix formulation (standard power systems approximation)
327
+ - **Slack bus** absorbs generation/load imbalance after each power flow solve
328
+ - **Islanding detection** via NetworkX graph connectivity checks
329
+ - **Droop frequency model** calibrated to system size: `f = 50.0 - (2.5 / total_capacity) * P_slack`
330
+
331
+ </details>
332
+
333
+ <details>
334
+ <summary><strong>Multi-Agent Design</strong></summary>
335
+
336
+ - Buses partitioned into zones using **greedy modularity community detection** (NetworkX)
337
+ - Each zone maps to a KPTCL transmission region (Bengaluru, Mysuru, Kalburagi)
338
+ - **Partial observability**: agents see only local buses, boundary lines, noisy frequency
339
+ - **Neighbor signals**: each agent receives average injection of adjacent zones
340
+ - **Safety-first**: all actions validated by constraint filter before physics engine
341
+
342
+ </details>
343
+
344
+ <details>
345
+ <summary><strong>Thread Safety</strong></summary>
346
+
347
+ - All session reads/writes are protected by a `threading.Lock`
348
+ - Grader bounds use double-checked locking to avoid duplicate rollouts
349
+ - Safe for concurrent requests from multiple agents
350
+
351
+ </details>
352
+
353
+ <details>
354
+ <summary><strong>Reproducibility</strong></summary>
355
+
356
+ | Component | Mechanism |
357
+ |---|---|
358
+ | Task grids | Seeded procedural generation (`np.random.default_rng`) |
359
+ | Zone partitioning | Deterministic community detection with seed |
360
+ | Wind variability | Per-episode RNG (same seed → same wind pattern) |
361
+ | Floor estimation | Seeded thrash policy + 10 diverse-seeded episodes |
362
+ | Ceiling | Analytical formula (deterministic) |
363
+ | Scoring | Shared `normalize_score()` across all endpoints |
364
+
365
+ </details>
366
+
367
+ ---
368
+
369
+ ## Related Work
370
+
371
+ - **Massgen**: When Multiple LLMs Think Together (Gradient Network, 2025)
372
+ - **Symphony**: Multi-Agent Intelligence in a Collective Fabric (Gradient Network, 2025)
373
+ - **Grid2Op**: Power grid RL environment (RTE, 2020)
374
+ - **OpenEnv**: Standardized agentic execution environments (Scalar/HuggingFace/Meta, 2026)
375
+
376
+ ---
377
+
378
+ ## Links
379
+
380
+ | Resource | URL |
381
+ |---|---|
382
+ | **Live Demo** | [huggingface.co/spaces/K446/Opengrid](https://huggingface.co/spaces/K446/Opengrid) |
383
+ | **GitHub Repo** | [github.com/krishnagoyal099/Opengrid_env](https://github.com/krishnagoyal099/Opengrid_env) |
384
+ | **API Docs (Swagger)** | [huggingface.co/spaces/K446/Opengrid/docs](https://k446-opengrid.hf.space/docs) |
385
+
386
+ ---
387
+
388
+ ## License
389
+
390
+ MIT — see [LICENSE](LICENSE) for details.
app.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from typing import Dict, List
5
+ from src.models import (
6
+ GridAction, GridObservation, GridReward,
7
+ MultiAgentAction, MultiAgentStepResult,
8
+ )
9
+ from src.environment import OpenGridEnv
10
+ from src.tasks import TASKS
11
+ from src.grader import RobustnessGrader, normalize_score, _SCORE_EPSILON, _clamp_score
12
+ from src.baseline import heuristic_policy, llm_policy
13
+ from src.visualization import generate_dashboard
14
+ import copy
15
+ import uuid
16
+ import os
17
+ import time
18
+ import pathlib
19
+ import threading
20
+ import warnings
21
+
22
+ app = FastAPI(
23
+ title="OpenGrid Environment",
24
+ description="Multi-agent renewable energy grid load-balancing environment with safety constraints",
25
+ version="2.0.0"
26
+ )
27
+
28
+ # Static files — mount only if present (allows API-only or test deployments)
29
+ STATIC_DIR = pathlib.Path(__file__).parent / "static"
30
+ if STATIC_DIR.exists():
31
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
32
+ else:
33
+ warnings.warn(
34
+ f"Static directory not found: {STATIC_DIR}. "
35
+ "Dashboard UI disabled; API endpoints remain available."
36
+ )
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Session storage with TTL + per-session locking
40
+ # ---------------------------------------------------------------------------
41
+ # _session_lock guards the sessions/history *dicts* for insert/delete/lookup.
42
+ # Each session also has its own lock ("lock" key) that serializes env
43
+ # operations, preventing race conditions when concurrent requests target
44
+ # the same session (e.g. two /step calls, or /step racing with /grader).
45
+ # ---------------------------------------------------------------------------
46
+ sessions: Dict[str, Dict] = {}
47
+ history: Dict[str, List] = {}
48
+ MAX_SESSIONS = 100
49
+ SESSION_TTL_SECONDS = 3600 # 1 hour
50
+ _session_lock = threading.Lock()
51
+
52
+ # Grader cache: bounds are expensive (10 rollouts per task), compute once.
53
+ # Construction AND bounds estimation are serialized under _grader_lock.
54
+ _grader_cache: Dict[str, RobustnessGrader] = {}
55
+ _grader_lock = threading.Lock()
56
+
57
+
58
+ def _new_session(env: OpenGridEnv, task_id: str, mode: str, **extra) -> dict:
59
+ """Create a session dict with per-session lock and metadata."""
60
+ session = {
61
+ "env": env,
62
+ "created": time.time(),
63
+ "last_access": time.time(),
64
+ "task_id": task_id,
65
+ "rewards": [],
66
+ "mode": mode,
67
+ "done": False,
68
+ "is_blackout": False,
69
+ "lock": threading.Lock(),
70
+ }
71
+ session.update(extra)
72
+ return session
73
+
74
+
75
+ def _session_age(s: dict, now: float) -> float:
76
+ """Return the last-access timestamp for a session (for eviction sorting)."""
77
+ ts = s.get("last_access")
78
+ if ts is None:
79
+ ts = s.get("created")
80
+ return float(ts) if ts is not None else now
81
+
82
+
83
+ def _cleanup_sessions():
84
+ """Evict expired and excess sessions. Caller must hold _session_lock."""
85
+ now = time.time()
86
+
87
+ # Phase 1: evict expired sessions (actual TTL)
88
+ expired = [
89
+ sid for sid, s in sessions.items()
90
+ if now - _session_age(s, now) > SESSION_TTL_SECONDS
91
+ ]
92
+ for sid in expired:
93
+ sessions.pop(sid, None)
94
+ history.pop(sid, None)
95
+
96
+ # Phase 2: evict oldest if still over limit
97
+ while len(sessions) >= MAX_SESSIONS:
98
+ oldest_sid = min(
99
+ sessions,
100
+ key=lambda k: _session_age(sessions[k], 0.0),
101
+ )
102
+ sessions.pop(oldest_sid, None)
103
+ history.pop(oldest_sid, None)
104
+
105
+
106
+ def _get_session(session_id: str) -> dict:
107
+ """Look up session, update last_access, raise 404 if missing.
108
+ Caller must NOT hold _session_lock (this acquires it)."""
109
+ with _session_lock:
110
+ session = sessions.get(session_id)
111
+ if session is None:
112
+ raise HTTPException(404, "Session not found")
113
+ session["last_access"] = time.time()
114
+ return session
115
+
116
+
117
+ def _get_grader(task_id: str) -> RobustnessGrader:
118
+ """Get or create a cached RobustnessGrader for a task.
119
+
120
+ Both construction and bounds estimation run under _grader_lock
121
+ so concurrent /grader requests don't duplicate or race on
122
+ _estimate_bounds() mutations.
123
+ """
124
+ with _grader_lock:
125
+ if task_id not in _grader_cache:
126
+ grader = RobustnessGrader(copy.deepcopy(TASKS[task_id]))
127
+ grader.get_bounds() # force expensive mutation while locked
128
+ _grader_cache[task_id] = grader
129
+ return _grader_cache[task_id]
130
+
131
+
132
+ @app.get("/")
133
+ def root():
134
+ """Serve the interactive dashboard (or API info if static files absent)."""
135
+ index = STATIC_DIR / "index.html"
136
+ if index.exists():
137
+ return FileResponse(str(index))
138
+ return {"status": "OpenGrid API", "version": "2.0.0", "docs": "/docs"}
139
+
140
+
141
+ @app.get("/health")
142
+ def health():
143
+ """Health check endpoint (JSON)."""
144
+ return {"status": "OpenGrid Running", "version": "2.0.0", "docs": "/docs"}
145
+
146
+
147
+ @app.get("/tasks")
148
+ def get_tasks():
149
+ """List available tasks with metadata including multi-agent zone info."""
150
+ action_schema = GridAction.model_json_schema()
151
+ obs_schema = GridObservation.model_json_schema()
152
+ return [
153
+ {
154
+ "id": k,
155
+ "difficulty": v.get("difficulty", k.split('_')[1]),
156
+ "num_buses": v["num_buses"],
157
+ "max_steps": v["max_steps"],
158
+ "num_agents": v.get("num_agents", 1),
159
+ "zone_names": v.get("zone_names", []),
160
+ "buses": v.get("buses", []),
161
+ "action_schema": action_schema,
162
+ "observation_schema": obs_schema
163
+ } for k, v in TASKS.items()
164
+ ]
165
+
166
+
167
+ # ===========================================================================
168
+ # Single-Agent API (backward compatible)
169
+ # ===========================================================================
170
+
171
+ @app.post("/reset")
172
+ def reset(task_id: str = "task_easy"):
173
+ """Reset (or create) an environment session. Returns initial observation."""
174
+ if task_id not in TASKS:
175
+ raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
176
+
177
+ env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
178
+ obs = env.reset()
179
+ sid = str(uuid.uuid4())
180
+
181
+ with _session_lock:
182
+ _cleanup_sessions()
183
+ sessions[sid] = _new_session(env, task_id, mode="single")
184
+ history[sid] = [obs]
185
+
186
+ return {"session_id": sid, "observation": obs.model_dump()}
187
+
188
+
189
+ @app.post("/step")
190
+ def step(session_id: str, action: GridAction):
191
+ """Execute one step in the environment."""
192
+ session = _get_session(session_id)
193
+
194
+ # Per-session lock serializes all env operations for this session
195
+ with session["lock"]:
196
+ if session.get("done"):
197
+ raise HTTPException(400, "Episode already done. Call /reset to start a new session.")
198
+
199
+ env = session["env"]
200
+ obs, reward, done, info = env.step(action)
201
+
202
+ session["rewards"].append(reward.value)
203
+ session["done"] = done
204
+ session["is_blackout"] = info.is_blackout
205
+
206
+ with _session_lock:
207
+ history[session_id].append(obs)
208
+
209
+ return {
210
+ "observation": obs.model_dump(),
211
+ "reward": reward.model_dump(),
212
+ "done": done,
213
+ "info": info.model_dump()
214
+ }
215
+
216
+
217
+ @app.get("/state")
218
+ def get_state(session_id: str):
219
+ """Get current state of a session."""
220
+ session = _get_session(session_id)
221
+
222
+ with session["lock"]:
223
+ return session["env"].state().model_dump()
224
+
225
+
226
+ # ===========================================================================
227
+ # Multi-Agent POMDP API
228
+ # ===========================================================================
229
+
230
+ @app.post("/reset_multi")
231
+ def reset_multi(task_id: str = "task_easy"):
232
+ """Reset environment in multi-agent mode. Returns per-agent partial observations."""
233
+ if task_id not in TASKS:
234
+ raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
235
+
236
+ env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
237
+ zone_obs = env.reset_multi()
238
+ sid = str(uuid.uuid4())
239
+
240
+ zone_info = env.get_zone_info()
241
+
242
+ with _session_lock:
243
+ _cleanup_sessions()
244
+ sessions[sid] = _new_session(
245
+ env, task_id, mode="multi",
246
+ per_agent_rewards={i: [] for i in range(env.num_agents)},
247
+ )
248
+ # Store full-grid observation for visualization history
249
+ history[sid] = [env.state()]
250
+
251
+ return {
252
+ "session_id": sid,
253
+ "num_agents": env.num_agents,
254
+ "zone_info": {str(k): v.model_dump() for k, v in zone_info.items()},
255
+ "observations": {str(k): v.model_dump() for k, v in zone_obs.items()},
256
+ }
257
+
258
+
259
+ @app.post("/step_multi")
260
+ def step_multi(session_id: str, actions: MultiAgentAction):
261
+ """Multi-agent step with safety layer and oversight.
262
+
263
+ Each agent submits actions for their zone. The safety layer validates,
264
+ the oversight agent evaluates coordination, and per-agent rewards are computed.
265
+ """
266
+ session = _get_session(session_id)
267
+
268
+ with session["lock"]:
269
+ if session.get("done"):
270
+ raise HTTPException(400, "Episode already done. Call /reset_multi to start a new session.")
271
+
272
+ env = session["env"]
273
+ if session.get("mode") != "multi":
274
+ raise HTTPException(400, "Session not in multi-agent mode. Use /reset_multi first.")
275
+
276
+ # Convert string keys from JSON to int keys, with validation
277
+ agent_actions = {}
278
+ for k, v in actions.agent_actions.items():
279
+ try:
280
+ agent_id = int(k) if isinstance(k, str) else k
281
+ except (TypeError, ValueError):
282
+ raise HTTPException(400, f"Invalid agent_id: {k!r}")
283
+ if not (0 <= agent_id < env.num_agents):
284
+ raise HTTPException(
285
+ 400,
286
+ f"Invalid agent_id {agent_id}; expected 0..{env.num_agents - 1}",
287
+ )
288
+ agent_actions[agent_id] = v
289
+
290
+ result = env.step_multi(agent_actions)
291
+
292
+ session["rewards"].append(result.team_reward)
293
+ session["done"] = result.done
294
+ session["is_blackout"] = result.info.is_blackout
295
+ for agent_id, reward in result.rewards.items():
296
+ if agent_id in session.get("per_agent_rewards", {}):
297
+ session["per_agent_rewards"][agent_id].append(reward.value)
298
+
299
+ # Store full-grid observation for visualization
300
+ with _session_lock:
301
+ history[session_id].append(env.state())
302
+
303
+ return {
304
+ "observations": {str(k): v.model_dump() for k, v in result.observations.items()},
305
+ "rewards": {str(k): v.model_dump() for k, v in result.rewards.items()},
306
+ "team_reward": result.team_reward,
307
+ "done": result.done,
308
+ "safety_reports": {str(k): v.model_dump() for k, v in result.safety_reports.items()},
309
+ "oversight_report": result.oversight_report.model_dump(),
310
+ "info": result.info.model_dump(),
311
+ }
312
+
313
+
314
+ @app.get("/zones")
315
+ def get_zones(session_id: str):
316
+ """Get zone assignments and agent info for a multi-agent session."""
317
+ session = _get_session(session_id)
318
+
319
+ with session["lock"]:
320
+ zone_info = session["env"].get_zone_info()
321
+
322
+ return {
323
+ "num_agents": session["env"].num_agents,
324
+ "zones": {str(k): v.model_dump() for k, v in zone_info.items()},
325
+ }
326
+
327
+
328
+ # ===========================================================================
329
+ # Grading & Baseline
330
+ # ===========================================================================
331
+
332
+ @app.get("/grader")
333
+ def run_grader(session_id: str):
334
+ """
335
+ Grade a completed (or in-progress) session.
336
+ Returns a score strictly in the open interval (0, 1) using the same
337
+ normalization as the /baseline endpoint (analytical ceiling + empirical floor).
338
+ """
339
+ session = _get_session(session_id)
340
+
341
+ with session["lock"]:
342
+ rewards = list(session["rewards"]) # snapshot under lock
343
+ task_id = session["task_id"]
344
+ is_blackout = session.get("is_blackout", False)
345
+
346
+ if not rewards:
347
+ return {"score": _SCORE_EPSILON, "message": "No steps taken yet. Run /step first."}
348
+
349
+ cumulative = sum(rewards)
350
+ n_steps = len(rewards)
351
+
352
+ grader = _get_grader(task_id)
353
+ bounds = grader.get_bounds()
354
+ n1_rate = 0.0 if is_blackout else 1.0
355
+
356
+ score = normalize_score(
357
+ cumulative_reward=cumulative,
358
+ reward_floor=bounds["reward_floor"],
359
+ reward_ceiling=bounds["reward_ceiling"],
360
+ n1_survival_rate=n1_rate
361
+ )
362
+
363
+ # Defense-in-depth: clamp again at the API boundary
364
+ score = _clamp_score(score)
365
+
366
+ return {
367
+ "score": score,
368
+ "cumulative_reward": round(cumulative, 4),
369
+ "steps": n_steps,
370
+ "is_blackout": is_blackout,
371
+ "task_id": task_id,
372
+ "reward_floor": bounds["reward_floor"],
373
+ "reward_ceiling": bounds["reward_ceiling"]
374
+ }
375
+
376
+
377
+ @app.get("/baseline")
378
+ def run_baseline(use_llm: bool = False):
379
+ """
380
+ Run baseline policy on all registered tasks. Returns 0.0–1.0 scores.
381
+ Default: heuristic (reproducible). Set use_llm=true for LLM agent.
382
+
383
+ Uses the same cached grader as /grader — bounds are computed once
384
+ and reused across all endpoints.
385
+ """
386
+ api_key = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", ""))
387
+ if use_llm and not api_key:
388
+ raise HTTPException(
389
+ 400,
390
+ "use_llm=true requires HF_TOKEN or OPENAI_API_KEY environment variable",
391
+ )
392
+
393
+ policy = llm_policy if use_llm and api_key else heuristic_policy
394
+ policy_name = "llm" if policy is llm_policy else "heuristic"
395
+
396
+ results = {}
397
+ for task_id, config in TASKS.items():
398
+ grader = _get_grader(task_id) # cached — no duplicate rollouts
399
+ res = grader.evaluate_policy(policy, n_episodes=3)
400
+ results[task_id] = res
401
+
402
+ return {"policy": policy_name, "baseline_scores": results}
403
+
404
+
405
+ @app.get("/visualize")
406
+ def visualize(session_id: str):
407
+ """Generate a visualization of the current grid state and frequency history."""
408
+ session = _get_session(session_id)
409
+
410
+ with session["lock"]:
411
+ obs = session["env"].state()
412
+ with _session_lock:
413
+ hist = list(history.get(session_id, []))
414
+
415
+ img_str = generate_dashboard(hist, obs)
416
+ return {"image_base64": img_str}
changes.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Notebook Changes — opengrid_grpo_colab.ipynb
2
+
3
+ ## Bug fixes applied (2026-04-25)
4
+
5
+ ### Cell 7 — Generate Training Prompts
6
+
7
+ | # | Severity | Bug | Fix |
8
+ |---|----------|-----|-----|
9
+ | 1 | 🔴 Critical | `obs_dict = obs.model_dump()` produces dicts with integer keys; `Dataset.from_dict({"obs_context": obs_contexts})` fails with `ArrowTypeError: Expected dict key of type str or bytes, got 'int'` | Changed to `json.loads(obs.model_dump_json())` so all keys are strings; then stored as `json.dumps(obs_dict)` — a flat JSON string PyArrow handles trivially |
10
+ | 2 | 🟡 Bug | `env = OpenGridEnv(task_config)` instantiated before the loop but immediately replaced inside the loop — wasted object creation | Removed stray instantiation |
11
+ | 3 | 🟡 Bug | `import copy`, `import json` inside inner loop body — re-imported on every iteration | Moved to top of cell |
12
+ | 4 | 🟡 Bug | Slack bus included in random action choices — physics solver overwrites it, wasting action budget | Filtered to `['generator', 'battery']` only |
13
+
14
+ ### Cell 8 — Reward Function
15
+
16
+ | # | Severity | Bug | Fix |
17
+ |---|----------|-----|-----|
18
+ | 5 | 🔴 Critical | `reward_fn` received `obs_context` as JSON strings from the dataset column but passed them directly to `compute_grpo_reward` which expects dicts | Added `json.loads(ctx) if isinstance(ctx, str) else ctx` deserialization before scoring |
19
+ | 6 | 🟡 Bug | No assertion to catch silent arity mismatches | Added `assert len(test_rewards) == 2` sanity check |
20
+
21
+ ### Cell 9 — Training
22
+
23
+ | # | Severity | Bug | Fix |
24
+ |---|----------|-----|-----|
25
+ | 7 | 🟡 Bug | `bf16=torch.cuda.is_bf16_supported()` raises `AssertionError` when CUDA is not available (no GPU runtime) | Guarded: `_cuda_ok = torch.cuda.is_available()` then `_bf16 = _cuda_ok and ...` |
26
+
27
+ ### Cell 12 — Before/After Plot
28
+
29
+ | # | Severity | Bug | Fix |
30
+ |---|----------|-----|-----|
31
+ | 8 | 🟡 Bug | Bar labels used `va='bottom'` for all bars; for negative-height bars the label renders inside/below the bar | Fixed: `va='bottom'` when `h >= 0`, `va='top'` when `h < 0`, with matching y-offset |
32
+
33
+ ### Cell 13 — Summary Table
34
+
35
+ | # | Severity | Bug | Fix |
36
+ |---|----------|-----|-----|
37
+ | 9 | 🟡 Bug | `common_tasks` was set in Cell 12; if the user skips the plot cell, Cell 13 raises `NameError: common_tasks` | Rebuilt `common_tasks` defensively at the top of Cell 13 |
38
+
39
+ ---
40
+
41
+ ## `inference.py` — Code review fixes (2026-04-25)
42
+
43
+ ### High-priority fixes
44
+
45
+ | # | Severity | Issue | Fix |
46
+ |---|----------|-------|-----|
47
+ | 1 | 🔴 Bug | `parse_action()` crashes on valid JSON that is not an object (e.g. `[]`) — `AttributeError` not caught by `except (json.JSONDecodeError, KeyError)` | Rewrote with `isinstance(data, dict)` guard, list-unwrapping, field-type validation, and broad `except Exception` |
48
+ | 2 | 🔴 Bug | `parse_action()` markdown/prose stripping is fragile — fails on `Here is the action: {...}` | Extracts first `{...}` substring via `text.find("{")` / `text.rfind("}")` |
49
+ | 3 | 🔴 Reliability | `/grader` call can exceed `httpx` 30s timeout on first use (lazy `RobustnessGrader` bound estimation) | `grade()` now uses `timeout=180.0`; base client uses `httpx.Timeout(connect=10, read=60, write=30, pool=10)` |
50
+ | 4 | 🟡 Bug | `HF_TOKEN` takes precedence over `OPENAI_API_KEY` — if both set with OpenAI endpoint, auth fails | Changed to `API_KEY or OPENAI_API_KEY or HF_TOKEN` priority order |
51
+ | 5 | 🟡 Bug | No JSON-mode enforcement for LLM — models return markdown/prose | Added `response_format={"type": "json_object"}` with fallback for unsupported endpoints |
52
+
53
+ ### System prompt fixes
54
+
55
+ | # | Severity | Issue | Fix |
56
+ |---|----------|-------|-----|
57
+ | 6 | 🟡 Design | Prompt says slack bus is controllable, but physics solver overwrites it | Changed to: "avoid adjusting the slack bus — physics overwrites it" |
58
+ | 7 | 🟡 Design | Single-agent mode allows topology actions without safety layer protection | Added: "Prefer NO topology actions unless absolutely necessary" |
59
+ | 8 | 🟡 Design | Multi-agent prompt says "Only for lines in your zone" but observations include boundary lines | Clarified: "Only for visible internal or boundary lines. Boundary-line switching is risky" |
60
+
61
+ ### Multi-agent robustness fixes
62
+
63
+ | # | Severity | Issue | Fix |
64
+ |---|----------|-------|-----|
65
+ | 9 | 🟡 Bug | Agent iteration uses `range(num_agents)` — assumes contiguous integer IDs | Changed to `sorted(observations.keys())` |
66
+ | 10 | 🟡 Bug | `safety_reports` assumed to be list, but API returns dict keyed by agent ID | Added `isinstance` check to handle both list and dict formats |
67
+ | 11 | 🟡 Design | Safety correction feedback not fed back to LLM — model repeats same invalid actions | Appended `[SAFETY] {reason}` to agent history when corrections occur |
68
+
69
+ ### Other fixes
70
+
71
+ | # | Severity | Issue | Fix |
72
+ |---|----------|-------|-----|
73
+ | 12 | 🟡 Bug | `MAX_STEPS = 50` hardcoded — may truncate future tasks | Changed to `MAX_STEPS = 100` as safety cap; `done` flag is the true terminator |
74
+ | 13 | 🟡 Bug | Default task list excludes `task_karnataka` despite KPTCL multi-agent framing | Added `task_karnataka` to `TASKS` list |
75
+ | 14 | 🟡 Bug | Module docstring says all 3 env vars are required; only API key is | Fixed docstring to document defaults and actual requirements |
76
+ | 15 | 🟡 Bug | `[END]` log prints score at `.2f` but summary prints `.4f` — precision loss | Changed `log_end` to use `:.4f` |
77
+ | 16 | 🟡 Reliability | `OpenAI()` client has no timeout or retry config | Added `timeout=30.0, max_retries=2` |
78
+ | 17 | 🟢 Feature | No `list_tasks()` method on `EnvClient` | Added `list_tasks()` for future task validation |
79
+
80
+ ---
81
+
82
+ ## GRPO Training — Environment-Grounded Rewards (2026-04-25)
83
+
84
+ ### Root Cause: Proxy Reward Disconnect
85
+
86
+ The original `compute_grpo_reward` was a **heuristic proxy scorer** that evaluated JSON format, direction, and proportionality without ever stepping the environment. The model optimized this proxy, which did not correlate with actual grid physics rewards. Result: zero improvement over baseline.
87
+
88
+ ### Changes Made
89
+
90
+ #### `src/environment.py`
91
+
92
+ | # | Change | Purpose |
93
+ |---|--------|---------|
94
+ | 1 | Added `_set_state(obs_dict)` method to `OpenGridEnv` | Enables restoring environment to any observed state for reward computation. Rebuilds bus/line state, frequency, and slack injection from observation dicts. |
95
+
96
+ #### `training/train_grpo.py`
97
+
98
+ | # | Severity | Change | Details |
99
+ |---|----------|--------|---------|
100
+ | 2 | 🔴 Critical | Replaced `compute_grpo_reward` with `compute_grpo_reward_env` | New reward function **actually steps the physics simulation**: restores env state → steps with LLM action → measures real reward → runs mini-rollout with heuristic continuation for trajectory awareness |
101
+ | 3 | 🔴 Critical | Added mini-rollout scoring (horizon=3) | After the LLM's action, runs 2 more steps with heuristic policy to capture trajectory-level impact. Combines: `immediate_reward + 0.5 * rollout_reward` |
102
+ | 4 | 🟡 Medium | Increased `num_generations` from 4 → 8 | Wider GRPO group = more reward variance = stronger ranking signal. Prevents the advantage calculation from collapsing to zero. |
103
+ | 5 | 🟡 Medium | Increased random perturbation range from ±15 → ±30 MW | Creates more diverse/stressed grid states during training data generation. Model sees near-blackout and overload scenarios. |
104
+ | 6 | 🟡 Medium | Added adversarial battery drain (every 5th episode) | Forces model to learn actions when batteries are near-empty — a critical edge case the original data lacked. |
105
+ | 7 | 🟡 Medium | Multi-bus perturbations (1-2 buses per step) | Was single-bus. More diverse action patterns create richer state transitions. |
106
+ | 8 | 🟡 Medium | Increased learning rate from 5e-6 → 1e-5 | Slightly more aggressive to capitalize on the now-meaningful reward signal. |
107
+ | 9 | 🟡 Medium | Increased gradient accumulation (effective batch 16) | Smoother gradients for more stable training. |
108
+ | 10 | 🟡 Medium | Steps per episode increased from 10 → 15 | More temporal diversity in observations. |
109
+ | 11 | 🟢 Minor | obs_context stored as JSON string | Fixes Arrow serialization (PyArrow can't handle dicts with int keys). |
110
+ | 12 | 🟢 Minor | Kept legacy `compute_grpo_reward` for test-mode compat | Backward compatibility with `--test-mode` pipeline verification. |
111
+
inference.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenGrid Inference Script
3
+ =========================
4
+ Runs an LLM agent against all OpenGrid tasks via the OpenAI-compatible API.
5
+ Supports both single-agent and multi-agent POMDP modes.
6
+
7
+ Optional environment variables:
8
+ API_BASE_URL -- defaults to https://api.openai.com/v1
9
+ MODEL_NAME -- defaults to gpt-4o
10
+ Required (one of):
11
+ OPENAI_API_KEY or HF_TOKEN
12
+
13
+ Emits structured [START], [STEP], [END] logs to stdout.
14
+
15
+ Usage:
16
+ # Single-agent mode (backward compatible)
17
+ python inference.py
18
+
19
+ # Multi-agent mode (uses safety layer + oversight)
20
+ python inference.py --multi
21
+ """
22
+
23
+ import os
24
+ import sys
25
+ import json
26
+ import math
27
+ import argparse
28
+ import httpx
29
+
30
+ from openai import OpenAI
31
+
32
+ # ---------- Configuration ----------
33
+
34
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
35
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
36
+
37
+ # Prefer OPENAI_API_KEY when using OpenAI endpoint; otherwise try HF_TOKEN
38
+ API_KEY = (
39
+ os.environ.get("API_KEY")
40
+ or os.environ.get("OPENAI_API_KEY")
41
+ or os.environ.get("HF_TOKEN")
42
+ or ""
43
+ )
44
+
45
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
46
+ BENCHMARK = "OpenGrid"
47
+ # Safety cap — the environment's 'done' flag is the true terminator
48
+ MAX_STEPS = 100
49
+ SUCCESS_SCORE_THRESHOLD = 0.5
50
+
51
+ TASKS = ["task_easy", "task_medium", "task_hard", "task_karnataka"]
52
+
53
+ SYSTEM_PROMPT_SINGLE = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
54
+
55
+ Key objectives:
56
+ 1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5-50.5 Hz)
57
+ 2. Prevent transmission line overloads (rho < 1.0)
58
+ 3. Avoid grid islanding (blackout)
59
+
60
+ Available actions:
61
+ 1. bus_adjustments: List of {"bus_id": int, "delta": float}
62
+ - Positive delta = increase power injection (discharge battery / ramp up generator)
63
+ - Negative delta = decrease power injection (charge battery / ramp down generator)
64
+ - Only works on battery and generator buses (avoid adjusting the slack bus — physics overwrites it)
65
+ 2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
66
+ - Opening a line removes it; closing reconnects. 3-step cooldown.
67
+ - WARNING: Opening lines can cause islanding -> blackout
68
+ - Prefer NO topology actions unless absolutely necessary. Always return "topology_actions": []
69
+
70
+ Strategy:
71
+ - If frequency < 50 Hz -> discharge batteries, ramp up generators
72
+ - If frequency > 50 Hz -> charge batteries, ramp down generators
73
+ - If a line rho > 0.9 -> reduce generation near that line, do NOT open it
74
+ - Prefer minimal actions over aggressive switching
75
+
76
+ Respond with ONLY a valid JSON object. Example:
77
+ {"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
78
+ """
79
+
80
+ SYSTEM_PROMPT_MULTI = """You are a KPTCL Zone Controller AI managing one zone of the Karnataka power grid.
81
+ You can only see and control buses in YOUR zone. Other zones are managed by other agents.
82
+
83
+ Key objectives:
84
+ 1. Keep grid frequency close to 50.0 Hz (you see a noisy reading)
85
+ 2. Prevent line overloads in your zone (rho < 1.0)
86
+ 3. Coordinate with other zones (don't fight against them)
87
+ 4. Avoid actions that would trigger the safety layer
88
+
89
+ Available actions:
90
+ 1. bus_adjustments: List of {"bus_id": int, "delta": float}
91
+ - ONLY adjust battery and generator buses in YOUR zone (avoid slack — physics overwrites it)
92
+ - Positive delta = increase power injection
93
+ - Negative delta = decrease power injection
94
+ 2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
95
+ - Only for visible internal or boundary lines. Safety layer will block dangerous switches.
96
+ - Boundary-line switching is risky; avoid unless necessary.
97
+
98
+ Strategy:
99
+ - If frequency < 50 Hz -> increase generation/discharge in your zone
100
+ - If frequency > 50 Hz -> decrease generation/charge in your zone
101
+ - Check neighbor signals to understand if other zones are compensating
102
+ - Prefer small corrections over large swings
103
+
104
+ Respond with ONLY a valid JSON object. Example:
105
+ {"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
106
+ """
107
+
108
+
109
+ # ---------- Structured Logging ----------
110
+
111
+ def log_start(task: str, env: str, model: str, mode: str = "single"):
112
+ print(f"[START] task={task} env={env} model={model} mode={mode}", flush=True)
113
+
114
+
115
+ def log_step(step: int, action: str, reward: float, done: bool, error=None, agent_id=None):
116
+ done_val = str(done).lower()
117
+ error_val = str(error) if error else "null"
118
+ agent_str = f" agent={agent_id}" if agent_id is not None else ""
119
+ print(
120
+ f"[STEP] step={step}{agent_str} action={action} reward={reward:.2f} done={done_val} error={error_val}",
121
+ flush=True,
122
+ )
123
+
124
+
125
+ def clamp_score(s: float) -> float:
126
+ """Ensure score is strictly in (0, 1). Mirrors grader._clamp_score."""
127
+ try:
128
+ s = float(s)
129
+ except (TypeError, ValueError):
130
+ return 0.5
131
+ if not math.isfinite(s):
132
+ return 0.5
133
+ s = max(0.02, min(0.98, s))
134
+ s = math.floor(s * 10000) / 10000
135
+ return max(0.02, min(0.98, s))
136
+
137
+
138
+ def log_end(success: bool, steps: int, score: float, rewards: list, mode: str = "single"):
139
+ clamped = clamp_score(score)
140
+ success_val = str(success).lower()
141
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
142
+ print(
143
+ f"[END] success={success_val} steps={steps} score={clamped:.4f} rewards={rewards_str} mode={mode}",
144
+ flush=True,
145
+ )
146
+
147
+
148
+ # ---------- LLM Call ----------
149
+
150
+ def get_model_message(client: OpenAI, step: int, obs_json: str, last_reward: float,
151
+ history: list, system_prompt: str, zone_name: str = None) -> str:
152
+ """Ask the LLM what action to take given the current observation."""
153
+ context = ""
154
+ if zone_name:
155
+ context += f"[Zone: {zone_name}] "
156
+ context += f"Step {step} | Last reward: {last_reward:+.2f}\n"
157
+ if history:
158
+ context += "Recent history (last 3):\n" + "\n".join(history[-3:]) + "\n\n"
159
+ context += f"Current Grid State:\n{obs_json}"
160
+
161
+ try:
162
+ kwargs = dict(
163
+ model=MODEL_NAME,
164
+ messages=[
165
+ {"role": "system", "content": system_prompt},
166
+ {"role": "user", "content": context}
167
+ ],
168
+ temperature=0.0,
169
+ max_tokens=300,
170
+ )
171
+ # Use JSON mode if the endpoint supports it (OpenAI-compatible)
172
+ try:
173
+ kwargs["response_format"] = {"type": "json_object"}
174
+ response = client.chat.completions.create(**kwargs)
175
+ except Exception:
176
+ # Fallback: endpoint may not support response_format
177
+ kwargs.pop("response_format", None)
178
+ response = client.chat.completions.create(**kwargs)
179
+ return response.choices[0].message.content.strip()
180
+ except Exception as exc:
181
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
182
+ return '{"bus_adjustments": [], "topology_actions": []}'
183
+
184
+
185
+ # ---------- Environment Client ----------
186
+
187
+ class EnvClient:
188
+ """HTTP client for the OpenGrid FastAPI environment."""
189
+
190
+ def __init__(self, base_url: str):
191
+ self.base_url = base_url.rstrip("/")
192
+ self.client = httpx.Client(
193
+ timeout=httpx.Timeout(connect=10.0, read=60.0, write=30.0, pool=10.0)
194
+ )
195
+ self.session_id = None
196
+
197
+ # --- Single-Agent ---
198
+
199
+ def reset(self, task_id: str) -> dict:
200
+ resp = self.client.post(f"{self.base_url}/reset", params={"task_id": task_id})
201
+ resp.raise_for_status()
202
+ data = resp.json()
203
+ self.session_id = data["session_id"]
204
+ return data["observation"]
205
+
206
+ def step(self, action_dict: dict) -> dict:
207
+ resp = self.client.post(
208
+ f"{self.base_url}/step",
209
+ params={"session_id": self.session_id},
210
+ json=action_dict
211
+ )
212
+ resp.raise_for_status()
213
+ return resp.json()
214
+
215
+ # --- Multi-Agent ---
216
+
217
+ def reset_multi(self, task_id: str) -> dict:
218
+ resp = self.client.post(f"{self.base_url}/reset_multi", params={"task_id": task_id})
219
+ resp.raise_for_status()
220
+ data = resp.json()
221
+ self.session_id = data["session_id"]
222
+ return data
223
+
224
+ def step_multi(self, agent_actions: dict) -> dict:
225
+ resp = self.client.post(
226
+ f"{self.base_url}/step_multi",
227
+ params={"session_id": self.session_id},
228
+ json={"agent_actions": agent_actions}
229
+ )
230
+ resp.raise_for_status()
231
+ return resp.json()
232
+
233
+ # --- Shared ---
234
+
235
+ def state(self) -> dict:
236
+ resp = self.client.get(f"{self.base_url}/state", params={"session_id": self.session_id})
237
+ resp.raise_for_status()
238
+ return resp.json()
239
+
240
+ def grade(self) -> dict:
241
+ # Grading can trigger lazy bound estimation (multiple rollouts) — use long timeout
242
+ resp = self.client.get(
243
+ f"{self.base_url}/grader",
244
+ params={"session_id": self.session_id},
245
+ timeout=180.0,
246
+ )
247
+ resp.raise_for_status()
248
+ return resp.json()
249
+
250
+ def list_tasks(self) -> list:
251
+ """Fetch available tasks from the server."""
252
+ resp = self.client.get(f"{self.base_url}/tasks")
253
+ resp.raise_for_status()
254
+ return resp.json()
255
+
256
+ def close(self):
257
+ self.client.close()
258
+
259
+
260
+ # ---------- Parse Action ----------
261
+
262
+ NOOP_ACTION = {"bus_adjustments": [], "topology_actions": []}
263
+
264
+
265
+ def parse_action(response_text: str) -> dict:
266
+ """Parse LLM JSON response into an action dict.
267
+
268
+ Handles markdown fences, prose preambles, JSON lists, and malformed output.
269
+ """
270
+ try:
271
+ text = str(response_text).strip()
272
+
273
+ # Strip markdown code fences
274
+ if text.startswith("```"):
275
+ lines = text.splitlines()
276
+ if lines and lines[0].startswith("```"):
277
+ lines = lines[1:]
278
+ if lines and lines[-1].startswith("```"):
279
+ lines = lines[:-1]
280
+ text = "\n".join(lines).strip()
281
+
282
+ # Extract first JSON object from any surrounding prose
283
+ start = text.find("{")
284
+ end = text.rfind("}")
285
+ if start < 0 or end <= start:
286
+ return dict(NOOP_ACTION)
287
+
288
+ data = json.loads(text[start:end + 1])
289
+
290
+ # Handle list wrapping (e.g. [{...}])
291
+ if isinstance(data, list):
292
+ data = data[0] if data else {}
293
+ if not isinstance(data, dict):
294
+ return dict(NOOP_ACTION)
295
+
296
+ bus_adjustments = data.get("bus_adjustments", [])
297
+ topology_actions = data.get("topology_actions", [])
298
+
299
+ if not isinstance(bus_adjustments, list):
300
+ bus_adjustments = []
301
+ if not isinstance(topology_actions, list):
302
+ topology_actions = []
303
+
304
+ return {
305
+ "bus_adjustments": bus_adjustments,
306
+ "topology_actions": topology_actions,
307
+ }
308
+ except Exception:
309
+ return dict(NOOP_ACTION)
310
+
311
+
312
+ # ---------- Single-Agent Runner ----------
313
+
314
+ def run_task_single(client: OpenAI, env: EnvClient, task_id: str) -> dict:
315
+ """Run one task in single-agent mode and return results."""
316
+ history_msgs = []
317
+ rewards = []
318
+ steps_taken = 0
319
+ score = 0.05
320
+ success = False
321
+
322
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME, mode="single")
323
+
324
+ try:
325
+ obs = env.reset(task_id)
326
+ last_reward = 0.0
327
+
328
+ for step_num in range(1, MAX_STEPS + 1):
329
+ obs_json = json.dumps(obs, indent=2)
330
+ message = get_model_message(client, step_num, obs_json, last_reward,
331
+ history_msgs, SYSTEM_PROMPT_SINGLE)
332
+ action_dict = parse_action(message)
333
+
334
+ result = env.step(action_dict)
335
+ obs = result["observation"]
336
+ reward = result.get("reward", {}).get("value", 0.0)
337
+ done = result.get("done", False)
338
+
339
+ rewards.append(reward)
340
+ steps_taken = step_num
341
+ last_reward = reward
342
+
343
+ action_summary = json.dumps(action_dict)
344
+ if len(action_summary) > 200:
345
+ action_summary = action_summary[:200] + "..."
346
+
347
+ log_step(step=step_num, action=action_summary, reward=reward, done=done)
348
+
349
+ history_msgs.append(f"Step {step_num}: action={action_summary[:80]} -> reward {reward:+.2f}")
350
+
351
+ if done:
352
+ break
353
+
354
+ grade_result = env.grade()
355
+ score = clamp_score(grade_result.get("score", 0.5))
356
+ success = score >= SUCCESS_SCORE_THRESHOLD
357
+
358
+ except Exception as e:
359
+ print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
360
+ score = 0.05
361
+ success = False
362
+
363
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards, mode="single")
364
+
365
+ return {"task": task_id, "score": score, "steps": steps_taken, "success": success}
366
+
367
+
368
+ # ---------- Multi-Agent Runner ----------
369
+
370
+ def run_task_multi(client: OpenAI, env: EnvClient, task_id: str) -> dict:
371
+ """Run one task in multi-agent mode and return results."""
372
+ rewards = []
373
+ steps_taken = 0
374
+ score = 0.05
375
+ success = False
376
+ total_safety_interventions = 0
377
+
378
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME, mode="multi")
379
+
380
+ try:
381
+ reset_data = env.reset_multi(task_id)
382
+ num_agents = reset_data["num_agents"]
383
+ zone_info = reset_data["zone_info"]
384
+ observations = reset_data["observations"]
385
+
386
+ # Per-agent history
387
+ agent_histories = {str(i): [] for i in range(num_agents)}
388
+ last_rewards = {str(i): 0.0 for i in range(num_agents)}
389
+
390
+ print(f"[INFO] Multi-agent mode: {num_agents} agents", flush=True)
391
+ for aid, zi in zone_info.items():
392
+ print(f" Agent {aid}: {zi['zone_name']} ({len(zi['bus_ids'])} buses)", flush=True)
393
+
394
+ for step_num in range(1, MAX_STEPS + 1):
395
+ agent_actions = {}
396
+
397
+ # Each agent generates its own action based on partial observation
398
+ for agent_id_str in sorted(observations.keys()):
399
+ obs = observations.get(agent_id_str, {})
400
+ zone_name = zone_info.get(agent_id_str, {}).get("zone_name", f"Zone_{agent_id_str}")
401
+
402
+ obs_json = json.dumps(obs, indent=2)
403
+ message = get_model_message(
404
+ client, step_num, obs_json,
405
+ last_rewards[agent_id_str],
406
+ agent_histories[agent_id_str],
407
+ SYSTEM_PROMPT_MULTI,
408
+ zone_name=zone_name
409
+ )
410
+ action_dict = parse_action(message)
411
+ agent_actions[agent_id_str] = action_dict
412
+
413
+ # Submit all actions together
414
+ result = env.step_multi(agent_actions)
415
+ observations = result["observations"]
416
+ team_reward = result.get("team_reward", 0.0)
417
+ done = result.get("done", False)
418
+
419
+ # Track safety interventions
420
+ safety_reports = result.get("safety_reports", {})
421
+ if isinstance(safety_reports, list):
422
+ # Handle list format from older API
423
+ step_interventions = sum(1 for sr in safety_reports if sr.get("was_corrected", False))
424
+ else:
425
+ step_interventions = sum(
426
+ 1 for sr in safety_reports.values() if sr.get("was_corrected", False)
427
+ )
428
+ total_safety_interventions += step_interventions
429
+
430
+ # Feed safety correction feedback into agent histories
431
+ if isinstance(safety_reports, dict):
432
+ for aid_str, sr in safety_reports.items():
433
+ if sr.get("was_corrected") and aid_str in agent_histories:
434
+ reason = sr.get("correction_reason", "action corrected")[:120]
435
+ agent_histories[aid_str].append(f"[SAFETY] {reason}")
436
+
437
+ # Log per-agent rewards
438
+ per_agent_rewards = result.get("rewards", {})
439
+ for agent_id_str in sorted(observations.keys()):
440
+ agent_reward = per_agent_rewards.get(agent_id_str, {}).get("value", 0.0)
441
+ last_rewards[agent_id_str] = agent_reward
442
+ action_summary = json.dumps(agent_actions.get(agent_id_str, {}))
443
+ if len(action_summary) > 100:
444
+ action_summary = action_summary[:100] + "..."
445
+ agent_histories[agent_id_str].append(
446
+ f"Step {step_num}: action={action_summary[:60]} -> reward {agent_reward:+.2f}"
447
+ )
448
+
449
+ rewards.append(team_reward)
450
+ steps_taken = step_num
451
+
452
+ # Log team-level step
453
+ oversight = result.get("oversight_report", {})
454
+ coord_score = oversight.get("coordination_score", 1.0)
455
+ safety_str = f" safety_corrections={step_interventions}" if step_interventions > 0 else ""
456
+ log_step(step=step_num, action=f"team_reward={team_reward:.2f} coord={coord_score:.2f}{safety_str}",
457
+ reward=team_reward, done=done)
458
+
459
+ if done:
460
+ break
461
+
462
+ grade_result = env.grade()
463
+ score = clamp_score(grade_result.get("score", 0.5))
464
+ success = score >= SUCCESS_SCORE_THRESHOLD
465
+
466
+ except Exception as e:
467
+ print(f"[DEBUG] Task {task_id} multi-agent error: {e}", flush=True)
468
+ score = 0.05
469
+ success = False
470
+
471
+ print(f"[INFO] Total safety interventions: {total_safety_interventions}", flush=True)
472
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards, mode="multi")
473
+
474
+ return {
475
+ "task": task_id, "score": score, "steps": steps_taken,
476
+ "success": success, "safety_interventions": total_safety_interventions
477
+ }
478
+
479
+
480
+ # ---------- Main ----------
481
+
482
+ def main():
483
+ """Run inference on all tasks."""
484
+ parser = argparse.ArgumentParser(description="OpenGrid LLM Inference")
485
+ parser.add_argument("--multi", action="store_true",
486
+ help="Use multi-agent POMDP mode (default: single-agent)")
487
+ parser.add_argument("--tasks", nargs="+", default=TASKS,
488
+ help="Which tasks to run (default: all)")
489
+ args = parser.parse_args()
490
+
491
+ if not API_KEY:
492
+ print("[ERROR] No API key found. Set OPENAI_API_KEY or HF_TOKEN environment variable.", flush=True)
493
+ sys.exit(1)
494
+
495
+ mode = "multi-agent" if args.multi else "single-agent"
496
+ print(f"[CONFIG] API_BASE_URL={API_BASE_URL}", flush=True)
497
+ print(f"[CONFIG] MODEL_NAME={MODEL_NAME}", flush=True)
498
+ print(f"[CONFIG] ENV_URL={ENV_URL}", flush=True)
499
+ print(f"[CONFIG] MODE={mode}", flush=True)
500
+
501
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=30.0, max_retries=2)
502
+ env = EnvClient(ENV_URL)
503
+
504
+ all_results = []
505
+ runner = run_task_multi if args.multi else run_task_single
506
+
507
+ try:
508
+ for task_id in args.tasks:
509
+ print(f"\n{'='*60}", flush=True)
510
+ print(f"Running task: {task_id} ({mode})", flush=True)
511
+ print(f"{'='*60}", flush=True)
512
+
513
+ result = runner(client, env, task_id)
514
+ all_results.append(result)
515
+
516
+ finally:
517
+ env.close()
518
+
519
+ # Summary
520
+ print(f"\n{'='*60}", flush=True)
521
+ print(f"FINAL RESULTS ({mode})", flush=True)
522
+ print(f"{'='*60}", flush=True)
523
+ for r in all_results:
524
+ status = "PASS" if r["success"] else "FAIL"
525
+ extra = ""
526
+ if "safety_interventions" in r:
527
+ extra = f" safety={r['safety_interventions']}"
528
+ print(f" {r['task']}: score={r['score']:.4f} steps={r['steps']} [{status}]{extra}", flush=True)
529
+
530
+ avg_score = sum(r["score"] for r in all_results) / len(all_results) if all_results else 0
531
+ print(f"\n Average Score: {avg_score:.4f}", flush=True)
532
+
533
+
534
+ if __name__ == "__main__":
535
+ main()
openenv.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: opengrid
3
+ type: space
4
+ runtime: fastapi
5
+ app: app:app
6
+ port: 7860
7
+
8
+ # Environment supports both single-agent and multi-agent POMDP modes.
9
+ # Single-agent: /reset + /step (backward compatible)
10
+ # Multi-agent: /reset_multi + /step_multi (2-3 agents per zone)
11
+
12
+ tasks:
13
+ - id: task_easy
14
+ name: Easy Grid (5 buses, 2 agents, 20% renewables)
15
+ description: Basic frequency control with 2-zone coordination
16
+ agents: 2
17
+ grader:
18
+ endpoint: /grader
19
+ score_range: [0.02, 0.98]
20
+ - id: task_medium
21
+ name: Medium Grid (10 buses, 3 agents, 50% renewables)
22
+ description: Congestion management with 3-zone POMDP and volatile renewables
23
+ agents: 3
24
+ grader:
25
+ endpoint: /grader
26
+ score_range: [0.02, 0.98]
27
+ - id: task_hard
28
+ name: Hard Grid (14 buses, 3 agents, 70% renewables)
29
+ description: High volatility, tight margins, complex topology with safety constraints
30
+ agents: 3
31
+ grader:
32
+ endpoint: /grader
33
+ score_range: [0.02, 0.98]
34
+ - id: task_karnataka
35
+ name: Karnataka KPTCL Grid (5 buses, 2 agents, real-world topology)
36
+ description: Realistic Karnataka power grid with POMDP multi-agent coordination
37
+ agents: 2
38
+ grader:
39
+ endpoint: /grader
40
+ score_range: [0.02, 0.98]
pyproject.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.backends._legacy:_Backend"
4
+
5
+ [project]
6
+ name = "opengrid"
7
+ version = "1.0.0"
8
+ description = "Renewable energy grid load-balancing environment for AI agents"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "KRISHNA GOYAL", email = "krishnagoyalcse@gmail.com"}
14
+ ]
15
+ dependencies = [
16
+ "fastapi",
17
+ "uvicorn",
18
+ "pydantic>=2.0",
19
+ "numpy",
20
+ "networkx",
21
+ "matplotlib",
22
+ "openai",
23
+ "httpx",
24
+ "openenv-core>=0.2.0",
25
+ ]
26
+
27
+ [project.urls]
28
+ Homepage = "https://github.com/K446/opengrid"
29
+
30
+ [project.scripts]
31
+ server = "server.app:main"
32
+
33
+ [tool.setuptools.packages.find]
34
+ where = ["."]
35
+ include = ["src*", "server*"]
36
+
37
+ [tool.pyright]
38
+ venvPath = "."
39
+ venv = ".venv"
40
+ pythonVersion = "3.13"
41
+ extraPaths = ["."]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic>=2.0
4
+ numpy
5
+ networkx
6
+ matplotlib
7
+ openai
8
+ httpx
9
+ openenv-core>=0.2.0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package
server/app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenGrid server entry point — used by openenv for multi-mode deployment.
3
+ Re-exports the FastAPI app from the root app module.
4
+ """
5
+ import sys
6
+ import os
7
+ import uvicorn
8
+
9
+ # Add parent directory to path so we can import from the root package
10
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ from app import app # type: ignore[import-not-found] # noqa: E402, F401
13
+
14
+
15
+ def main():
16
+ """Entry point for openenv server mode."""
17
+ uvicorn.run(app, host="0.0.0.0", port=7860)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
src/__init__.py ADDED
File without changes
src/baseline.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Policies for OpenGrid
3
+ ================================
4
+ Provides two agent implementations:
5
+ 1. heuristic_policy — deterministic rule-based baseline for reproducible scoring
6
+ 2. llm_policy — LLM-based policy using OpenAI-compatible API
7
+
8
+ Both support GridObservation (single-agent) and ZoneObservation (multi-agent).
9
+ """
10
+
11
+ import json
12
+ import logging
13
+ import os
14
+ from typing import List, Union
15
+
16
+ from openai import OpenAI
17
+ from .models import GridAction, BusAdjustment, GridObservation, ZoneObservation
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # API configuration — HF_TOKEN for Hugging Face endpoints, OPENAI_API_KEY for OpenAI
22
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
23
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
24
+ API_KEY = os.getenv("OPENAI_API_KEY", os.getenv("HF_TOKEN", ""))
25
+
26
+ # Cached client instance
27
+ _CLIENT = None
28
+
29
+
30
+ def _get_client() -> OpenAI:
31
+ """Lazy-cached client creation."""
32
+ global _CLIENT
33
+ if _CLIENT is None:
34
+ if not API_KEY:
35
+ raise RuntimeError(
36
+ "Missing API key. Set OPENAI_API_KEY or HF_TOKEN environment variable."
37
+ )
38
+ _CLIENT = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=15.0)
39
+ return _CLIENT
40
+
41
+
42
+ def _obs_buses(obs):
43
+ """Extract bus list from either GridObservation or ZoneObservation."""
44
+ return getattr(obs, "buses", getattr(obs, "local_buses", []))
45
+
46
+
47
+ def _obs_lines(obs):
48
+ """Extract line list from either GridObservation or ZoneObservation."""
49
+ if hasattr(obs, "lines"):
50
+ return obs.lines
51
+ internal = getattr(obs, "internal_lines", [])
52
+ boundary = getattr(obs, "boundary_lines", [])
53
+ return list(internal) + list(boundary)
54
+
55
+
56
+ SYSTEM_PROMPT = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
57
+
58
+ Key objectives:
59
+ 1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5–50.5 Hz)
60
+ 2. Prevent transmission line overloads (rho < 1.0)
61
+ 3. Avoid grid islanding (blackout)
62
+
63
+ Available actions:
64
+ 1. bus_adjustments: List of {"bus_id": int, "delta": float}
65
+ - Positive delta = increase power injection (discharge battery / ramp up generator)
66
+ - Negative delta = decrease power injection (charge battery / ramp down generator)
67
+ - Only works on battery and generator buses (NOT slack, load, solar, or wind)
68
+ - Slack bus injection is computed by physics — adjustments are ignored
69
+ 2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
70
+ - Opening a line removes it; closing reconnects. 3-step cooldown after each switch.
71
+ - WARNING: Opening lines can cause islanding → blackout → -100 reward
72
+ - Prefer NO topology actions unless absolutely necessary.
73
+
74
+ Strategy tips:
75
+ - If frequency < 50 Hz: grid needs more generation → discharge batteries or ramp up generators
76
+ - If frequency > 50 Hz: grid has excess generation → charge batteries or ramp down generators
77
+ - If a line rho > 0.9: reduce generation at one end or increase at the other to shift flow
78
+ - Prefer minimal actions. Do-nothing is better than reckless switching.
79
+
80
+ Respond with ONLY a valid JSON object, no markdown, no explanation. Example:
81
+ {"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
82
+ """
83
+
84
+
85
+ def parse_action_response(response_text: str) -> GridAction:
86
+ """Parse LLM response into a GridAction. Falls back to no-op on parse errors."""
87
+ try:
88
+ text = response_text.strip()
89
+
90
+ # Remove fenced code block if present
91
+ if text.startswith("```"):
92
+ lines = text.splitlines()
93
+ if lines[0].startswith("```"):
94
+ lines = lines[1:]
95
+ if lines and lines[-1].startswith("```"):
96
+ lines = lines[:-1]
97
+ text = "\n".join(lines).strip()
98
+
99
+ # Extract first JSON object
100
+ start = text.find("{")
101
+ end = text.rfind("}")
102
+ if start == -1 or end == -1 or end <= start:
103
+ return GridAction()
104
+
105
+ data = json.loads(text[start:end + 1])
106
+
107
+ # Handle list wrapping
108
+ if isinstance(data, list):
109
+ data = data[0] if data else {}
110
+
111
+ return GridAction(**data)
112
+ except Exception:
113
+ return GridAction()
114
+
115
+
116
+ def llm_policy(obs: Union[GridObservation, ZoneObservation]) -> GridAction:
117
+ """LLM-based policy using the OpenAI-compatible API.
118
+
119
+ Supports both GridObservation and ZoneObservation.
120
+ Falls back to no-op on any error.
121
+ """
122
+ client = _get_client()
123
+ obs_json = obs.model_dump_json()
124
+
125
+ try:
126
+ response = client.chat.completions.create(
127
+ model=MODEL_NAME,
128
+ messages=[
129
+ {"role": "system", "content": SYSTEM_PROMPT},
130
+ {"role": "user", "content": f"Current Grid State:\n{obs_json}"}
131
+ ],
132
+ temperature=0.0,
133
+ max_tokens=300,
134
+ )
135
+ action_str = response.choices[0].message.content
136
+ return parse_action_response(action_str)
137
+ except Exception as e:
138
+ logger.debug("LLM policy error: %s", e, exc_info=True)
139
+ return GridAction()
140
+
141
+
142
+ def heuristic_policy(
143
+ obs: Union[GridObservation, ZoneObservation],
144
+ ) -> GridAction:
145
+ """Rule-based baseline policy for reproducible scoring.
146
+
147
+ Strategy:
148
+ - Use batteries and generators for frequency regulation (proportional control)
149
+ - DO NOT open overloaded lines (causes cascading failures)
150
+ - DO NOT adjust the slack bus (overwritten by physics solver)
151
+ - Let the environment/safety layer clamp any out-of-range deltas
152
+
153
+ Supports both GridObservation (single-agent) and ZoneObservation (multi-agent).
154
+ """
155
+ adj = []
156
+ freq = obs.grid_frequency
157
+ freq_error = freq - 50.0 # positive = too high, negative = too low
158
+
159
+ buses = list(_obs_buses(obs))
160
+ lines = list(_obs_lines(obs))
161
+
162
+ batteries = [b for b in buses if b.type == 'battery']
163
+ generators = [b for b in buses if b.type == 'generator']
164
+
165
+ # --- 1. Proportional frequency control via batteries ---
166
+ if abs(freq_error) > 0.1 and batteries:
167
+ # Distribute correction across all available batteries
168
+ correction_total = -freq_error * 15.0 # stronger gain than naive 2.0
169
+ correction_total = max(-20.0, min(20.0, correction_total))
170
+ per_battery = correction_total / len(batteries)
171
+
172
+ for bus in batteries:
173
+ if per_battery > 0 and bus.soc > 0:
174
+ # Discharge — safety layer clamps to actual SOC
175
+ adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
176
+ elif per_battery < 0:
177
+ # Charge — safety layer clamps to remaining capacity
178
+ adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
179
+
180
+ # --- 2. Generator response for larger deviations ---
181
+ if abs(freq_error) > 0.25:
182
+ for bus in generators:
183
+ delta = -freq_error * 5.0
184
+ ramp = getattr(bus, 'ramp_rate', 20.0)
185
+ delta = max(-ramp, min(ramp, delta))
186
+ adj.append(BusAdjustment(bus_id=bus.id, delta=delta))
187
+
188
+ # --- 3. Overload relief via generators (not slack) ---
189
+ adjusted_for_overload = set()
190
+ for line in lines:
191
+ if line.rho > 0.95 and line.connected:
192
+ for bus in generators:
193
+ if bus.id not in adjusted_for_overload and bus.p_injection > 5:
194
+ adj.append(BusAdjustment(bus_id=bus.id, delta=-3.0))
195
+ adjusted_for_overload.add(bus.id)
196
+ break
197
+
198
+ # No topology actions — much safer than opening overloaded lines
199
+ return GridAction(bus_adjustments=adj, topology_actions=[])
src/environment.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from typing import List, Dict, Tuple, Optional
4
+ from .models import (
5
+ GridObservation, GridAction, GridReward, GridInfo,
6
+ LineStatus, BusState, ZoneObservation, ZoneInfo,
7
+ SafetyReport, OversightReport, MultiAgentStepResult,
8
+ )
9
+ from .physics import DCSolver, IslandedException
10
+ from .safety import SafetyLayer
11
+ from .oversight import OversightAgent
12
+
13
+
14
+ class OpenGridEnv:
15
+ """
16
+ OpenGrid: A renewable energy grid load-balancing environment.
17
+
18
+ Supports two modes:
19
+ 1. Single-agent (backward compatible): reset()/step()/state()
20
+ 2. Multi-agent POMDP: reset_multi()/step_multi() with per-zone
21
+ partial observability, safety layer, and oversight agent.
22
+
23
+ The agent(s) must maintain grid stability by:
24
+ - Balancing generation and load (frequency control)
25
+ - Managing transmission line loading (congestion management)
26
+ - Coordinating battery storage and topology switching
27
+ """
28
+
29
+ NOMINAL_FREQ = 50.0
30
+ FREQ_DEADBAND = 0.5 # Hz — acceptable deviation band
31
+ FREQ_NOISE_STD = 0.05 # Hz — noise added to POMDP observations
32
+ LINE_NOISE_STD = 0.02 # fraction — noise added to line readings
33
+
34
+ def __init__(self, config: Dict):
35
+ self.config = config
36
+ self.num_buses = config['num_buses']
37
+ self.lines_config = config['lines']
38
+ self.buses_config = config['buses']
39
+
40
+ # Resolve slack bus from config (not hardcoded to index 0)
41
+ self.slack_bus_id = next(
42
+ (b['id'] for b in self.buses_config if b['type'] == 'slack'), 0
43
+ )
44
+
45
+ self.solver = DCSolver(self.num_buses, slack_bus=self.slack_bus_id)
46
+ self.timestep = 0
47
+ self.max_steps = config.get('max_steps', 50)
48
+
49
+ self.bus_state = []
50
+ self.line_state = []
51
+ self.cooldowns = {}
52
+ self.slack_injection = 0.0
53
+ self._is_blackout = False
54
+
55
+ # Build index dicts for O(1) lookups
56
+ self._bus_cfg_by_id = {b['id']: b for b in self.buses_config}
57
+ self._line_cfg_by_id = {l['id']: l for l in self.lines_config}
58
+
59
+ # Multi-agent config
60
+ self.num_agents = config.get('num_agents', 1)
61
+ self.zone_assignments = config.get('zone_assignments', {})
62
+ self.zone_names = config.get('zone_names', [])
63
+ self.zone_bus_ids = config.get('zone_bus_ids', {})
64
+ self.internal_lines = config.get('internal_lines', {})
65
+ self.boundary_lines = config.get('boundary_lines', {})
66
+
67
+ # Safety and oversight (initialized on first multi-agent use)
68
+ self.safety_layer = SafetyLayer(config)
69
+ self.oversight_agent = OversightAgent(config)
70
+
71
+ # Episode tracking for multi-agent rewards
72
+ self._safety_reports_this_step: List[SafetyReport] = []
73
+ self._oversight_report_this_step: Optional[OversightReport] = None
74
+
75
+ # Calibrate droop constant to system size
76
+ total_load = sum(
77
+ b['base_p'] for b in self.buses_config if b['type'] == 'load'
78
+ )
79
+ total_gen = sum(
80
+ b['max_p'] for b in self.buses_config
81
+ if b['type'] in ['slack', 'generator', 'solar', 'wind']
82
+ )
83
+ total_system = max(total_load + total_gen, 50.0)
84
+ self.droop_constant = 2.5 / total_system
85
+
86
+ # Per-episode RNG — initialized early so _update_loads_and_renewables never crashes
87
+ self._seed = config.get('seed', 42)
88
+ self._rng = np.random.default_rng(self._seed)
89
+
90
+ # ======================================================================
91
+ # State Restoration (for GRPO environment-grounded rewards)
92
+ # ======================================================================
93
+
94
+ def _set_state(self, obs_dict: dict) -> None:
95
+ """Restore the environment to a state described by an observation dict.
96
+
97
+ This enables environment-grounded GRPO rewards: instead of scoring
98
+ actions with a heuristic proxy, we restore the env to the observed state,
99
+ step with the proposed action, and use the real reward.
100
+
101
+ Args:
102
+ obs_dict: A dict from ZoneObservation.model_dump() or
103
+ GridObservation.model_dump(), containing at minimum:
104
+ timestep, grid_frequency, and bus/line state.
105
+ """
106
+ self.timestep = obs_dict.get('timestep', 0)
107
+ self._is_blackout = obs_dict.get('is_blackout', False)
108
+ self.cooldowns = obs_dict.get('cooldowns', {k: 0 for k in self.cooldowns})
109
+
110
+ # Restore bus state from observation
111
+ local_buses = obs_dict.get('local_buses', obs_dict.get('buses', []))
112
+ if local_buses:
113
+ for b_obs in local_buses:
114
+ b_dyn = self._find_bus_state(b_obs['id'])
115
+ if b_dyn is not None:
116
+ b_dyn['p'] = b_obs.get('p_injection', b_dyn['p'])
117
+ b_dyn['soc'] = b_obs.get('soc', b_dyn.get('soc', 0.0))
118
+
119
+ # Restore line state from observation
120
+ all_lines = (obs_dict.get('internal_lines', []) or []) + \
121
+ (obs_dict.get('boundary_lines', []) or []) + \
122
+ (obs_dict.get('lines', []) or [])
123
+ for l_obs in all_lines:
124
+ l_dyn = self._find_line(l_obs['id'])
125
+ if l_dyn is not None:
126
+ l_dyn['connected'] = l_obs.get('connected', True)
127
+ l_dyn['flow'] = l_obs.get('flow', 0.0)
128
+
129
+ # Rebuild lookup indices
130
+ self._bus_state_by_id = {b['id']: b for b in self.bus_state}
131
+ self._line_state_by_id = {l['id']: l for l in self.line_state}
132
+
133
+ # Re-derive slack injection from frequency if available
134
+ freq = obs_dict.get('grid_frequency', self.NOMINAL_FREQ)
135
+ self.slack_injection = (self.NOMINAL_FREQ - freq) / self.droop_constant
136
+
137
+ # Update slack bus p to match
138
+ slack_dyn = self._find_bus_state(self.slack_bus_id)
139
+ if slack_dyn is not None:
140
+ slack_dyn['p'] = self.slack_injection
141
+
142
+ # ======================================================================
143
+ # Single-Agent API (backward compatible)
144
+ # ======================================================================
145
+
146
+ def reset(self) -> GridObservation:
147
+ """Reset the environment to initial state. Returns initial observation."""
148
+ self.timestep = 0
149
+ self.slack_injection = 0.0
150
+ self.cooldowns = {l['id']: 0 for l in self.lines_config}
151
+ self._rng = np.random.default_rng(self._seed)
152
+ self.oversight_agent.reset()
153
+
154
+ self.bus_state = []
155
+ for b in self.buses_config:
156
+ init_p = 0.0
157
+ # Initialize generators at 50% capacity so slack doesn't absorb all load
158
+ if b['type'] in ['generator']:
159
+ init_p = b['max_p'] * 0.5
160
+ self.bus_state.append({
161
+ 'id': b['id'], 'p': init_p, 'soc': b.get('init_soc', 0.0)
162
+ })
163
+ self.line_state = [
164
+ {'id': l['id'], 'connected': True, 'flow': 0.0}
165
+ for l in self.lines_config
166
+ ]
167
+
168
+ # Build O(1) lookup indices for dynamic state
169
+ self._bus_state_by_id = {b['id']: b for b in self.bus_state}
170
+ self._line_state_by_id = {l['id']: l for l in self.line_state}
171
+
172
+ self._is_blackout = False
173
+ self._update_loads_and_renewables()
174
+ self._run_power_flow()
175
+
176
+ return self._get_obs()
177
+
178
+ def step(self, action: GridAction) -> Tuple[GridObservation, GridReward, bool, GridInfo]:
179
+ """Execute one step: apply action, update dynamics, solve physics, compute reward."""
180
+ self.timestep += 1
181
+ reward_components = {"survival": 1.0, "frequency": 0.0, "overload": 0.0, "action_cost": 0.0}
182
+ self._is_blackout = False
183
+
184
+ # 1. Apply topology actions (with cooldown enforcement)
185
+ for t_act in action.topology_actions:
186
+ l_id = t_act.line_id
187
+ if l_id not in self.cooldowns:
188
+ continue
189
+ if self.cooldowns[l_id] == 0:
190
+ line = self._find_line(l_id)
191
+ if line is None:
192
+ continue
193
+ current_status = line['connected']
194
+ new_status = (t_act.action == "close")
195
+
196
+ if current_status != new_status:
197
+ line['connected'] = new_status
198
+ self.cooldowns[l_id] = 3
199
+ reward_components['action_cost'] -= 0.5
200
+
201
+ # Tick cooldowns
202
+ for l_id in self.cooldowns:
203
+ self.cooldowns[l_id] = max(0, self.cooldowns[l_id] - 1)
204
+
205
+ # 2. Apply power adjustment actions
206
+ for adj in action.bus_adjustments:
207
+ bus_cfg = self._find_bus_config(adj.bus_id)
208
+ bus_dyn = self._find_bus_state(adj.bus_id)
209
+ if bus_cfg is None or bus_dyn is None:
210
+ continue
211
+
212
+ delta = adj.delta
213
+
214
+ if bus_cfg['type'] == 'battery':
215
+ max_charge = bus_cfg['capacity'] - bus_dyn['soc']
216
+ max_discharge = bus_dyn['soc']
217
+
218
+ if delta > 0:
219
+ delta = min(delta, max_discharge)
220
+ else:
221
+ delta = max(delta, -max_charge)
222
+
223
+ bus_dyn['soc'] = np.clip(bus_dyn['soc'] - delta, 0.0, bus_cfg['capacity'])
224
+ bus_dyn['p'] = delta
225
+
226
+ elif bus_cfg['type'] not in ['load', 'solar', 'wind']:
227
+ max_ramp = bus_cfg.get('ramp_rate', 10.0)
228
+ delta = np.clip(delta, -max_ramp, max_ramp)
229
+ new_p = bus_dyn['p'] + delta
230
+ bus_dyn['p'] = np.clip(new_p, bus_cfg['min_p'], bus_cfg['max_p'])
231
+
232
+ # 3. Update load/renewable dynamics
233
+ self._update_loads_and_renewables()
234
+
235
+ # 4. Solve physics
236
+ try:
237
+ self._run_power_flow()
238
+
239
+ # Check line overloads
240
+ for l in self.line_state:
241
+ if l['connected']:
242
+ flow = l['flow']
243
+ limit = self._get_line_capacity(l['id'])
244
+ rho = abs(flow) / limit if limit > 0 else 0.0
245
+
246
+ if rho > 1.0:
247
+ reward_components['overload'] -= (rho - 1.0) ** 2 * 20
248
+ elif rho > 0.8:
249
+ reward_components['overload'] -= 0.1
250
+
251
+ # Frequency reward
252
+ freq = self._compute_frequency()
253
+ freq_dev = abs(freq - self.NOMINAL_FREQ)
254
+ if freq_dev > self.FREQ_DEADBAND:
255
+ raw_penalty = (freq_dev - self.FREQ_DEADBAND) * 0.5
256
+ reward_components['frequency'] -= min(raw_penalty, 1.5)
257
+ elif freq_dev < 0.1:
258
+ reward_components['frequency'] += 0.2
259
+
260
+ except IslandedException:
261
+ self._is_blackout = True
262
+ reward_components['survival'] = -100.0
263
+
264
+ done = self._is_blackout or (self.timestep >= self.max_steps)
265
+
266
+ total_reward = sum(reward_components.values())
267
+ reward = GridReward(value=total_reward, components=reward_components)
268
+ info = GridInfo(task_id=self.config['id'], is_blackout=self._is_blackout)
269
+
270
+ return self._get_obs(), reward, done, info
271
+
272
+ def state(self) -> GridObservation:
273
+ """Return current state (alias for observation)."""
274
+ return self._get_obs()
275
+
276
+ # ======================================================================
277
+ # Multi-Agent POMDP API
278
+ # ======================================================================
279
+
280
+ def reset_multi(self) -> Dict[int, ZoneObservation]:
281
+ """Reset environment and return per-agent partial observations."""
282
+ self.reset() # Reuse single-agent reset for state initialization
283
+ return {
284
+ agent_id: self._get_zone_obs(agent_id)
285
+ for agent_id in range(self.num_agents)
286
+ }
287
+
288
+ def step_multi(self, agent_actions: Dict[int, GridAction]) -> MultiAgentStepResult:
289
+ """Multi-agent step with safety layer and oversight.
290
+
291
+ Flow:
292
+ 1. Safety layer validates each agent's actions
293
+ 2. Combine corrected actions into one GridAction
294
+ 3. Run single-agent step with combined action
295
+ 4. Oversight agent evaluates coordination
296
+ 5. Compute per-agent rewards (local + global + safety + coordination)
297
+ """
298
+ pre_frequency = self._compute_frequency()
299
+ pre_bus_state = [dict(b) for b in self.bus_state]
300
+
301
+ # --- 1. Safety validation per agent ---
302
+ safety_reports: Dict[int, SafetyReport] = {}
303
+ corrected_actions: Dict[int, GridAction] = {}
304
+
305
+ for agent_id in range(self.num_agents):
306
+ proposed = agent_actions.get(agent_id, GridAction())
307
+ corrected, report = self.safety_layer.validate_and_correct(
308
+ agent_id=agent_id,
309
+ proposed_action=proposed,
310
+ current_line_state=self.line_state,
311
+ current_bus_state=self.bus_state,
312
+ cooldowns=self.cooldowns,
313
+ )
314
+ corrected_actions[agent_id] = corrected
315
+ safety_reports[agent_id] = report
316
+
317
+ self._safety_reports_this_step = safety_reports
318
+
319
+ # --- 2. Combine all corrected actions ---
320
+ combined = GridAction(
321
+ bus_adjustments=[
322
+ adj for action in corrected_actions.values()
323
+ for adj in action.bus_adjustments
324
+ ],
325
+ topology_actions=[
326
+ t for action in corrected_actions.values()
327
+ for t in action.topology_actions
328
+ ],
329
+ )
330
+
331
+ # --- 3. Run the step ---
332
+ obs, base_reward, done, info = self.step(combined)
333
+ post_frequency = self._compute_frequency()
334
+
335
+ # --- 4. Oversight evaluation ---
336
+ oversight_report = self.oversight_agent.evaluate(
337
+ agent_actions=agent_actions,
338
+ safety_reports=safety_reports,
339
+ pre_frequency=pre_frequency,
340
+ post_frequency=post_frequency,
341
+ pre_bus_state=pre_bus_state,
342
+ post_bus_state=self.bus_state,
343
+ )
344
+ self._oversight_report_this_step = oversight_report
345
+
346
+ # --- 5. Per-agent rewards ---
347
+ per_agent_rewards = {}
348
+ for agent_id in range(self.num_agents):
349
+ agent_reward = self._compute_agent_reward(
350
+ agent_id=agent_id,
351
+ base_reward=base_reward,
352
+ safety_report=safety_reports.get(agent_id),
353
+ oversight_report=oversight_report,
354
+ is_blackout=info.is_blackout,
355
+ )
356
+ per_agent_rewards[agent_id] = agent_reward
357
+
358
+ team_reward = base_reward.value
359
+
360
+ # --- 6. Per-agent partial observations ---
361
+ per_agent_obs = {
362
+ agent_id: self._get_zone_obs(agent_id)
363
+ for agent_id in range(self.num_agents)
364
+ }
365
+
366
+ # Propagate blackout to observations
367
+ if info.is_blackout:
368
+ for obs in per_agent_obs.values():
369
+ obs.is_blackout = True
370
+
371
+ return MultiAgentStepResult(
372
+ observations=per_agent_obs,
373
+ rewards=per_agent_rewards,
374
+ team_reward=round(team_reward, 4),
375
+ done=done,
376
+ safety_reports=safety_reports,
377
+ oversight_report=oversight_report,
378
+ info=info,
379
+ )
380
+
381
+ def get_zone_info(self) -> Dict[int, ZoneInfo]:
382
+ """Get metadata about each agent's zone."""
383
+ zones = {}
384
+ for agent_id in range(self.num_agents):
385
+ zones[agent_id] = ZoneInfo(
386
+ agent_id=agent_id,
387
+ zone_name=self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}",
388
+ bus_ids=self.zone_bus_ids.get(agent_id, []),
389
+ boundary_line_ids=self.boundary_lines.get(agent_id, []),
390
+ internal_line_ids=self.internal_lines.get(agent_id, []),
391
+ )
392
+ return zones
393
+
394
+ # ======================================================================
395
+ # Multi-Agent Reward Computation
396
+ # ======================================================================
397
+
398
+ def _compute_agent_reward(
399
+ self,
400
+ agent_id: int,
401
+ base_reward: GridReward,
402
+ safety_report: Optional[SafetyReport],
403
+ oversight_report: OversightReport,
404
+ is_blackout: bool,
405
+ ) -> GridReward:
406
+ """Compute per-agent reward with composable components.
407
+
408
+ Components:
409
+ - survival: shared team component (same for all)
410
+ - frequency: shared (all agents affected equally)
411
+ - local_congestion: penalty for overloads in agent's zone
412
+ - safety_compliance: penalty if safety layer corrected the action
413
+ - coordination: penalty from oversight for selfish/conflicting behavior
414
+ - efficiency: small bonus for minimal actions
415
+ """
416
+ components = {}
417
+
418
+ # Shared components (from base reward)
419
+ components['survival'] = base_reward.components.get('survival', 1.0)
420
+ components['frequency'] = base_reward.components.get('frequency', 0.0)
421
+
422
+ # Global overload shared equally — ensures no line's penalty is lost
423
+ components['overload_shared'] = base_reward.components.get('overload', 0.0) / max(self.num_agents, 1)
424
+
425
+ # Local congestion: additional penalty for overloads on lines in agent's zone
426
+ zone_overload = 0.0
427
+ agent_lines = set(self.internal_lines.get(agent_id, []))
428
+ agent_lines.update(self.boundary_lines.get(agent_id, []))
429
+ for l in self.line_state:
430
+ if l['id'] in agent_lines and l['connected']:
431
+ limit = self._get_line_capacity(l['id'])
432
+ rho = abs(l['flow']) / limit if limit > 0 else 0.0
433
+ if rho > 1.0:
434
+ zone_overload -= (rho - 1.0) ** 2 * 10
435
+ elif rho > 0.8:
436
+ zone_overload -= 0.05
437
+ components['local_congestion'] = zone_overload
438
+
439
+ # Safety compliance penalty
440
+ if safety_report and safety_report.was_corrected:
441
+ components['safety_compliance'] = -0.3 * (
442
+ 1 + safety_report.blocked_topology_actions
443
+ )
444
+ else:
445
+ components['safety_compliance'] = 0.1 # Bonus for safe actions
446
+
447
+ # Coordination penalty from oversight
448
+ coord_penalty = oversight_report.coordination_penalties.get(agent_id, 0.0)
449
+ components['coordination'] = -coord_penalty
450
+
451
+ # Action cost
452
+ components['action_cost'] = base_reward.components.get('action_cost', 0.0) / max(self.num_agents, 1)
453
+
454
+ total = sum(components.values())
455
+ return GridReward(value=round(total, 4), components=components)
456
+
457
+ # ======================================================================
458
+ # POMDP Observation
459
+ # ======================================================================
460
+
461
+ def _get_zone_obs(self, agent_id: int) -> ZoneObservation:
462
+ """Build partial observation for one agent (POMDP).
463
+
464
+ Each agent sees:
465
+ - Only buses in their zone
466
+ - Internal + boundary lines
467
+ - Noisy global frequency
468
+ - Limited neighbor signals
469
+ """
470
+ # Local buses
471
+ zone_bus_ids = set(self.zone_bus_ids.get(agent_id, []))
472
+ local_buses = []
473
+ zone_load = 0.0
474
+ zone_gen = 0.0
475
+ for b in self.bus_state:
476
+ if b['id'] in zone_bus_ids:
477
+ b_cfg = self._find_bus_config(b['id'])
478
+ if b_cfg is None:
479
+ continue
480
+ local_buses.append(BusState(
481
+ id=b['id'], type=b_cfg['type'],
482
+ p_injection=round(b['p'], 4),
483
+ soc=round(b.get('soc', 0.0), 4),
484
+ ramp_rate=b_cfg.get('ramp_rate', 0.0),
485
+ ))
486
+ if b_cfg['type'] == 'load':
487
+ zone_load += abs(b['p'])
488
+ elif b_cfg['type'] in ('generator', 'solar', 'wind', 'slack'):
489
+ zone_gen += b['p']
490
+ # battery: not classified as load or gen
491
+
492
+ # Internal lines (within zone)
493
+ int_line_ids = set(self.internal_lines.get(agent_id, []))
494
+ internal_lines = []
495
+ for l in self.line_state:
496
+ if l['id'] in int_line_ids:
497
+ limit = self._get_line_capacity(l['id'])
498
+ rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
499
+ # Add noise to line readings
500
+ noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho
501
+ noisy_rho = max(0.0, noisy_rho)
502
+ internal_lines.append(LineStatus(
503
+ id=l['id'], connected=l['connected'],
504
+ flow=round(l['flow'], 4),
505
+ rho=round(noisy_rho, 4),
506
+ ))
507
+
508
+ # Boundary lines (connecting to other zones)
509
+ bnd_line_ids = set(self.boundary_lines.get(agent_id, []))
510
+ boundary_lines = []
511
+ for l in self.line_state:
512
+ if l['id'] in bnd_line_ids:
513
+ limit = self._get_line_capacity(l['id'])
514
+ rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
515
+ noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho
516
+ noisy_rho = max(0.0, noisy_rho)
517
+ boundary_lines.append(LineStatus(
518
+ id=l['id'], connected=l['connected'],
519
+ flow=round(l['flow'], 4),
520
+ rho=round(noisy_rho, 4),
521
+ ))
522
+
523
+ # Noisy frequency (POMDP — agents don't get perfect readings)
524
+ true_freq = self._compute_frequency()
525
+ noisy_freq = true_freq + (self._rng.normal(0, self.FREQ_NOISE_STD) if self._rng else 0.0)
526
+
527
+ # Neighbor signals: average bus injection of other zones
528
+ neighbor_signals = {}
529
+ for other_id in range(self.num_agents):
530
+ if other_id == agent_id:
531
+ continue
532
+ other_bus_ids = self.zone_bus_ids.get(other_id, [])
533
+ if other_bus_ids:
534
+ avg_inj = np.mean([
535
+ b['p'] for b in self.bus_state if b['id'] in other_bus_ids
536
+ ])
537
+ neighbor_signals[other_id] = round(float(avg_inj), 2)
538
+
539
+ # Cooldowns for lines this agent can see
540
+ visible_lines = int_line_ids | bnd_line_ids
541
+ visible_cooldowns = {
542
+ k: v for k, v in self.cooldowns.items() if k in visible_lines
543
+ }
544
+
545
+ zone_name = self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}"
546
+
547
+ return ZoneObservation(
548
+ agent_id=agent_id,
549
+ zone_name=zone_name,
550
+ timestep=self.timestep,
551
+ grid_frequency=round(noisy_freq, 4),
552
+ local_buses=local_buses,
553
+ boundary_lines=boundary_lines,
554
+ internal_lines=internal_lines,
555
+ neighbor_signals=neighbor_signals,
556
+ cooldowns=visible_cooldowns,
557
+ is_blackout=False,
558
+ zone_load_mw=round(zone_load, 2),
559
+ zone_gen_mw=round(zone_gen, 2),
560
+ )
561
+
562
+ # ======================================================================
563
+ # Internal Methods (unchanged from original)
564
+ # ======================================================================
565
+
566
+ def _run_power_flow(self):
567
+ """Build active line list, solve DC power flow, update line flows and slack injection."""
568
+ active_lines = []
569
+ for l_cfg in self.lines_config:
570
+ l_dyn = self._find_line(l_cfg['id'])
571
+ if l_dyn and l_dyn['connected']:
572
+ active_lines.append({
573
+ 'id': l_cfg['id'], 'from': l_cfg['from'], 'to': l_cfg['to'],
574
+ 'susceptance': l_cfg['susceptance'], 'connected': True
575
+ })
576
+
577
+ self.solver.update_grid(active_lines)
578
+
579
+ p_inj = np.zeros(self.num_buses)
580
+ for b_dyn in self.bus_state:
581
+ p_inj[b_dyn['id']] = b_dyn['p']
582
+
583
+ theta, flows, slack_inj = self.solver.solve(p_inj)
584
+
585
+ self.slack_injection = slack_inj
586
+ slack_dyn = self._find_bus_state(self.slack_bus_id)
587
+ if slack_dyn is not None:
588
+ slack_dyn['p'] = slack_inj
589
+
590
+ for l in self.line_state:
591
+ if l['connected'] and l['id'] in flows:
592
+ l['flow'] = flows[l['id']]
593
+ elif not l['connected']:
594
+ l['flow'] = 0.0
595
+
596
+ def _compute_frequency(self) -> float:
597
+ """Frequency proxy using droop model, calibrated to system size."""
598
+ return self.NOMINAL_FREQ - self.droop_constant * self.slack_injection
599
+
600
+ def _update_loads_and_renewables(self):
601
+ """Update time-varying loads and renewable generation. Uses per-episode RNG."""
602
+ for b_dyn in self.bus_state:
603
+ b_cfg = self._find_bus_config(b_dyn['id'])
604
+ if b_cfg is None:
605
+ continue
606
+
607
+ if b_cfg['type'] == 'load':
608
+ daily_cycle = math.sin((self.timestep % 24 - 6) * math.pi / 12)
609
+ b_dyn['p'] = -b_cfg['base_p'] * (0.8 + 0.4 * max(0, daily_cycle))
610
+
611
+ elif b_cfg['type'] == 'solar':
612
+ solar_cycle = max(0, math.sin((self.timestep % 24 - 6) * math.pi / 12))
613
+ b_dyn['p'] = b_cfg['max_p'] * solar_cycle
614
+
615
+ elif b_cfg['type'] == 'wind':
616
+ wind_delta = self._rng.uniform(-5, 5)
617
+ b_dyn['p'] = float(np.clip(b_dyn['p'] + wind_delta, 0, b_cfg['max_p']))
618
+
619
+ def _get_obs(self) -> GridObservation:
620
+ """Build observation from current state."""
621
+ obs_lines = []
622
+ for l in self.line_state:
623
+ limit = self._get_line_capacity(l['id'])
624
+ rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
625
+ obs_lines.append(LineStatus(
626
+ id=l['id'], connected=l['connected'], flow=round(l['flow'], 4), rho=round(rho, 4)
627
+ ))
628
+
629
+ obs_buses = []
630
+ for b in self.bus_state:
631
+ b_cfg = self._find_bus_config(b['id'])
632
+ if b_cfg is None:
633
+ continue
634
+ obs_buses.append(BusState(
635
+ id=b['id'], type=b_cfg['type'],
636
+ p_injection=round(b['p'], 4),
637
+ soc=round(b.get('soc', 0.0), 4),
638
+ ramp_rate=b_cfg.get('ramp_rate', 0.0)
639
+ ))
640
+
641
+ freq = self._compute_frequency()
642
+
643
+ return GridObservation(
644
+ timestep=self.timestep,
645
+ grid_frequency=round(freq, 4),
646
+ buses=obs_buses,
647
+ lines=obs_lines,
648
+ cooldowns=self.cooldowns,
649
+ is_blackout=getattr(self, '_is_blackout', False)
650
+ )
651
+
652
+ # ---------- Lookup Helpers (O(1) indexed + guarded fallbacks) ----------
653
+
654
+ def _find_line(self, line_id: str):
655
+ # Use index if available (built in reset), fall back to linear scan
656
+ idx = getattr(self, '_line_state_by_id', None)
657
+ if idx is not None:
658
+ return idx.get(line_id)
659
+ return next((l for l in self.line_state if l['id'] == line_id), None)
660
+
661
+ def _find_bus_config(self, bus_id: int):
662
+ return self._bus_cfg_by_id.get(bus_id)
663
+
664
+ def _find_bus_state(self, bus_id: int):
665
+ idx = getattr(self, '_bus_state_by_id', None)
666
+ if idx is not None:
667
+ return idx.get(bus_id)
668
+ return next((b for b in self.bus_state if b['id'] == bus_id), None)
669
+
670
+ def _get_line_capacity(self, line_id: str) -> float:
671
+ cfg = self._line_cfg_by_id.get(line_id)
672
+ return cfg['capacity'] if cfg else 1.0
src/grader.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import numpy as np
4
+ from typing import Dict, Callable, List
5
+ from .environment import OpenGridEnv
6
+ from .models import GridAction, BusAdjustment, TopologyAction
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def _random_thrash_policy(obs, rng: np.random.Generator) -> GridAction:
12
+ """Deliberately bad policy: random topology switching. Used as reward floor.
13
+
14
+ Alternates between opening and closing lines to maximize instability
15
+ across all steps (not just step 1). Uses an explicit RNG instance
16
+ (not global np.random) so that floor estimation is reproducible.
17
+ """
18
+ top_actions = []
19
+ for line in obs.lines:
20
+ if rng.random() > 0.7:
21
+ action = "open" if line.connected else "close"
22
+ top_actions.append(TopologyAction(line_id=line.id, action=action))
23
+ return GridAction(topology_actions=top_actions)
24
+
25
+
26
+ def compute_analytical_ceiling(max_steps: int) -> float:
27
+ """Compute the theoretical maximum reward for an episode.
28
+
29
+ Perfect agent: survives every step (+1.0 survival) and achieves
30
+ tight frequency control bonus (+0.2) every step, with zero overload
31
+ and zero action cost.
32
+
33
+ ceiling = max_steps * (1.0 + 0.2) = max_steps * 1.2
34
+
35
+ NOTE: The +0.2 frequency bonus requires freq_dev < 0.1 Hz, which needs
36
+ |P_slack| < 0.04 * S_total (from droop model). On high-renewable tasks
37
+ (task_hard) where slack routinely absorbs >50 MW of imbalance, this band
38
+ may be structurally inaccessible. The effective ceiling on such tasks is
39
+ closer to max_steps * 1.0 = 50.0. Scores remain comparable across agents
40
+ on the same task — the ceiling just compresses the achievable range.
41
+ """
42
+ return max_steps * 1.2
43
+
44
+
45
+ # Validator requires scores strictly in the open interval (0, 1).
46
+ # Using wide epsilon so that even aggressive rounding (e.g. round(x, 1))
47
+ # can never produce exactly 0.0 or 1.0.
48
+ _SCORE_EPSILON = 0.02
49
+ _SCORE_MIN = _SCORE_EPSILON # 0.02
50
+ _SCORE_MAX = 1.0 - _SCORE_EPSILON # 0.98
51
+
52
+
53
+ def _safe_float(x: float) -> float:
54
+ """Convert to plain Python float; replace NaN/Inf with midpoint."""
55
+ v = float(x)
56
+ if not math.isfinite(v):
57
+ return 0.5 # safe fallback inside (0, 1)
58
+ return v
59
+
60
+
61
+ def _clamp_score(score: float) -> float:
62
+ """Clamp a score to the open interval (0, 1) using Python-native min/max.
63
+
64
+ This avoids any numpy-scalar serialisation quirks and guarantees a plain
65
+ Python float that JSON-encodes to a normal number.
66
+ """
67
+ score = _safe_float(score)
68
+ score = max(_SCORE_MIN, min(_SCORE_MAX, score))
69
+ # Truncate (not round) to 4 decimal places to avoid
70
+ # round(0.98500…, 4) == 0.985 becoming 0.99 after further rounding.
71
+ score = math.floor(score * 10000) / 10000
72
+ # Final safety: ensure truncation didn't land on a boundary
73
+ score = max(_SCORE_MIN, min(_SCORE_MAX, score))
74
+ return score
75
+
76
+
77
+ def normalize_score(cumulative_reward: float, reward_floor: float, reward_ceiling: float,
78
+ n1_survival_rate: float = 1.0) -> float:
79
+ """
80
+ Shared normalization: maps raw cumulative reward to the open interval (0, 1).
81
+ Used by both /grader endpoint and RobustnessGrader for consistency.
82
+
83
+ - reward_floor: empirical worst-case (random thrashing policy, seeded RNG)
84
+ - reward_ceiling: analytical upper bound (perfect survival + perfect frequency bonus)
85
+ - n1_survival_rate: fraction of episodes without blackout (adds up to 10% bonus)
86
+
87
+ Scores are clamped to [0.02, 0.98] so they are never exactly 0.0 or 1.0,
88
+ and cannot round to those values, satisfying the OpenEnv Phase-2 validator.
89
+ """
90
+ raw_range = _safe_float(reward_ceiling) - _safe_float(reward_floor)
91
+ if raw_range < 1.0:
92
+ raw_range = 1.0 # Prevent division by near-zero
93
+
94
+ cumulative_reward = _safe_float(cumulative_reward)
95
+ normalized = (cumulative_reward - _safe_float(reward_floor)) / raw_range
96
+
97
+ # N-1 bonus: up to 10% boost for surviving without blackout
98
+ # Scale into available headroom so top performers still differentiate
99
+ n1_bonus = float(n1_survival_rate) * 0.1
100
+ available = _SCORE_MAX - normalized
101
+ if available > 0:
102
+ n1_bonus = min(n1_bonus, available * 0.5)
103
+ else:
104
+ n1_bonus = 0.0
105
+ score = normalized + n1_bonus
106
+
107
+ return _clamp_score(score)
108
+
109
+
110
+ class RobustnessGrader:
111
+ """
112
+ Evaluates a policy's performance on an OpenGrid task.
113
+
114
+ Scoring:
115
+ - Floor: empirical estimate from adversarial random topology thrashing
116
+ (seeded RNG for reproducibility, n_samples=10 for stability)
117
+ - Ceiling: analytical upper bound = max_steps * 1.2
118
+ (perfect survival + perfect frequency bonus every step)
119
+ - Normalizes cumulative reward to 0.0–1.0
120
+ - Adds N-1 survival bonus (max 10%)
121
+
122
+ The heuristic baseline scores ~0.75–0.90, leaving headroom for
123
+ agents that employ active topology management and predictive scheduling.
124
+ """
125
+
126
+ def __init__(self, config: Dict):
127
+ self.config = config
128
+ self.reward_floor = None
129
+ self.reward_ceiling = None
130
+
131
+ def _estimate_bounds(self, n_samples: int = 10):
132
+ """Estimate reward bounds.
133
+
134
+ Floor: adversarial random thrashing policy (empirical, seeded).
135
+ Ceiling: analytical upper bound (deterministic).
136
+
137
+ n_samples=10 to reduce variance in the floor estimate.
138
+ The floor uses mean - std to be conservatively low.
139
+ Each episode gets its own thrash RNG derived from a master seed
140
+ so that changing n_samples doesn't alter existing episodes.
141
+ """
142
+ master_rng = np.random.default_rng(seed=12345)
143
+
144
+ floors = []
145
+ base_seed = self.config.get('seed', 42)
146
+
147
+ for i in range(n_samples):
148
+ # Per-episode thrash RNG — decoupled from other episodes
149
+ thrash_rng = np.random.default_rng(seed=int(master_rng.integers(0, 2**31)))
150
+
151
+ # Vary environment seed so floor reflects environment stochasticity
152
+ config_with_seed = {**self.config, 'seed': base_seed + i}
153
+ env = OpenGridEnv(config_with_seed)
154
+ obs = env.reset()
155
+ done = False
156
+ ep_reward = 0
157
+ while not done:
158
+ action = _random_thrash_policy(obs, rng=thrash_rng)
159
+ obs, reward, done, info = env.step(action)
160
+ ep_reward += reward.value
161
+ floors.append(ep_reward)
162
+
163
+ self.reward_floor = float(np.mean(floors) - np.std(floors))
164
+ logger.debug("Floor estimate: mean=%.2f, std=%.2f, floor=%.2f",
165
+ np.mean(floors), np.std(floors), self.reward_floor)
166
+
167
+ # Ceiling: analytical upper bound (not heuristic)
168
+ max_steps = self.config.get('max_steps', 50)
169
+ analytical_ceiling = compute_analytical_ceiling(max_steps)
170
+ self.reward_ceiling = analytical_ceiling
171
+
172
+ # Ensure minimum spread — expand floor downward, not ceiling upward
173
+ if self.reward_ceiling - self.reward_floor < 10.0:
174
+ self.reward_floor = self.reward_ceiling - max(10.0, analytical_ceiling * 0.2)
175
+ logger.debug("Spread too small, adjusted floor to %.2f", self.reward_floor)
176
+
177
+ def get_bounds(self) -> Dict[str, float]:
178
+ """Return the reward floor and ceiling, computing if needed."""
179
+ if self.reward_floor is None:
180
+ self._estimate_bounds()
181
+ return {"reward_floor": self.reward_floor, "reward_ceiling": self.reward_ceiling}
182
+
183
+ def evaluate_policy(self, policy_fn: Callable, n_episodes: int = 10) -> Dict:
184
+ """Run a policy for n_episodes and return normalized score.
185
+
186
+ Each episode uses a different environment seed (offset by 1000 from
187
+ floor estimation seeds) to measure policy robustness across diverse
188
+ wind/load trajectories.
189
+ """
190
+ if self.reward_floor is None:
191
+ self._estimate_bounds()
192
+
193
+ base_seed = self.config.get('seed', 42)
194
+ rewards = []
195
+ n1_survivals = 0
196
+
197
+ for ep in range(n_episodes):
198
+ # Offset by 1000 to avoid overlap with floor estimation seeds
199
+ config_with_seed = {**self.config, 'seed': base_seed + ep + 1000}
200
+ env = OpenGridEnv(config_with_seed)
201
+ obs = env.reset()
202
+ done = False
203
+ ep_reward = 0
204
+
205
+ while not done:
206
+ action = policy_fn(obs)
207
+ obs, reward, done, info = env.step(action)
208
+ ep_reward += reward.value
209
+
210
+ rewards.append(ep_reward)
211
+ if not info.is_blackout:
212
+ n1_survivals += 1
213
+
214
+ avg_reward = float(np.mean(rewards))
215
+ n1_rate = n1_survivals / n_episodes
216
+ logger.debug("Policy eval: avg=%.2f, n1_rate=%.2f, episodes=%d",
217
+ avg_reward, n1_rate, n_episodes)
218
+
219
+ final_score = normalize_score(
220
+ cumulative_reward=avg_reward,
221
+ reward_floor=self.reward_floor,
222
+ reward_ceiling=self.reward_ceiling,
223
+ n1_survival_rate=n1_rate
224
+ )
225
+
226
+ return {
227
+ "avg_raw_reward": round(avg_reward, 4),
228
+ "n1_survival_rate": round(n1_rate, 4),
229
+ "reward_floor": round(self.reward_floor, 4),
230
+ "reward_ceiling": round(self.reward_ceiling, 4),
231
+ "score": final_score
232
+ }
src/models.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Literal, Optional
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class TopologyAction(BaseModel):
6
+ """A topology switching action on a transmission line."""
7
+ line_id: str
8
+ action: Literal["open", "close"]
9
+
10
+
11
+ class BusAdjustment(BaseModel):
12
+ """A power injection adjustment on a bus."""
13
+ bus_id: int
14
+ delta: float # MW change (positive = inject more)
15
+
16
+
17
+ class GridAction(BaseModel):
18
+ """Agent action: adjust bus injections and/or switch line topology."""
19
+ bus_adjustments: List[BusAdjustment] = []
20
+ topology_actions: List[TopologyAction] = []
21
+
22
+
23
+ class LineStatus(BaseModel):
24
+ """Current state of a transmission line."""
25
+ id: str
26
+ connected: bool
27
+ flow: float = 0.0
28
+ rho: float = Field(0.0, ge=0.0, description="Loading percentage (flow/capacity)")
29
+
30
+
31
+ class BusState(BaseModel):
32
+ """Current state of a bus (generator, load, battery, or renewable)."""
33
+ id: int
34
+ type: Literal["slack", "generator", "load", "battery", "solar", "wind"]
35
+ p_injection: float
36
+ soc: float = Field(0.0, ge=0.0, description="State of charge (MWh)")
37
+ ramp_rate: float = 0.0
38
+
39
+
40
+ class GridObservation(BaseModel):
41
+ """Full grid observation returned by reset()/step()/state()."""
42
+ timestep: int
43
+ grid_frequency: float
44
+ buses: List[BusState]
45
+ lines: List[LineStatus]
46
+ cooldowns: Dict[str, int]
47
+ is_blackout: bool = False
48
+
49
+ def __repr__(self) -> str:
50
+ return (
51
+ f"GridObservation(t={self.timestep}, f={self.grid_frequency:.2f}, "
52
+ f"buses={len(self.buses)}, lines={len(self.lines)}, "
53
+ f"blackout={self.is_blackout})"
54
+ )
55
+
56
+
57
+ class GridReward(BaseModel):
58
+ """Reward signal with component breakdown."""
59
+ value: float
60
+ components: Dict[str, float]
61
+
62
+
63
+ class GridInfo(BaseModel):
64
+ """Episode info (metadata alongside reward)."""
65
+ task_id: str
66
+ is_blackout: bool
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Multi-Agent POMDP Models
71
+ # ---------------------------------------------------------------------------
72
+
73
+ class ZoneInfo(BaseModel):
74
+ """Metadata about an agent's zone."""
75
+ agent_id: int
76
+ zone_name: str
77
+ bus_ids: List[int]
78
+ boundary_line_ids: List[str]
79
+ internal_line_ids: List[str]
80
+
81
+
82
+ class ZoneObservation(BaseModel):
83
+ """Partial observation for one agent under POMDP.
84
+
85
+ Each agent sees only:
86
+ - Their local buses (within their zone)
87
+ - Boundary lines (connecting to other zones)
88
+ - Internal lines (within their zone)
89
+ - A noisy estimate of global grid frequency
90
+ - Limited communication signals from neighboring agents
91
+ """
92
+ agent_id: int
93
+ zone_name: str
94
+ timestep: int
95
+ grid_frequency: float # noisy — Gaussian noise added
96
+ local_buses: List[BusState]
97
+ boundary_lines: List[LineStatus]
98
+ internal_lines: List[LineStatus]
99
+ neighbor_signals: Dict[int, float] = Field(
100
+ default_factory=dict,
101
+ description="Limited info from other agents: {agent_id: their avg bus injection}"
102
+ )
103
+ cooldowns: Dict[str, int] = Field(default_factory=dict)
104
+ is_blackout: bool = False
105
+ zone_load_mw: float = 0.0
106
+ zone_gen_mw: float = 0.0
107
+
108
+ def __repr__(self) -> str:
109
+ return (
110
+ f"ZoneObservation(agent={self.agent_id}, zone={self.zone_name}, "
111
+ f"t={self.timestep}, f={self.grid_frequency:.2f}, "
112
+ f"buses={len(self.local_buses)}, blackout={self.is_blackout})"
113
+ )
114
+
115
+
116
+ class SafetyReport(BaseModel):
117
+ """Report from the safety layer about action corrections."""
118
+ agent_id: int
119
+ was_corrected: bool
120
+ correction_reason: str = ""
121
+ n1_violations_detected: int = 0
122
+ proposed_topology_actions: int = 0
123
+ blocked_topology_actions: int = 0
124
+ original_total_delta_mw: float = 0.0
125
+ corrected_total_delta_mw: float = 0.0
126
+
127
+
128
+ class OversightReport(BaseModel):
129
+ """Report from the oversight agent about multi-agent coordination."""
130
+ coordination_score: float = Field(
131
+ 1.0, description="1.0 = perfect cooperation, 0.0 = total conflict"
132
+ )
133
+ conflicting_actions_detected: int = 0
134
+ selfish_actions_detected: int = 0
135
+ coordination_penalties: Dict[int, float] = Field(default_factory=dict)
136
+ global_frequency_contribution: Dict[int, float] = Field(
137
+ default_factory=dict,
138
+ description="Each agent's net impact on frequency deviation"
139
+ )
140
+ notes: List[str] = Field(default_factory=list)
141
+
142
+
143
+ class MultiAgentAction(BaseModel):
144
+ """Request body for /step_multi: per-agent actions keyed by agent_id."""
145
+ agent_actions: Dict[int, GridAction] = Field(
146
+ default_factory=dict,
147
+ description="Actions for each agent, keyed by agent_id"
148
+ )
149
+
150
+
151
+ class MultiAgentStepResult(BaseModel):
152
+ """Result of a multi-agent step — per-agent observations, rewards, reports."""
153
+ observations: Dict[int, ZoneObservation]
154
+ rewards: Dict[int, GridReward]
155
+ team_reward: float
156
+ done: bool
157
+ safety_reports: Dict[int, SafetyReport] = Field(
158
+ default_factory=dict,
159
+ description="Per-agent safety reports, keyed by agent_id"
160
+ )
161
+ oversight_report: OversightReport
162
+ info: GridInfo
src/oversight.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Oversight Agent — Multi-Agent Coordination Monitor
3
+ ===================================================
4
+ A rule-based meta-agent that monitors coordination quality across zones.
5
+
6
+ Responsibilities:
7
+ 1. Detect conflicting actions (agents pulling frequency opposite ways)
8
+ 2. Detect selfish behavior (local improvement at global cost)
9
+ 3. Assign coordination penalties to agents
10
+ 4. Track safety layer intervention frequency
11
+
12
+ This is NOT a trained agent — it's a deterministic rule engine that
13
+ provides additional reward signal to guide multi-agent learning.
14
+
15
+ References:
16
+ - Symphony: Multi-Agent Intelligence in a Collective Fabric (Gradient, 2025)
17
+ - Massgen: When Multiple LLMs Think Together (Gradient, 2025)
18
+ """
19
+
20
+ import logging
21
+ import math
22
+ from typing import Dict, List
23
+ from .models import GridAction, SafetyReport, OversightReport
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class OversightAgent:
29
+ """Rule-based oversight agent for multi-agent coordination.
30
+
31
+ Sits above zone agents and evaluates whether their combined actions
32
+ are globally beneficial or harmful. Produces an OversightReport
33
+ with coordination scores and penalties.
34
+ """
35
+
36
+ def __init__(self, config: Dict):
37
+ self.config = config
38
+ self.zone_assignments = config.get('zone_assignments', {})
39
+ self.num_agents = config.get('num_agents', 1)
40
+ self.intervention_history: Dict[int, int] = {
41
+ i: 0 for i in range(self.num_agents)
42
+ }
43
+
44
+ def evaluate(
45
+ self,
46
+ agent_actions: Dict[int, GridAction],
47
+ safety_reports: Dict[int, SafetyReport],
48
+ pre_frequency: float,
49
+ post_frequency: float,
50
+ pre_bus_state: List[Dict],
51
+ post_bus_state: List[Dict],
52
+ ) -> OversightReport:
53
+ """Evaluate multi-agent coordination quality.
54
+
55
+ Args:
56
+ agent_actions: {agent_id: GridAction} — proposed actions
57
+ safety_reports: {agent_id: SafetyReport} — per-agent safety results
58
+ pre_frequency: Grid frequency before this step
59
+ post_frequency: Grid frequency after this step
60
+ pre_bus_state: Bus states before actions
61
+ post_bus_state: Bus states after actions
62
+
63
+ Returns:
64
+ OversightReport with scores, penalties, and notes
65
+ """
66
+ notes = []
67
+ penalties: Dict[int, float] = {i: 0.0 for i in range(self.num_agents)}
68
+ conflicts = 0
69
+ selfish_count = 0
70
+
71
+ # --- 1. Track safety interventions ---
72
+ for agent_id, report in safety_reports.items():
73
+ # Validate agent_id is within expected range
74
+ if agent_id not in self.intervention_history:
75
+ notes.append(f"WARNING: unknown agent_id {agent_id} in safety report")
76
+ continue
77
+ if report.was_corrected:
78
+ self.intervention_history[agent_id] += 1
79
+ # Penalty scales with repeated violations
80
+ repeat_count = self.intervention_history[agent_id]
81
+ penalties[agent_id] += 0.1 * min(repeat_count, 5)
82
+ notes.append(
83
+ f"Agent {agent_id}: safety correction #{repeat_count}"
84
+ )
85
+
86
+ # --- 2. Detect conflicting frequency actions ---
87
+ # If agents are pushing frequency in opposite directions, that's waste
88
+ net_deltas = {}
89
+ for agent_id, action in agent_actions.items():
90
+ total_delta = sum(a.delta for a in action.bus_adjustments)
91
+ n_topo = len(action.topology_actions)
92
+ if n_topo > 0:
93
+ notes.append(
94
+ f"Agent {agent_id}: {n_topo} topology action(s) "
95
+ f"not included in conflict analysis"
96
+ )
97
+ net_deltas[agent_id] = total_delta
98
+
99
+ if len(net_deltas) >= 2:
100
+ deltas = list(net_deltas.values())
101
+ # Check if some agents inject and others withdraw significantly
102
+ injectors = [d for d in deltas if d > 2.0]
103
+ withdrawers = [d for d in deltas if d < -2.0]
104
+ if injectors and withdrawers:
105
+ conflicts += 1
106
+ notes.append(
107
+ "Conflicting actions: some agents inject while others withdraw"
108
+ )
109
+ # Penalize the agent pushing AGAINST the needed direction
110
+ freq_error = 50.0 - pre_frequency
111
+
112
+ if abs(freq_error) > 0.1:
113
+ # Clear direction needed — penalize the opposing side
114
+ for agent_id, delta in net_deltas.items():
115
+ # If freq < 50 (need more injection) but agent withdraws
116
+ if freq_error > 0.1 and delta < -2.0:
117
+ penalties[agent_id] += 0.2
118
+ selfish_count += 1
119
+ notes.append(
120
+ f"Agent {agent_id}: withdrew {delta:.1f} MW "
121
+ f"when grid needed injection"
122
+ )
123
+ # If freq > 50 (need less injection) but agent injects
124
+ elif freq_error < -0.1 and delta > 2.0:
125
+ penalties[agent_id] += 0.2
126
+ selfish_count += 1
127
+ notes.append(
128
+ f"Agent {agent_id}: injected {delta:.1f} MW "
129
+ f"when grid had excess"
130
+ )
131
+ else:
132
+ # Near-nominal: penalize all significant participants equally
133
+ for agent_id, delta in net_deltas.items():
134
+ if abs(delta) > 2.0:
135
+ penalties[agent_id] += 0.1
136
+ notes.append(
137
+ f"Agent {agent_id}: conflicting injection "
138
+ f"({delta:+.1f} MW) with no clear grid need"
139
+ )
140
+
141
+ # --- 3. Evaluate frequency impact per agent ---
142
+ freq_contribution: Dict[int, float] = {}
143
+ freq_dev_before = abs(pre_frequency - 50.0)
144
+ freq_dev_after = abs(post_frequency - 50.0)
145
+ freq_improved = freq_dev_after < freq_dev_before
146
+
147
+ for agent_id in range(self.num_agents):
148
+ # Net MW delta (not frequency impact — would need droop constant)
149
+ total_delta = net_deltas.get(agent_id, 0.0)
150
+ freq_contribution[agent_id] = round(total_delta, 4)
151
+
152
+ # --- 4. Compute coordination score ---
153
+ # Sub-linear scaling: diminishing penalty per additional incident
154
+ # prevents score from collapsing to 0.0 for mildly bad teams
155
+ safety_corrections = sum(
156
+ 1 for r in safety_reports.values() if r.was_corrected
157
+ )
158
+
159
+ conflict_penalty = 1.0 - math.exp(-conflicts * 0.3)
160
+ selfish_penalty = 1.0 - math.exp(-selfish_count * 0.2)
161
+ safety_penalty = 1.0 - math.exp(-safety_corrections * 0.2)
162
+
163
+ base_score = (1.0
164
+ - 0.4 * conflict_penalty
165
+ - 0.3 * selfish_penalty
166
+ - 0.3 * safety_penalty)
167
+
168
+ # Frequency improvement bonus / degradation penalty
169
+ if freq_improved:
170
+ base_score += 0.1
171
+ else:
172
+ degradation = freq_dev_after - freq_dev_before
173
+ base_score -= min(degradation * 0.5, 0.2)
174
+
175
+ coordination_score = max(0.0, min(1.0, base_score))
176
+
177
+ return OversightReport(
178
+ coordination_score=round(coordination_score, 4),
179
+ conflicting_actions_detected=conflicts,
180
+ selfish_actions_detected=selfish_count,
181
+ coordination_penalties=penalties,
182
+ global_frequency_contribution=freq_contribution,
183
+ notes=notes,
184
+ )
185
+
186
+ def reset(self):
187
+ """Reset intervention history for a new episode."""
188
+ self.intervention_history = {
189
+ i: 0 for i in range(self.num_agents)
190
+ }
src/physics.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DC Power Flow Solver
3
+ ====================
4
+ Implements the standard DC approximation: B * θ = P
5
+
6
+ Assumptions:
7
+ - Flat voltage profile (|V| ≈ 1.0 p.u.)
8
+ - Small angle differences (sin(θ) ≈ θ)
9
+ - Negligible resistance (R ≈ 0, only susceptance used)
10
+
11
+ Flow sign convention:
12
+ flow = b * (θ_from - θ_to)
13
+ Positive flow = power flowing from 'from' bus to 'to' bus.
14
+ """
15
+
16
+ import logging
17
+ import warnings
18
+ import numpy as np
19
+ from typing import List, Dict, Tuple
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class IslandedException(Exception):
25
+ pass
26
+
27
+
28
+ class DCSolver:
29
+ """DC power flow solver with graph-based islanding detection.
30
+
31
+ The slack bus absorbs any power imbalance and has its voltage angle
32
+ fixed to 0 (reference). By default this is bus 0, but can be
33
+ configured via the slack_bus parameter.
34
+ """
35
+
36
+ def __init__(self, num_buses: int, slack_bus: int = 0):
37
+ self.num_buses = num_buses
38
+ self.slack_bus = slack_bus
39
+ self.B = np.zeros((num_buses, num_buses))
40
+ self.line_map = {}
41
+ self._grid_loaded = False
42
+
43
+ def update_grid(self, lines: List[Dict]):
44
+ """Rebuild the B matrix and check connectivity.
45
+
46
+ Skips zero-susceptance lines (no electrical contribution).
47
+ Validates bus indices to prevent silent corruption.
48
+ """
49
+ self.B = np.zeros((self.num_buses, self.num_buses))
50
+ self.line_map = {}
51
+
52
+ # Union-Find for O(n) connectivity check (replaces NetworkX)
53
+ parent = list(range(self.num_buses))
54
+ rank = [0] * self.num_buses
55
+
56
+ def find(x):
57
+ while parent[x] != x:
58
+ parent[x] = parent[parent[x]] # path compression
59
+ x = parent[x]
60
+ return x
61
+
62
+ def union(x, y):
63
+ rx, ry = find(x), find(y)
64
+ if rx == ry:
65
+ return
66
+ if rank[rx] < rank[ry]:
67
+ rx, ry = ry, rx
68
+ parent[ry] = rx
69
+ if rank[rx] == rank[ry]:
70
+ rank[rx] += 1
71
+
72
+ for line in lines:
73
+ if line['connected']:
74
+ i, j = line['from'], line['to']
75
+ b = line['susceptance']
76
+
77
+ # Validate bus indices
78
+ if not (0 <= i < self.num_buses and 0 <= j < self.num_buses):
79
+ raise ValueError(
80
+ f"Line {line['id']}: bus indices ({i}, {j}) out of range "
81
+ f"for {self.num_buses} buses"
82
+ )
83
+
84
+ # Skip zero-susceptance lines (no electrical contribution)
85
+ if abs(b) < 1e-12:
86
+ continue
87
+
88
+ self.B[i, j] -= b
89
+ self.B[j, i] -= b
90
+ self.B[i, i] += b
91
+ self.B[j, j] += b
92
+
93
+ self.line_map[line['id']] = (i, j, b)
94
+ union(i, j)
95
+
96
+ # Connectivity check via union-find
97
+ root = find(0)
98
+ if not all(find(i) == root for i in range(self.num_buses)):
99
+ # Build component info for diagnostics
100
+ components = {}
101
+ for i in range(self.num_buses):
102
+ r = find(i)
103
+ components.setdefault(r, []).append(i)
104
+ comp_sizes = [len(c) for c in components.values()]
105
+ raise IslandedException(
106
+ f"Grid is islanded: {len(components)} components, "
107
+ f"sizes={comp_sizes}"
108
+ )
109
+
110
+ self._grid_loaded = True
111
+
112
+ def solve(self, p_inj: np.ndarray) -> Tuple[np.ndarray, Dict[str, float], float]:
113
+ """Solve DC power flow: B_red * θ_red = P_red.
114
+
115
+ Args:
116
+ p_inj: Real power injection at each bus (MW). Shape must be (num_buses,).
117
+
118
+ Returns:
119
+ (theta, line_flows, slack_injection) tuple.
120
+ theta: voltage angles (radians). Slack bus angle = 0.
121
+ line_flows: {line_id: flow_MW}. Positive = from→to direction.
122
+ slack_injection: MW absorbed/injected by the slack bus.
123
+ """
124
+ if not self._grid_loaded:
125
+ raise RuntimeError("DCSolver.solve() called before update_grid()")
126
+
127
+ # Validate input
128
+ p_inj = np.asarray(p_inj).ravel()
129
+ if len(p_inj) != self.num_buses:
130
+ raise ValueError(
131
+ f"p_inj length {len(p_inj)} != num_buses {self.num_buses}"
132
+ )
133
+
134
+ # Remove slack bus row/column
135
+ mask = np.arange(self.num_buses) != self.slack_bus
136
+ B_red = self.B[np.ix_(mask, mask)]
137
+ p_red = p_inj[mask]
138
+
139
+ try:
140
+ theta_red = np.linalg.solve(B_red, p_red)
141
+ except np.linalg.LinAlgError:
142
+ raise IslandedException("Grid is islanded (singular B matrix)")
143
+
144
+ # Check conditioning
145
+ cond = np.linalg.cond(B_red)
146
+ if cond > 1e12:
147
+ warnings.warn(
148
+ f"DCSolver: B_red is ill-conditioned (cond={cond:.2e}). "
149
+ f"Results may be numerically unreliable.",
150
+ RuntimeWarning,
151
+ stacklevel=2,
152
+ )
153
+
154
+ # Insert slack bus angle (= 0)
155
+ theta = np.zeros(self.num_buses)
156
+ theta[mask] = theta_red
157
+
158
+ # Compute line flows
159
+ flows = {}
160
+ for line_id, (i, j, b) in self.line_map.items():
161
+ flows[line_id] = (theta[i] - theta[j]) * b
162
+
163
+ # Slack injection from power balance (more robust than summing flows)
164
+ slack_injection = -float(p_inj[mask].sum())
165
+
166
+ return theta, flows, slack_injection
167
+
168
+ def __repr__(self):
169
+ return (
170
+ f"DCSolver(num_buses={self.num_buses}, slack={self.slack_bus}, "
171
+ f"lines={len(self.line_map)}, loaded={self._grid_loaded})"
172
+ )
src/safety.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Safety Layer — Hard Constraint Filter for OpenGrid
3
+ ===================================================
4
+ Validates agent actions BEFORE they are applied to the environment.
5
+ If constraints are violated, actions are projected to the nearest safe alternative.
6
+
7
+ This is the core safety innovation: constraint violations should NEVER
8
+ reach the physics engine. The safety layer catches them first.
9
+
10
+ Checks:
11
+ 1. Anti-Islanding: topology actions that would disconnect the grid are blocked
12
+ 2. N-1 Security: for each critical line, simulate failure → check grid survives
13
+ 3. Generation Limits: bus adjustments respect ramp rates and capacity
14
+ 4. Zone Boundary: agents can only adjust buses in their assigned zone
15
+
16
+ References:
17
+ - KPTCL N-1 security criterion (Indian Grid Code, IEGC)
18
+ - Control Barrier Functions for safe RL (Ames et al., 2019)
19
+ """
20
+
21
+ import logging
22
+ import networkx as nx
23
+ import numpy as np
24
+ from typing import List, Dict, Tuple
25
+ from .models import GridAction, BusAdjustment, TopologyAction, SafetyReport
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class SafetyLayer:
31
+ """Hard constraint filter that validates and corrects agent actions.
32
+
33
+ The safety layer sits between agents and the environment:
34
+ Agent proposes action → SafetyLayer validates → corrected action → Environment
35
+
36
+ If an action would cause a constraint violation, it is PROJECTED to the
37
+ nearest safe alternative (not just rejected). This preserves the agent's
38
+ intent while enforcing safety, and provides a richer training signal
39
+ than binary accept/reject.
40
+ """
41
+
42
+ def __init__(self, config: Dict):
43
+ self.config = config
44
+ self.num_buses = config['num_buses']
45
+ self.lines_config = config['lines']
46
+ self.buses_config = config['buses']
47
+ self.zone_assignments = config.get('zone_assignments', {})
48
+ self.zone_enforcement = bool(self.zone_assignments)
49
+
50
+ # Build config index for O(1) lookups
51
+ self._bus_cfg_by_id = {b['id']: b for b in self.buses_config}
52
+
53
+ def validate_and_correct(
54
+ self,
55
+ agent_id: int,
56
+ proposed_action: GridAction,
57
+ current_line_state: List[Dict],
58
+ current_bus_state: List[Dict],
59
+ cooldowns: Dict[str, int],
60
+ ) -> Tuple[GridAction, SafetyReport]:
61
+ """Full validation pipeline for one agent's proposed action.
62
+
63
+ Returns:
64
+ corrected_action: Safe version of the proposed action
65
+ report: Details of what was checked and corrected
66
+ """
67
+ corrections = []
68
+ n1_violations = 0
69
+
70
+ # Track original action stats
71
+ original_delta = sum(abs(a.delta) for a in proposed_action.bus_adjustments)
72
+ proposed_topo_count = len(proposed_action.topology_actions)
73
+ blocked_topo_count = 0
74
+
75
+ # Build bus state index for O(1) lookups
76
+ bus_dyn_by_id = {b['id']: b for b in current_bus_state}
77
+
78
+ # --- 1. Zone boundary enforcement ---
79
+ safe_bus_adj = []
80
+ for adj in proposed_action.bus_adjustments:
81
+ bus_zone = self.zone_assignments.get(adj.bus_id, -1)
82
+ if not self.zone_enforcement or bus_zone == agent_id:
83
+ # Agent owns this bus, or single-agent mode
84
+ safe_bus_adj.append(adj)
85
+ else:
86
+ corrections.append(
87
+ f"Blocked bus {adj.bus_id} adjustment: "
88
+ f"belongs to zone {bus_zone}, not agent {agent_id}"
89
+ )
90
+
91
+ # --- 2. Generation limit enforcement ---
92
+ # Aggregate adjustments per bus to prevent double-spending
93
+ bus_deltas: Dict[int, float] = {}
94
+ for adj in safe_bus_adj:
95
+ bus_deltas[adj.bus_id] = bus_deltas.get(adj.bus_id, 0.0) + adj.delta
96
+
97
+ clamped_bus_adj = []
98
+ for bus_id, total_delta in bus_deltas.items():
99
+ bus_cfg = self._bus_cfg_by_id.get(bus_id)
100
+ bus_dyn = bus_dyn_by_id.get(bus_id)
101
+ if bus_cfg is None or bus_dyn is None:
102
+ corrections.append(f"Blocked bus {bus_id}: not found")
103
+ continue
104
+
105
+ delta = total_delta
106
+ bus_type = bus_cfg['type']
107
+
108
+ # Loads and renewables can't be directly adjusted
109
+ if bus_type in ['load', 'solar', 'wind']:
110
+ corrections.append(
111
+ f"Blocked bus {bus_id}: type '{bus_type}' is not controllable"
112
+ )
113
+ continue
114
+
115
+ # Enforce ramp rate
116
+ max_ramp = bus_cfg.get('ramp_rate', 20.0)
117
+ if abs(delta) > max_ramp:
118
+ delta = np.clip(delta, -max_ramp, max_ramp)
119
+ corrections.append(
120
+ f"Clamped bus {bus_id} delta to ramp rate ±{max_ramp}"
121
+ )
122
+
123
+ # Enforce battery SoC limits
124
+ if bus_type == 'battery':
125
+ soc = bus_dyn.get('soc', 0.0)
126
+ capacity = bus_cfg.get('capacity', 50.0)
127
+ if delta > 0 and delta > soc:
128
+ delta = soc
129
+ corrections.append(
130
+ f"Clamped bus {bus_id} discharge to SoC={soc:.1f}"
131
+ )
132
+ elif delta < 0 and abs(delta) > (capacity - soc):
133
+ delta = -(capacity - soc)
134
+ corrections.append(
135
+ f"Clamped bus {bus_id} charge to remaining capacity"
136
+ )
137
+
138
+ # Enforce generator limits
139
+ # NOTE: This is a best-effort projection based on pre-step state.
140
+ # If multiple agents adjust the same bus via different zones,
141
+ # the environment provides a secondary clamp.
142
+ if bus_type in ['slack', 'generator']:
143
+ current_p = bus_dyn.get('p', 0.0)
144
+ new_p = current_p + delta
145
+ min_p = bus_cfg.get('min_p', -50)
146
+ max_p = bus_cfg.get('max_p', 100)
147
+ if new_p < min_p or new_p > max_p:
148
+ new_p = np.clip(new_p, min_p, max_p)
149
+ delta = new_p - current_p
150
+ corrections.append(
151
+ f"Clamped bus {bus_id} to generation limits "
152
+ f"[{min_p}, {max_p}]"
153
+ )
154
+
155
+ clamped_bus_adj.append(BusAdjustment(bus_id=bus_id, delta=delta))
156
+
157
+ # --- 3. Topology safety (anti-islanding + N-1) ---
158
+ # Build base graph once for all topology checks
159
+ base_graph = self._build_connectivity_graph(current_line_state)
160
+
161
+ safe_topo = []
162
+ approved_opens: set = set() # Track approved opens for cumulative check
163
+ for t_act in proposed_action.topology_actions:
164
+ line_id = t_act.line_id
165
+
166
+ # Check cooldown
167
+ if cooldowns.get(line_id, 0) > 0:
168
+ corrections.append(
169
+ f"Blocked {line_id} switch: cooldown active "
170
+ f"({cooldowns[line_id]} steps)"
171
+ )
172
+ blocked_topo_count += 1
173
+ continue
174
+
175
+ # Check if opening this line would island the grid
176
+ # (cumulative: checks against already-approved opens)
177
+ if t_act.action == "open":
178
+ if self._would_island(
179
+ line_id, base_graph, additional_opens=approved_opens
180
+ ):
181
+ corrections.append(
182
+ f"Blocked opening {line_id}: would island the grid"
183
+ )
184
+ blocked_topo_count += 1
185
+ n1_violations += 1
186
+ continue
187
+ approved_opens.add(line_id)
188
+
189
+ safe_topo.append(t_act)
190
+
191
+ # --- 4. N-1 check on final combined action ---
192
+ if safe_topo:
193
+ n1_fails = self._check_n1_post_action(safe_topo, current_line_state)
194
+ if n1_fails > 0:
195
+ n1_violations += n1_fails
196
+ corrections.append(
197
+ f"N-1 warning: {n1_fails} lines would leave grid "
198
+ f"vulnerable after action"
199
+ )
200
+
201
+ corrected_action = GridAction(
202
+ bus_adjustments=clamped_bus_adj,
203
+ topology_actions=safe_topo
204
+ )
205
+
206
+ corrected_delta = sum(abs(a.delta) for a in clamped_bus_adj)
207
+
208
+ was_corrected = len(corrections) > 0
209
+ report = SafetyReport(
210
+ agent_id=agent_id,
211
+ was_corrected=was_corrected,
212
+ correction_reason="; ".join(corrections) if corrections else "",
213
+ n1_violations_detected=n1_violations,
214
+ proposed_topology_actions=proposed_topo_count,
215
+ blocked_topology_actions=blocked_topo_count,
216
+ original_total_delta_mw=round(original_delta, 4),
217
+ corrected_total_delta_mw=round(corrected_delta, 4),
218
+ )
219
+
220
+ return corrected_action, report
221
+
222
+ def _build_connectivity_graph(
223
+ self, current_line_state: List[Dict]
224
+ ) -> nx.Graph:
225
+ """Build the connectivity graph from current line state (once)."""
226
+ G = nx.Graph()
227
+ G.add_nodes_from(range(self.num_buses))
228
+
229
+ line_dyn_by_id = {l['id']: l for l in current_line_state}
230
+ for l_cfg in self.lines_config:
231
+ l_dyn = line_dyn_by_id.get(l_cfg['id'])
232
+ if l_dyn is not None and l_dyn.get('connected', True):
233
+ G.add_edge(l_cfg['from'], l_cfg['to'])
234
+
235
+ return G
236
+
237
+ def _would_island(
238
+ self,
239
+ line_id: str,
240
+ base_graph: nx.Graph,
241
+ additional_opens: set = None,
242
+ ) -> bool:
243
+ """Check if opening a line would disconnect the grid.
244
+
245
+ Takes cumulative approved opens into account so that
246
+ multiple simultaneous opens are correctly checked.
247
+ """
248
+ additional_opens = additional_opens or set()
249
+
250
+ # Find the edge for this line
251
+ line_cfg = next(
252
+ (l for l in self.lines_config if l['id'] == line_id), None
253
+ )
254
+ if line_cfg is None:
255
+ return False
256
+
257
+ # Build a test graph with all proposed removals
258
+ G = base_graph.copy()
259
+ # Remove previously approved opens
260
+ for open_id in additional_opens:
261
+ open_cfg = next(
262
+ (l for l in self.lines_config if l['id'] == open_id), None
263
+ )
264
+ if open_cfg and G.has_edge(open_cfg['from'], open_cfg['to']):
265
+ G.remove_edge(open_cfg['from'], open_cfg['to'])
266
+
267
+ # Remove the line under test
268
+ if G.has_edge(line_cfg['from'], line_cfg['to']):
269
+ G.remove_edge(line_cfg['from'], line_cfg['to'])
270
+
271
+ return not nx.is_connected(G)
272
+
273
+ def _check_n1_post_action(
274
+ self,
275
+ topo_actions: List[TopologyAction],
276
+ current_line_state: List[Dict],
277
+ ) -> int:
278
+ """Check N-1 security after applying proposed topology actions.
279
+
280
+ For each remaining connected line, simulate its loss and check
281
+ connectivity. Uses edge removal/restoration instead of rebuilding
282
+ the full graph for each contingency.
283
+
284
+ Returns the number of single-line contingencies that would island.
285
+ """
286
+ # Build the post-action line state
287
+ post_state = {}
288
+ for l_dyn in current_line_state:
289
+ post_state[l_dyn['id']] = l_dyn.get('connected', True)
290
+ for t_act in topo_actions:
291
+ post_state[t_act.line_id] = (t_act.action == "close")
292
+
293
+ # Build post-action graph once
294
+ G = nx.Graph()
295
+ G.add_nodes_from(range(self.num_buses))
296
+
297
+ edge_to_line = {}
298
+ for l_cfg in self.lines_config:
299
+ if post_state.get(l_cfg['id'], True):
300
+ u, v = l_cfg['from'], l_cfg['to']
301
+ G.add_edge(u, v)
302
+ edge_to_line[(u, v)] = l_cfg['id']
303
+
304
+ # Test each contingency via edge removal/restoration
305
+ n1_failures = 0
306
+ for (u, v), line_id in edge_to_line.items():
307
+ G.remove_edge(u, v)
308
+ if not nx.is_connected(G):
309
+ n1_failures += 1
310
+ G.add_edge(u, v) # restore
311
+
312
+ return n1_failures
313
+
314
+ def reset(self):
315
+ """Reset any per-episode state (future-proofing)."""
316
+ pass
src/tasks.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grid Generator & Task Definitions
3
+ ===================================
4
+ Generates reproducible power grid configurations for OpenGrid RL tasks.
5
+
6
+ Procedural grids use Watts-Strogatz small-world topology with
7
+ configurable difficulty (bus count, renewable penetration).
8
+
9
+ The Karnataka task is a hand-crafted 15-bus grid based on the
10
+ actual KPTCL transmission map.
11
+ """
12
+
13
+ import copy
14
+ import networkx as nx
15
+ import numpy as np
16
+ from typing import Dict, List, Tuple
17
+
18
+ __all__ = ['generate_procedural_grid', 'generate_karnataka_task', 'TASKS', 'get_task']
19
+
20
+
21
+ # KPTCL-inspired zone names
22
+ def _get_zone_names(num_agents: int) -> List[str]:
23
+ """Get human-readable zone names for a given agent count."""
24
+ base_names = [
25
+ "Bengaluru_Region", "Mysuru_Region", "Kalburagi_Region",
26
+ "Belagavi_Region", "Mangaluru_Region",
27
+ ]
28
+ if num_agents <= len(base_names):
29
+ return base_names[:num_agents]
30
+ return [f"Zone_{i}" for i in range(num_agents)]
31
+
32
+
33
+ def _partition_into_zones(G: nx.Graph, num_agents: int) -> Dict[int, int]:
34
+ """Partition graph nodes into balanced, connected zones.
35
+
36
+ Returns mapping of {bus_id: agent_id}.
37
+ Guarantees: every bus is assigned, each zone has at least 1 node,
38
+ and zones are roughly balanced in size.
39
+
40
+ NOTE: Uses greedy modularity which is deterministic for a given graph
41
+ structure but not guaranteed across NetworkX versions.
42
+ """
43
+ nodes = sorted(G.nodes())
44
+ n = len(nodes)
45
+
46
+ if n <= num_agents:
47
+ # Trivial case: 1 bus per agent
48
+ return {node: i for i, node in enumerate(nodes)}
49
+
50
+ try:
51
+ communities = nx.community.greedy_modularity_communities(G, cutoff=num_agents)
52
+ communities = [set(c) for c in sorted(communities, key=len, reverse=True)]
53
+ except Exception:
54
+ # Fallback: round-robin assignment by node index
55
+ communities = [set() for _ in range(num_agents)]
56
+ for i, node in enumerate(nodes):
57
+ communities[i % num_agents].add(node)
58
+
59
+ # If we got more communities than agents, merge smallest into largest
60
+ while len(communities) > num_agents:
61
+ smallest = communities.pop()
62
+ communities[0] = communities[0] | smallest
63
+
64
+ # If we got fewer, split the largest using topology-aware bisection
65
+ while len(communities) < num_agents:
66
+ largest = max(communities, key=len)
67
+ communities.remove(largest)
68
+
69
+ # Attempt topology-aware split
70
+ subG = G.subgraph(largest).copy()
71
+ split_done = False
72
+ if nx.is_connected(subG) and len(largest) >= 2:
73
+ # Find edge whose removal creates the most balanced partition
74
+ best_edge, best_balance = None, float('inf')
75
+ target = len(largest) / 2
76
+ for u, v in subG.edges():
77
+ subG.remove_edge(u, v)
78
+ components = list(nx.connected_components(subG))
79
+ if len(components) == 2:
80
+ balance = abs(len(components[0]) - target) + abs(len(components[1]) - target)
81
+ if balance < best_balance:
82
+ best_edge = (u, v)
83
+ best_balance = balance
84
+ subG.add_edge(u, v)
85
+ if best_edge:
86
+ subG.remove_edge(*best_edge)
87
+ parts = list(nx.connected_components(subG))
88
+ communities.extend(parts)
89
+ split_done = True
90
+
91
+ if not split_done:
92
+ # Fallback: arbitrary split
93
+ largest_list = sorted(largest)
94
+ half = len(largest) // 2
95
+ communities.append(set(largest_list[:half]))
96
+ communities.append(set(largest_list[half:]))
97
+
98
+ # Ensure no empty zones
99
+ for i, comm in enumerate(communities):
100
+ if len(comm) == 0:
101
+ # Steal a node from the largest community
102
+ largest = max(communities, key=len)
103
+ stolen = next(iter(largest))
104
+ largest.remove(stolen)
105
+ communities[i] = {stolen}
106
+
107
+ zone_map = {}
108
+ for agent_id, comm in enumerate(communities):
109
+ for node in comm:
110
+ zone_map[node] = agent_id
111
+
112
+ return zone_map
113
+
114
+
115
+ def _classify_lines(
116
+ lines_config: List[Dict], zone_assignments: Dict[int, int]
117
+ ) -> Tuple[Dict[int, List[str]], Dict[int, List[str]]]:
118
+ """Classify lines as internal (both endpoints in same zone) or boundary.
119
+
120
+ Returns:
121
+ internal_lines: {agent_id: [line_ids within this zone]}
122
+ boundary_lines: {agent_id: [line_ids on this zone's boundary]}
123
+ """
124
+ agents = set(zone_assignments.values())
125
+ internal = {a: [] for a in agents}
126
+ boundary = {a: [] for a in agents}
127
+
128
+ for line in lines_config:
129
+ from_zone = zone_assignments.get(line['from'])
130
+ to_zone = zone_assignments.get(line['to'])
131
+
132
+ # Skip lines with unassigned bus endpoints
133
+ if from_zone is None or to_zone is None:
134
+ continue
135
+
136
+ if from_zone == to_zone:
137
+ internal[from_zone].append(line['id'])
138
+ else:
139
+ boundary[from_zone].append(line['id'])
140
+ boundary[to_zone].append(line['id'])
141
+
142
+ return internal, boundary
143
+
144
+
145
+ def generate_procedural_grid(difficulty: str = "easy", seed: int = 42):
146
+ """Generate a reproducible grid configuration for a given difficulty level.
147
+
148
+ Easy: 5 buses, 20% renewables — simple balancing
149
+ Medium: 10 buses, 50% renewables — congestion management
150
+ Hard: 14 buses, 70% renewables — volatile supply, tight margins
151
+
152
+ Guarantees: at least 30% of non-slack buses are loads, and at least 1 battery.
153
+ Includes multi-agent zone assignments for POMDP mode.
154
+ """
155
+ rng = np.random.default_rng(seed)
156
+
157
+ if difficulty == "easy":
158
+ n_buses = 5
159
+ renewable_mix = 0.2
160
+ max_steps = 50
161
+ num_agents = 2 # Small grid: 2 agents
162
+ elif difficulty == "medium":
163
+ n_buses = 10
164
+ renewable_mix = 0.5
165
+ max_steps = 50
166
+ num_agents = 3
167
+ else: # Hard
168
+ n_buses = 14
169
+ renewable_mix = 0.7
170
+ max_steps = 50
171
+ num_agents = 3
172
+
173
+ G = nx.connected_watts_strogatz_graph(n_buses, k=4, p=0.3, seed=seed)
174
+
175
+ # Generate bus types with guaranteed minimums
176
+ n_non_slack = n_buses - 1
177
+ min_loads = max(2, int(n_non_slack * 0.3)) # At least 30% loads
178
+ min_batteries = 1
179
+
180
+ types = ['slack']
181
+
182
+ # Assign guaranteed loads first
183
+ assigned = []
184
+ for _ in range(min_loads):
185
+ assigned.append('load')
186
+ for _ in range(min_batteries):
187
+ assigned.append('battery')
188
+
189
+ # Fill remaining slots with renewable_mix probability
190
+ remaining = n_non_slack - len(assigned)
191
+ for _ in range(remaining):
192
+ r = rng.random()
193
+ if r < renewable_mix:
194
+ assigned.append(str(rng.choice(['solar', 'wind'])))
195
+ elif r < renewable_mix + 0.15:
196
+ assigned.append('battery')
197
+ else:
198
+ assigned.append('load')
199
+
200
+ # Shuffle to avoid spatial bias (loads always first)
201
+ rng.shuffle(assigned)
202
+ types.extend(assigned)
203
+
204
+ # Estimate total load for slack bus sizing
205
+ load_estimates = []
206
+ buses = []
207
+ lines = []
208
+
209
+ for i, t in enumerate(types):
210
+ base_p = float(rng.uniform(20, 50)) if t == 'load' else 0
211
+ if t == 'load':
212
+ load_estimates.append(base_p)
213
+
214
+ # Set max_p based on bus type
215
+ if t == 'battery':
216
+ max_p = float(rng.uniform(30, 60)) # batteries can discharge
217
+ elif t in ['solar', 'wind', 'generator']:
218
+ max_p = float(rng.uniform(50, 100))
219
+ elif t == 'slack':
220
+ # Slack max_p sized to cover expected imbalance
221
+ max_p = 0 # placeholder, set below
222
+ else:
223
+ max_p = 0
224
+
225
+ buses.append({
226
+ 'id': i, 'type': t,
227
+ 'base_p': base_p,
228
+ 'max_p': max_p,
229
+ 'min_p': 0 if t in ['solar', 'wind', 'generator'] else -50,
230
+ 'capacity': 50 if t == 'battery' else 0,
231
+ 'init_soc': 25.0 if t == 'battery' else 0,
232
+ 'ramp_rate': 20.0 if t not in ['load', 'solar', 'wind'] else 0.0,
233
+ })
234
+
235
+ # Size slack bus to cover expected imbalance
236
+ total_load_est = sum(load_estimates) if load_estimates else 100
237
+ slack_max_p = max(100, total_load_est * 0.6)
238
+ for b in buses:
239
+ if b['type'] == 'slack':
240
+ b['max_p'] = slack_max_p
241
+ b['min_p'] = -slack_max_p
242
+
243
+ for idx, (u, v) in enumerate(G.edges()):
244
+ lines.append({
245
+ 'id': f"L_{idx}",
246
+ 'from': u, 'to': v,
247
+ 'susceptance': 50.0,
248
+ 'capacity': float(rng.uniform(80, 150))
249
+ })
250
+
251
+ # Multi-agent zone assignment
252
+ zone_assignments = _partition_into_zones(G, num_agents)
253
+ internal_lines, boundary_lines = _classify_lines(lines, zone_assignments)
254
+
255
+ zone_names = _get_zone_names(num_agents)
256
+
257
+ # Build per-zone bus lists
258
+ zone_bus_ids = {a: [] for a in range(num_agents)}
259
+ for bus_id, agent_id in zone_assignments.items():
260
+ zone_bus_ids[agent_id].append(bus_id)
261
+
262
+ return {
263
+ "id": f"task_{difficulty}",
264
+ "num_buses": n_buses,
265
+ "buses": buses,
266
+ "lines": lines,
267
+ "max_steps": max_steps,
268
+ "seed": seed,
269
+ "difficulty": difficulty,
270
+ # Multi-agent fields
271
+ "num_agents": num_agents,
272
+ "zone_assignments": zone_assignments, # {bus_id: agent_id}
273
+ "zone_names": zone_names,
274
+ "zone_bus_ids": zone_bus_ids, # {agent_id: [bus_ids]}
275
+ "internal_lines": internal_lines, # {agent_id: [line_ids]}
276
+ "boundary_lines": boundary_lines, # {agent_id: [line_ids]}
277
+ }
278
+
279
+
280
+ def generate_karnataka_task(seed: int = 808) -> Dict:
281
+ """
282
+ A highly realistic 15-bus grid topology based on the actual Karnataka
283
+ KPTCL transmission map. Nodes have real GPS coordinates for GIS rendering.
284
+ """
285
+ nodes = [
286
+ {"id": 0, "name": "Raichur_TPS", "type": "slack", "lat": 16.20, "lon": 77.36, "max_p": 200, "base_p": 0},
287
+ {"id": 1, "name": "Kalaburagi", "type": "load", "lat": 17.33, "lon": 76.83, "max_p": 0, "base_p": 40},
288
+ {"id": 2, "name": "Belagavi", "type": "load", "lat": 15.85, "lon": 74.50, "max_p": 0, "base_p": 35},
289
+ {"id": 3, "name": "Hubballi", "type": "load", "lat": 15.36, "lon": 75.13, "max_p": 0, "base_p": 45},
290
+ {"id": 4, "name": "Ballari_TPS", "type": "generator", "lat": 15.14, "lon": 76.92, "max_p": 150, "base_p": 0},
291
+ {"id": 5, "name": "Chitradurga_Wind", "type": "wind", "lat": 14.23, "lon": 76.40, "max_p": 80, "base_p": 0},
292
+ {"id": 6, "name": "Pavagada_Solar", "type": "solar", "lat": 14.10, "lon": 77.27, "max_p": 120, "base_p": 0},
293
+ {"id": 7, "name": "Sharavathi_Hydro", "type": "generator", "lat": 14.18, "lon": 74.83, "max_p": 100, "base_p": 0},
294
+ {"id": 8, "name": "Shivamogga", "type": "load", "lat": 13.93, "lon": 75.57, "max_p": 0, "base_p": 30},
295
+ {"id": 9, "name": "Mangaluru", "type": "load", "lat": 12.87, "lon": 74.88, "max_p": 0, "base_p": 50},
296
+ {"id": 10, "name": "Hassan_BESS", "type": "battery", "lat": 13.01, "lon": 76.10, "max_p": 50, "base_p": 0},
297
+ {"id": 11, "name": "Mysuru", "type": "load", "lat": 12.30, "lon": 76.64, "max_p": 0, "base_p": 40},
298
+ {"id": 12, "name": "Nelamangala", "type": "battery", "lat": 13.10, "lon": 77.39, "max_p": 50, "base_p": 0},
299
+ {"id": 13, "name": "Bengaluru_City", "type": "load", "lat": 12.97, "lon": 77.59, "max_p": 0, "base_p": 120},
300
+ {"id": 14, "name": "Kolar_Solar", "type": "solar", "lat": 13.13, "lon": 78.13, "max_p": 60, "base_p": 0},
301
+ ]
302
+
303
+ edges = [
304
+ (0,1), (0,4), (4,5), (4,6), (5,3), (3,2), (3,7),
305
+ (7,8), (8,9), (8,10), (9,10), # (9,10) added: connects Mangaluru within zone 2
306
+ (10,11), (10,12), (5,12),
307
+ (6,12), (12,13), (13,14), (11,13)
308
+ ]
309
+
310
+ buses = []
311
+ for n in nodes:
312
+ buses.append({
313
+ 'id': n['id'], 'name': n['name'], 'type': n['type'],
314
+ 'lat': n['lat'], 'lon': n['lon'],
315
+ 'base_p': n['base_p'], 'max_p': n['max_p'],
316
+ 'min_p': 0 if n['type'] in ['solar', 'wind', 'generator'] else -50,
317
+ 'capacity': 100 if n['type'] == 'battery' else 0,
318
+ 'init_soc': 50.0 if n['type'] == 'battery' else 0,
319
+ 'ramp_rate': 40.0 if n['type'] not in ['load', 'solar', 'wind'] else 0.0,
320
+ })
321
+
322
+ lines = []
323
+ for idx, (u, v) in enumerate(edges):
324
+ lines.append({
325
+ 'id': f"L_{u}_{v}", 'from': u, 'to': v,
326
+ 'susceptance': 80.0, 'capacity': 150.0
327
+ })
328
+
329
+ # Realistic agents based on regional discoms/SLDC zones
330
+ zone_assignments = {
331
+ 0: 0, 1: 0, 4: 0, # North Zone (Raichur/Bellary)
332
+ 2: 1, 3: 1, 5: 1, 7: 1, 8: 1, # Hubli/Central Zone
333
+ 9: 2, 10: 2, 11: 2, # Mysuru/Coast Zone
334
+ 6: 3, 12: 3, 13: 3, 14: 3 # Bengaluru Zone
335
+ }
336
+
337
+ internal_lines, boundary_lines = _classify_lines(lines, zone_assignments)
338
+
339
+ zone_bus_ids = {a: [] for a in range(4)}
340
+ for b_id, a_id in zone_assignments.items():
341
+ zone_bus_ids[a_id].append(b_id)
342
+
343
+ return {
344
+ "id": "task_karnataka",
345
+ "num_buses": len(buses),
346
+ "buses": buses,
347
+ "lines": lines,
348
+ "max_steps": 50,
349
+ "seed": seed,
350
+ "difficulty": "karnataka",
351
+ "num_agents": 4,
352
+ "zone_assignments": zone_assignments,
353
+ "zone_names": ["Kalaburagi_Region", "Hubballi_Region", "Mysuru_Region", "Bengaluru_Region"],
354
+ "zone_bus_ids": zone_bus_ids,
355
+ "internal_lines": internal_lines,
356
+ "boundary_lines": boundary_lines,
357
+ }
358
+
359
+
360
+ def get_task(task_id: str) -> Dict:
361
+ """Get a deep-copied task config by ID."""
362
+ if task_id not in _TASK_GENERATORS:
363
+ raise ValueError(
364
+ f"Unknown task: {task_id}. "
365
+ f"Available: {list(_TASK_GENERATORS.keys())}"
366
+ )
367
+ return copy.deepcopy(_TASK_GENERATORS[task_id]())
368
+
369
+
370
+ _TASK_GENERATORS = {
371
+ "task_easy": lambda: generate_procedural_grid("easy", seed=101),
372
+ "task_medium": lambda: generate_procedural_grid("medium", seed=102),
373
+ "task_hard": lambda: generate_procedural_grid("hard", seed=103),
374
+ "task_karnataka": lambda: generate_karnataka_task(),
375
+ }
376
+
377
+ # Deterministic tasks — same seed always produces the same grid
378
+ # NOTE: These are shared instances. Use get_task() for a mutable copy.
379
+ TASKS = {
380
+ "task_easy": generate_procedural_grid("easy", seed=101),
381
+ "task_medium": generate_procedural_grid("medium", seed=102),
382
+ "task_hard": generate_procedural_grid("hard", seed=103),
383
+ "task_karnataka": generate_karnataka_task()
384
+ }
src/visualization.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grid Visualization — Dashboard Generator
3
+ ==========================================
4
+ Generates a base64-encoded PNG dashboard with two panels:
5
+ 1. Grid topology with bus-type coloring and line-loading heat map
6
+ 2. Frequency stability trace over time
7
+
8
+ Supports both GridObservation (single-agent) and ZoneObservation (multi-agent).
9
+ """
10
+
11
+ import io
12
+ import base64
13
+ import logging
14
+ from typing import List, Optional, Sequence, Dict, Tuple
15
+
16
+ import matplotlib
17
+ matplotlib.use('Agg') # Non-interactive backend for server use
18
+ import matplotlib.pyplot as plt
19
+ from matplotlib.lines import Line2D
20
+ import networkx as nx
21
+
22
+ from .models import GridObservation
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def _parse_line_endpoints(line_id: str) -> Optional[Tuple[int, int]]:
28
+ """Parse line ID format 'L_<from>_<to>' into endpoint bus IDs.
29
+
30
+ Returns (from, to) on success, None on parse failure.
31
+ Requires exactly the format L_<int>_<int>.
32
+ """
33
+ try:
34
+ parts = line_id.split('_')
35
+ if len(parts) == 3 and parts[0] == "L":
36
+ return int(parts[1]), int(parts[2])
37
+ except (ValueError, IndexError):
38
+ pass
39
+ return None
40
+
41
+
42
+ def generate_dashboard(
43
+ history: Sequence,
44
+ current_obs,
45
+ config: Optional[Dict] = None,
46
+ ) -> str:
47
+ """Generate a base64-encoded PNG dashboard image.
48
+
49
+ Args:
50
+ history: Sequence of observation objects for frequency trace.
51
+ current_obs: Current GridObservation or ZoneObservation for topology.
52
+ config: Optional grid config dict. When provided, line endpoints
53
+ are read from config (robust) instead of parsed from IDs.
54
+
55
+ Returns:
56
+ Base64-encoded PNG image string (without data URI prefix).
57
+ """
58
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
59
+
60
+ try:
61
+ # Support both GridObservation and ZoneObservation
62
+ buses = getattr(current_obs, "buses",
63
+ getattr(current_obs, "local_buses", []))
64
+ lines = getattr(current_obs, "lines", None)
65
+ if lines is None:
66
+ internal = getattr(current_obs, "internal_lines", [])
67
+ boundary = getattr(current_obs, "boundary_lines", [])
68
+ lines = list(internal) + list(boundary)
69
+
70
+ # Build line endpoint lookup from config if available
71
+ line_endpoints: Dict[str, Tuple[int, int]] = {}
72
+ if config:
73
+ for l_cfg in config.get("lines", []):
74
+ line_endpoints[l_cfg["id"]] = (l_cfg["from"], l_cfg["to"])
75
+
76
+ # --- Plot 1: Grid Topology ---
77
+ G = nx.Graph()
78
+
79
+ color_map = {}
80
+ for bus in buses:
81
+ G.add_node(bus.id)
82
+ if bus.type in ['generator', 'slack']:
83
+ color_map[bus.id] = '#2ecc71' # green
84
+ elif bus.type == 'load':
85
+ color_map[bus.id] = '#e74c3c' # red
86
+ elif bus.type == 'battery':
87
+ color_map[bus.id] = '#3498db' # blue
88
+ else:
89
+ color_map[bus.id] = '#f1c40f' # yellow (renewables)
90
+
91
+ # Build graph with line data as edge attributes
92
+ for line in lines:
93
+ # Get endpoints from config (preferred) or parse from ID
94
+ if line.id in line_endpoints:
95
+ u, v = line_endpoints[line.id]
96
+ else:
97
+ parsed = _parse_line_endpoints(line.id)
98
+ if parsed is None:
99
+ continue
100
+ u, v = parsed
101
+
102
+ G.add_edge(u, v, line_id=line.id, rho=line.rho,
103
+ connected=line.connected)
104
+
105
+ # Build edge colors in G.edges() order (correct alignment)
106
+ edge_colors = []
107
+ edge_styles = []
108
+ for u, v, data in G.edges(data=True):
109
+ connected = data.get('connected', True)
110
+ rho = abs(data.get('rho', 0.0))
111
+
112
+ if not connected:
113
+ edge_colors.append('lightgray')
114
+ edge_styles.append('dashed')
115
+ elif rho > 0.9:
116
+ edge_colors.append('#e74c3c') # red
117
+ edge_styles.append('solid')
118
+ elif rho > 0.7:
119
+ edge_colors.append('#e67e22') # orange
120
+ edge_styles.append('solid')
121
+ else:
122
+ edge_colors.append('#2ecc71') # green
123
+ edge_styles.append('solid')
124
+
125
+ node_colors = [color_map.get(n, 'gray') for n in G.nodes()]
126
+
127
+ # Use config coordinates if available (stable layout)
128
+ pos = None
129
+ if config:
130
+ bus_coords = {}
131
+ for b_cfg in config.get("buses", []):
132
+ if "lon" in b_cfg and "lat" in b_cfg:
133
+ bus_coords[b_cfg["id"]] = (b_cfg["lon"], b_cfg["lat"])
134
+ if len(bus_coords) == G.number_of_nodes():
135
+ pos = bus_coords
136
+
137
+ if pos is None and G.number_of_nodes() > 0:
138
+ pos = nx.spring_layout(G, seed=42)
139
+
140
+ if G.number_of_nodes() > 0 and pos:
141
+ # Draw solid edges
142
+ solid_edges = [
143
+ (u, v) for (u, v, _), s in zip(G.edges(data=True), edge_styles)
144
+ if s == 'solid'
145
+ ]
146
+ solid_colors = [
147
+ c for c, s in zip(edge_colors, edge_styles) if s == 'solid'
148
+ ]
149
+ dashed_edges = [
150
+ (u, v) for (u, v, _), s in zip(G.edges(data=True), edge_styles)
151
+ if s == 'dashed'
152
+ ]
153
+ dashed_colors = [
154
+ c for c, s in zip(edge_colors, edge_styles) if s == 'dashed'
155
+ ]
156
+
157
+ nx.draw_networkx_nodes(
158
+ G, pos, ax=ax1, node_color=node_colors, node_size=300
159
+ )
160
+ nx.draw_networkx_labels(G, pos, ax=ax1, font_size=8)
161
+
162
+ if solid_edges:
163
+ nx.draw_networkx_edges(
164
+ G, pos, ax=ax1, edgelist=solid_edges,
165
+ edge_color=solid_colors, width=2, style='solid'
166
+ )
167
+ if dashed_edges:
168
+ nx.draw_networkx_edges(
169
+ G, pos, ax=ax1, edgelist=dashed_edges,
170
+ edge_color=dashed_colors, width=1, style='dashed'
171
+ )
172
+
173
+ # Legend
174
+ legend_elements = [
175
+ Line2D([0], [0], marker='o', color='w',
176
+ markerfacecolor='#2ecc71', markersize=10,
177
+ label='Generator/Slack'),
178
+ Line2D([0], [0], marker='o', color='w',
179
+ markerfacecolor='#e74c3c', markersize=10,
180
+ label='Load'),
181
+ Line2D([0], [0], marker='o', color='w',
182
+ markerfacecolor='#3498db', markersize=10,
183
+ label='Battery'),
184
+ Line2D([0], [0], marker='o', color='w',
185
+ markerfacecolor='#f1c40f', markersize=10,
186
+ label='Renewable'),
187
+ ]
188
+ ax1.legend(handles=legend_elements, loc='upper left', fontsize=7)
189
+ else:
190
+ ax1.text(0.5, 0.5, "No buses in observation",
191
+ ha='center', va='center', transform=ax1.transAxes)
192
+
193
+ ax1.set_title("Grid Topology & Loading")
194
+
195
+ # --- Plot 2: Frequency Trace ---
196
+ if history:
197
+ history_sorted = sorted(history, key=lambda h: h.timestep)
198
+ timesteps = [h.timestep for h in history_sorted]
199
+ freqs = [h.grid_frequency for h in history_sorted]
200
+
201
+ ax2.plot(timesteps, freqs, label='Frequency (Hz)',
202
+ color='#2980b9', linewidth=1.5)
203
+ ax2.axhline(y=50.0, color='k', linestyle='--', linewidth=0.8)
204
+ ax2.fill_between(timesteps, 49.5, 50.5,
205
+ color='green', alpha=0.1, label='Normal band')
206
+ ax2.legend(fontsize=8)
207
+ else:
208
+ ax2.text(0.5, 0.5, "No frequency history",
209
+ ha='center', va='center', transform=ax2.transAxes)
210
+
211
+ ax2.set_title("Frequency Stability")
212
+ ax2.set_xlabel("Timestep")
213
+ ax2.set_ylabel("Hz")
214
+ ax2.set_ylim(48.5, 51.5)
215
+
216
+ fig.tight_layout()
217
+
218
+ buf = io.BytesIO()
219
+ fig.savefig(buf, format='png', bbox_inches='tight')
220
+ buf.seek(0)
221
+ return base64.b64encode(buf.read()).decode('utf-8')
222
+
223
+ finally:
224
+ plt.close(fig)
static/app.js ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // OpenGrid Control Room
2
+ const API = window.location.origin;
3
+ const AGENT_COLORS = ['#00bfff','#ff69b4','#ff6347','#32cd32','#9370db','#ffa500'];
4
+ const AGENT_NAMES = ['Bengaluru','Mysuru','Kalburagi','Hassan','Tumakuru','Bagalkot'];
5
+
6
+ // Real Karnataka state boundary path (source: @svg-maps/india)
7
+ const KARNATAKA_PATH = "m 124.338,505.46021 -0.617,-0.44733 0.776,-0.16422 -0.063,-0.8604 1.544,-0.77275 0.48,-0.70223 0.476,0.96821 0.881,0.0413 1.521,-0.74857 0.512,-1.53442 -0.938,-0.17228 0.62,-0.86141 0.404,0.86745 0.379,-0.0181 -0.412,-1.05888 1.641,-3.03861 -0.711,-0.35364 -0.968,0.47151 -0.458,-0.38889 1.391,-1.25837 1.141,0.50879 -0.068,-1.30269 0.567,-0.8997 -0.205,-0.93495 -1.688,-0.57629 -0.027,-0.50476 -1.422,-0.24583 -0.407,0.51987 0.312,-0.51181 -0.538,-0.73446 0.051,-1.1828 0.369,-0.24886 0.389,0.56622 0.156,-0.64581 -0.554,-0.135 -0.079,-1.12941 -0.891,-0.14911 0.075,-0.95309 -0.652,0.58133 -0.327,-0.41207 0.683,-0.18639 -0.196,-0.9007 0.79,0.92891 0.32,-1.12336 0.758,-0.0786 -0.063,0.39998 0.572,0.23676 0.284,-1.11026 1.444,-0.57126 0.104,-1.2241 0.432,0.74655 1.118,-0.14407 0.474,1.77622 1.304,-0.51987 0.135,-0.67805 0.996,0.0504 -0.625,-0.72439 0.746,-0.8191 0.043,-0.88055 3.282,-1.21706 1.441,0.0192 -0.248,-1.88302 1.091,-0.48057 0.066,-0.60249 -0.842,-0.44329 0.238,-0.33752 1.924,-0.0121 0.034,0.3486 1.225,-0.50375 1.062,1.64625 1.016,0 -0.135,0.69014 0.684,0.0373 1.401,-0.74252 0.119,-1.76514 1.19,0.0494 1.035,-0.52289 0.759,0.28311 0.772,-0.47957 0.515,0.92992 1.629,-0.45438 0.114,-0.9672 0.706,0.10276 0.024,0.73447 0.719,0.40703 0.619,-0.20251 -0.049,-1.65431 -0.596,-0.0151 0.725,-0.57931 0.002,-0.68712 -1.057,-1.6664 0.714,-0.83722 -0.047,-1.16568 -1.129,-0.91884 0.15,-0.85738 -0.592,-0.16422 -0.131,-0.72741 0.78,-0.19646 0.414,-1.88201 0.878,0.4302 0.285,0.99642 0.96,-0.20352 1.367,1.13646 0.469,-1.15761 0.779,0.81405 0.529,-0.69215 0.134,1.39841 0.785,0.64883 2.583,-0.66294 0.506,0.53196 0.889,-0.79693 0.877,0.55916 0.264,0.96015 -0.072,-1.13243 1.508,-0.56823 0.659,0.96922 1.418,-0.42618 0.181,0.86343 0.616,-0.0262 0.552,-1.2634 -0.964,-0.12593 0.234,-1.3037 -0.827,0.0463 -0.06,-0.80197 0.926,-0.54304 -0.661,-0.0191 0.474,-0.61155 -0.546,-0.44733 -0.175,-1.14955 1.758,-0.20553 0.273,-0.88459 1.268,-0.35766 -0.062,-1.16265 0.781,-0.0373 0.001,-0.96115 1.038,-0.0242 -0.001,1.27348 0.863,-1.45483 1.02,1.77017 0.573,0.1743 0.159,-1.01455 0.617,-0.24079 -0.249,-0.98735 0.985,-0.11384 0.532,-0.86746 -1.061,-0.67301 0.067,-0.90271 1.3,0.65386 2.379,-1.03067 0.026,-2.60337 0.773,-0.14206 -0.16,-0.75159 0.445,-0.51584 -0.957,-0.41912 0.661,-1.51628 0.707,-0.0796 0.755,0.56923 0.186,-0.46546 0.52,0.69316 1.072,-0.008 -0.279,-0.93496 1.14,-0.47453 0.43,-1.41956 0.746,0.0645 -0.226,-0.76772 1.039,-1.27851 -0.101,-0.84126 1.616,-0.99742 0.517,0.51987 0.577,-0.38386 0.002,1.03772 0.845,0.269 -1.074,1.7198 1.624,0.0917 0.607,1.02866 0.938,-0.40804 -0.015,-0.62465 0.847,0.33953 0,0 1.11,0.2952 -0.81,1.70972 0.701,1.07298 0.059,1.15661 -1.148,1.00649 0.974,0.96115 1.129,0.37378 0.151,0.52592 -0.197,0.50576 -0.424,-0.25087 -0.15,1.209 -0.657,-0.11788 0.241,0.83219 -0.501,-0.0524 -0.482,1.20598 -0.497,-0.19243 -0.316,0.55916 -0.134,0.41509 1.287,0.40501 -0.083,0.37479 -2.338,1.22814 -0.218,2.41597 1.049,0.33349 0.243,0.55815 0.54,-0.71029 0.439,0.5229 0.867,-0.29319 0.04,0.66193 1.965,0.59442 -0.034,0.72036 -0.752,-0.18336 0.098,-0.48461 -0.258,0.59946 -0.617,0.134 -0.007,0.56521 -0.783,-0.21964 -0.013,0.54203 -1.307,0.0504 0.531,0.50879 -0.157,0.70222 -0.605,0.39595 -0.995,-0.35968 -0.368,1.80544 0.429,0.27202 -1.552,1.23318 0.386,0.24079 -0.812,1.03369 2.148,1.26239 0.77,2.12078 -0.963,2.15403 0.372,2.84517 -0.704,-0.0887 -0.296,1.50218 0.909,0.0564 -0.037,0.73648 -1.015,0.33852 0.343,0.5511 0.763,-0.2025 -0.109,1.3712 -1.522,0.41509 0.5,1.0357 -0.758,0.15516 -0.268,0.58132 -1.458,-0.16019 -0.097,0.3899 -1.189,0.12191 1.036,1.42158 1.22,0.50879 1.44,0.37176 1.732,-0.28613 2.033,0.83622 -0.027,1.10724 -0.53,-0.23676 -0.653,0.7657 -0.682,-0.11284 -0.286,0.39393 0.025,0.55614 0.46,0.0212 -0.568,1.41352 0.064,0.93395 0.476,0.26698 -0.391,1.59084 0.405,0.6186 -0.014,1.7742 0,0 -5.454,-0.80499 -2.208,0.37379 -1.622,0.9007 -0.915,1.47195 0.871,0.1884 -0.433,1.93137 1.711,1.64222 -0.184,0.7385 -0.728,-0.6045 -1.092,0.41408 -0.056,2.91167 -1.145,-0.13803 -0.032,0.3355 1.193,1.21202 0.715,2.46333 1.007,-0.11788 0.12,0.81406 0.68,0.0907 0.34,0.5773 -0.906,0.82212 0.78,0.59543 -0.01,0.97425 -0.536,0.13601 0.459,0.48158 -1.574,2.61647 -0.792,-0.11788 0.123,-0.51583 -0.967,-0.008 -0.395,0.4171 -2.2,-0.39796 -1.67,-1.34803 -0.475,0.42113 0.216,1.18784 -0.435,1.01455 1.342,0.40904 0.765,-0.2821 -0.329,3.46982 -1.432,0.53599 0.371,0.96821 -0.793,2.87338 0.828,1.60897 0.583,0.10075 0.893,1.1828 2.16,-0.21258 -0.62,0.71835 -0.046,0.98633 -0.596,-0.30325 -0.627,0.50375 -0.084,0.94805 1.486,1.13344 -0.528,0.47251 0.271,0.69316 2.34,0.0121 0.538,0.65286 0.623,-0.0846 0.143,-1.60595 0.842,-1.08709 1.67,0.57729 1.03,-0.43423 -0.033,1.19086 1.667,0.1471 -0.081,0.84932 0.594,0.82615 0.668,-0.43524 -0.852,-1.8004 0.223,-0.64278 1.187,0.32743 0.259,0.81305 0.87,0.009 0.087,2.61143 -2.317,-0.3093 -0.272,0.77174 -0.606,0.11284 -0.067,0.61155 0.946,-0.48662 0.12,0.32643 -1.45,1.73995 1.197,0.3365 -0.162,0.82514 1.151,0.008 -0.48,0.70525 0.413,0.58032 -0.744,0.63372 -0.03,-0.38487 -0.881,0.0242 0.114,-1.85279 -0.843,-0.91682 -3.478,0.71734 -0.549,-1.09213 -2.039,-0.23978 -0.322,-0.96921 0.241,-1.61301 -1.637,-0.35766 -0.098,0.50577 -0.954,0.14911 -0.162,0.60449 0.907,0.0826 1.005,1.33896 -0.919,0.91884 2.193,2.09761 0.114,0.60853 -0.502,0.0876 -1.023,1.70468 0.506,1.54953 0.395,0.21158 0.313,-0.83522 0.706,0.70324 0.737,-0.47554 1.493,0.134 -0.091,-1.42461 -0.803,-0.6186 0.809,-0.16422 -0.137,-0.92488 0.441,-0.37983 0.037,1.24325 0.547,-0.46042 0.138,0.42617 0.467,-0.72339 0.348,1.19691 1.182,-0.38386 0.274,0.68006 0.826,-0.4302 1.362,0.2277 -0.332,0.77476 1.021,0.0474 -0.161,2.61646 0.695,-0.0846 0.092,-0.58435 0.522,0.45539 0.154,-1.25535 0.762,0.59141 0.828,-0.58536 0.537,0.3496 0.324,-0.16926 -0.55,-0.43624 0.809,-0.44028 0.442,-0.0363 -0.136,0.54505 0.666,-0.19746 0.276,0.6186 0.086,-1.24527 1.374,-0.48259 -0.051,-0.49669 1.082,-0.53297 -0.447,-1.03671 0.25,-0.56723 1.438,0.0796 0.515,0.73345 1.148,-1.15156 0.243,1.23519 -0.745,0.2831 0.044,1.30169 0.444,-0.005 0.406,-0.89566 1.102,0.18941 0.07,-0.73145 1.516,0.98937 0.098,1.37926 -0.697,0.93798 0.512,0.50173 -0.084,0.55715 -0.865,-0.0816 -0.12,0.4574 0.469,0.68309 1.57,-0.28815 0.1,0.54506 0.7,0.15616 -0.224,1.28859 0.93,-0.66394 3.414,0.19545 -0.746,4.80576 0.884,0.48965 -0.636,0.26497 0.508,0.35262 0.695,-0.33146 0.241,0.44632 0.749,-0.20553 1.027,1.12638 0.729,-0.94402 0.457,0.80499 -0.184,1.24425 -0.581,0.36573 0.589,0.66091 -1.263,0.79391 0.402,0.47957 -0.545,0.33751 0.056,0.62163 -1.11,0.45639 0.133,1.46793 -0.738,-0.11486 0.275,1.46087 -1.203,0.11788 -0.689,-0.70726 -0.886,1.73994 -1.298,-0.005 -0.428,2.03715 0,0 -2.093,-0.37478 -1.548,-1.55457 -0.666,-0.0756 -0.281,1.08406 -0.42,-0.004 -0.75,-1.15459 -0.435,0.7657 -0.326,-0.17833 0.528,-0.73245 -0.35,-0.48057 -2.781,0.95812 0.306,1.00952 -1.425,2.63964 -0.578,-0.20956 -0.533,0.52994 -0.504,-0.54002 -1.339,0.35666 0.157,0.78484 -0.582,1.35407 0.177,1.09314 0.583,0.15515 -0.649,0.67402 1.043,-0.19042 -0.107,1.47699 -0.34,1.17273 -1.279,1.50318 -1.518,0.46849 -0.095,1.27851 5.457,0.74958 0.881,1.32285 -1.654,2.04924 -0.607,1.53744 -3.686,0.12292 -0.157,1.20799 -0.505,-0.269 0.073,1.05888 -0.775,1.89308 -1.251,-0.60147 -0.699,0.42415 -0.864,-0.84327 -0.902,-0.0877 -0.308,0.39997 -2.601,0.44129 0.136,1.076 -0.789,-0.26195 -0.316,-1.11429 -0.716,0.26598 0.195,-0.41106 -0.57,-0.40803 -0.663,0.0413 -0.276,0.76872 -1.254,-0.38788 -1.49,2.97816 0.469,0.90473 -0.285,0.58435 -0.435,-0.48562 -1.471,-0.28512 -3.897,0.009 -0.412,-1.29161 -0.758,-0.58939 -1.106,0.91783 -0.584,-0.0897 0,0 -0.566,-0.89365 0.471,-0.34255 -0.235,-0.77678 -1.521,0.48561 -1.318,-1.56061 -1.12,0.0746 -0.722,-1.36415 -1.59,0.27001 -0.003,-2.55803 -2.375,1.01354 -2.464,-0.37278 -1.096,-0.93294 -0.517,-1.86185 -1.73,0.19444 -0.323,-0.81909 -0.82,0.19545 -0.572,-1.09817 -1.219,-0.17933 -1.97,-2.90361 -1.331,-0.005 0.047,-1.72382 -1.168,-0.86343 0.021,-0.95712 1.17,-0.18034 -0.168,-0.78686 -1.542,0.8725 -0.125,-0.73447 -1.125,-0.59946 -0.09,-0.98634 1.068,-0.45136 -1.071,-0.56823 -1.126,1.05183 -0.449,-1.34098 -0.885,0.17329 -0.339,-0.3496 0.161,-0.92388 -1.351,-0.46042 -0.063,0.67401 -0.739,0.13602 0.039,-1.17374 -0.891,0.11788 0.106,-0.29318 -0.574,-0.15012 0.499,-0.65689 -0.342,-0.6448 -2.621,0.77376 0,0 -0.965,-2.10365 -2.634,-10.6573 -0.512,-6.16488 -1.337,-5.02237 -0.768,-1.72786 -0.809,-0.39594 -0.627,-1.24728 -0.64,-3.47486 -0.611,-0.87048 -1.843,-6.03994 0.826,-0.61357 -0.599,-0.48662 -0.181,0.68611 -0.971,0.0302 -0.313,-1.75002 -0.524,-0.54808 0.32,-0.28814 -0.384,-1.61905 -0.669,-0.71633 -0.622,0.65084 -2.291,-1.75103 0.587,-0.13299 0.157,-0.89768 -0.396,-0.0121 -0.308,-1.05989 0,0 0.879,-0.538 0.754,0.24986 -0.068,-0.91279 0.831,0.64278 1.22,-0.98231 -0.176,-0.52289 0.851,-1.06593 -0.502,-1.21605 0.235,-1.02664 0.676,-0.8876 -0.318,-0.85033 -1.029,-0.74857 1.761,-0.76671 -0.278,-1.61602 -0.957,-0.46446 -0.003,-1.18481 -0.548,-1.00952 0.647,-0.7939 -0.69,-0.86645 0.298,-0.95108 -0.278,-0.97123 -0.496,-0.26799 -0.384,0.38083 -0.483,-0.69517 -0.35,0.62969 -0.986,0.0383 z";
8
+
9
+ let state = {
10
+ sessionId: null, task: 'task_karnataka', step: 0, done: false,
11
+ numAgents: 0, zoneInfo: {}, observations: {}, taskConfigs: {},
12
+ rewardHistory: [], freqHistory: [], perAgentRewards: {},
13
+ totalReward: 0, autoRunning: false, autoTimer: null,
14
+ safetyTotal: 0, lastOversight: null, mapScale: 1, alarms: []
15
+ };
16
+
17
+ // --- Init ---
18
+ document.addEventListener('DOMContentLoaded', () => {
19
+ document.querySelectorAll('.task-btn').forEach(btn => {
20
+ btn.addEventListener('click', () => {
21
+ document.querySelectorAll('.task-btn').forEach(b => b.classList.remove('active'));
22
+ btn.classList.add('active');
23
+ state.task = btn.dataset.task;
24
+ });
25
+ });
26
+ fetch(`${API}/tasks`).then(r=>r.json()).then(d=>{
27
+ d.forEach(t => state.taskConfigs[t.id] = t);
28
+ resetEpisode(); // reset only after configs are loaded
29
+ setTimeout(() => document.getElementById('loading').classList.add('hidden'), 800);
30
+ });
31
+ });
32
+
33
+ // --- API Calls ---
34
+ async function resetEpisode() {
35
+ stopAutoRun();
36
+ state.step = 0; state.done = false; state.totalReward = 0;
37
+ state.rewardHistory = []; state.freqHistory = []; state.safetyTotal = 0;
38
+ state.alarms = [];
39
+ mapFitted = false;
40
+ document.getElementById('alarmLog').innerHTML = '';
41
+ document.getElementById('simStatus').textContent = 'RUNNING';
42
+ try {
43
+ const r = await fetch(`${API}/reset_multi?task_id=${state.task}`, {method:'POST'});
44
+ const d = await r.json();
45
+ state.sessionId = d.session_id;
46
+ state.numAgents = d.num_agents;
47
+ state.zoneInfo = d.zone_info;
48
+ state.observations = d.observations;
49
+ state.perAgentRewards = {};
50
+ for (let i = 0; i < d.num_agents; i++) state.perAgentRewards[i] = [];
51
+ updateAll();
52
+ } catch(e) { showAlert('critical', 'Reset failed: ' + e.message); }
53
+ }
54
+
55
+ async function stepEpisode() {
56
+ if (!state.sessionId || state.done) return;
57
+ const actions = {};
58
+ for (let i = 0; i < state.numAgents; i++) {
59
+ const obs = state.observations[String(i)];
60
+ actions[String(i)] = generateHeuristicAction(i, obs);
61
+ }
62
+ try {
63
+ const r = await fetch(`${API}/step_multi?session_id=${state.sessionId}`, {
64
+ method: 'POST', headers: {'Content-Type':'application/json'},
65
+ body: JSON.stringify({agent_actions: actions})
66
+ });
67
+ const d = await r.json();
68
+ state.step++;
69
+ state.observations = d.observations;
70
+ state.totalReward += d.team_reward;
71
+ state.rewardHistory.push(d.team_reward);
72
+ state.lastOversight = d.oversight_report;
73
+ state.done = d.done;
74
+ const freq = getAvgFreq(d.observations);
75
+ state.freqHistory.push(freq);
76
+ // safety_reports is a string-keyed dict {"0": {...}, "1": {...}}, not an array
77
+ Object.values(d.safety_reports || {}).forEach(sr => { if (sr.was_corrected) state.safetyTotal++; });
78
+ for (const [aid, rew] of Object.entries(d.rewards)) {
79
+ if (!state.perAgentRewards[aid]) state.perAgentRewards[aid] = [];
80
+ state.perAgentRewards[aid].push(rew.value);
81
+ }
82
+ if (d.done) {
83
+ document.getElementById('simStatus').textContent = d.info.is_blackout ? 'BLACKOUT' : 'COMPLETE';
84
+ stopAutoRun();
85
+ }
86
+ updateAll(d);
87
+ } catch(e) { showAlert('critical', 'Step failed: ' + e.message); stopAutoRun(); }
88
+ }
89
+
90
+ async function getGrade() {
91
+ if (!state.sessionId) return;
92
+ try {
93
+ const r = await fetch(`${API}/grader?session_id=${state.sessionId}`);
94
+ const d = await r.json();
95
+ document.getElementById('episodeScore').textContent = d.score.toFixed(4);
96
+ document.getElementById('episodeScore').style.color =
97
+ d.score > 0.7 ? 'var(--status-normal)' : d.score > 0.4 ? 'var(--status-warning)' : 'var(--status-critical)';
98
+ } catch(e) { showAlert('warning', 'Grade failed: ' + e.message); }
99
+ }
100
+
101
+ // --- Heuristic Agent ---
102
+ function generateHeuristicAction(agentId, obs) {
103
+ if (!obs) return {bus_adjustments: [], topology_actions: []};
104
+ const freq = obs.grid_frequency || 50;
105
+ const error = 50.0 - freq;
106
+ const buses = obs.local_buses || [];
107
+ const adjs = [];
108
+ buses.forEach(b => {
109
+ // Exclude slack — physics solver overwrites its injection; adjusting it wastes the action
110
+ if (b.type === 'battery' || b.type === 'generator') {
111
+ let delta = error * 8;
112
+ delta = Math.max(-15, Math.min(15, delta));
113
+ if (Math.abs(delta) > 0.5) adjs.push({bus_id: b.id, delta: Math.round(delta*10)/10});
114
+ }
115
+ });
116
+ return {bus_adjustments: adjs, topology_actions: []};
117
+ }
118
+
119
+ // --- Auto Run ---
120
+ function toggleAutoRun() {
121
+ if (state.autoRunning) { stopAutoRun(); }
122
+ else { state.autoRunning = true; document.getElementById('btnAutoRun').classList.add('active'); autoStep(); }
123
+ }
124
+ function stopAutoRun() {
125
+ state.autoRunning = false;
126
+ if (state.autoTimer) clearTimeout(state.autoTimer);
127
+ document.getElementById('btnAutoRun').classList.remove('active');
128
+ }
129
+ async function autoStep() {
130
+ if (!state.autoRunning || state.done) { stopAutoRun(); return; }
131
+ await stepEpisode();
132
+ if (state.autoRunning && !state.done) state.autoTimer = setTimeout(autoStep, 200);
133
+ }
134
+
135
+ // --- UI Updates ---
136
+ function updateAll(stepData) {
137
+ updateHeader();
138
+ updateFrequency();
139
+ updateSystemSummary();
140
+ updateOversight();
141
+ updateAgentCards(stepData);
142
+ updateLeaderboard();
143
+ updateGridMap();
144
+ updateCharts();
145
+ updateAlarmLog(stepData);
146
+ }
147
+
148
+ function getAvgFreq(obs) {
149
+ let sum=0, n=0;
150
+ for (const o of Object.values(obs||state.observations)) { sum += (o.grid_frequency||50); n++; }
151
+ return n ? sum/n : 50;
152
+ }
153
+
154
+ function updateHeader() {
155
+ const maxSteps = state.taskConfigs[state.task]?.max_steps || 50;
156
+ document.getElementById('headerStep').textContent = `${state.step} / ${maxSteps}`;
157
+ document.getElementById('headerAgents').textContent = `${state.numAgents} Active`;
158
+ document.getElementById('headerReward').textContent = state.totalReward.toFixed(2);
159
+ document.getElementById('headerEpisode').textContent = state.task.replace('task_','').toUpperCase();
160
+ const freq = getAvgFreq();
161
+ const el = document.getElementById('headerFreq');
162
+ el.textContent = freq.toFixed(2) + ' Hz';
163
+ el.className = 'value ' + freqClass(freq);
164
+ document.getElementById('totalSteps').textContent = state.step;
165
+ document.getElementById('blackoutStatus').textContent = state.done && document.getElementById('simStatus').textContent==='BLACKOUT' ? 'Yes' : 'No';
166
+ }
167
+
168
+ function updateFrequency() {
169
+ const freq = getAvgFreq();
170
+ const cls = freqClass(freq);
171
+ const colors = {normal:'#00e5a0',warning:'#ffd700',critical:'#ff3d3d'};
172
+ const col = colors[cls];
173
+ // Arc gauge
174
+ const container = document.getElementById('freqArc');
175
+ const W=200, H=110, cx=100, cy=100, r=80;
176
+ const minF=49, maxF=51;
177
+ const pct = Math.max(0,Math.min(1,(freq-minF)/(maxF-minF)));
178
+ const startA=Math.PI, endA=0;
179
+ const needleA = startA - pct*(startA-endA);
180
+ let svg = `<svg viewBox="0 0 ${W} ${H}" xmlns="http://www.w3.org/2000/svg">`;
181
+ // Background arc
182
+ svg += `<path d="M${cx-r},${cy} A${r},${r} 0 0,1 ${cx+r},${cy}" fill="none" stroke="rgba(255,255,255,0.06)" stroke-width="10" stroke-linecap="round"/>`;
183
+ // Colored segments
184
+ const segs = [{f:49,t:49.5,c:'#ff3d3d'},{f:49.5,t:49.85,c:'#ffd700'},{f:49.85,t:50.15,c:'#00e5a0'},{f:50.15,t:50.5,c:'#ffd700'},{f:50.5,t:51,c:'#ff3d3d'}];
185
+ segs.forEach(s => {
186
+ const a1=Math.PI-((s.f-minF)/(maxF-minF))*Math.PI;
187
+ const a2=Math.PI-((s.t-minF)/(maxF-minF))*Math.PI;
188
+ const x1=cx+r*Math.cos(a1),y1=cy-r*Math.sin(a1);
189
+ const x2=cx+r*Math.cos(a2),y2=cy-r*Math.sin(a2);
190
+ svg += `<path d="M${x1},${y1} A${r},${r} 0 0,0 ${x2},${y2}" fill="none" stroke="${s.c}" stroke-width="6" opacity="0.25" stroke-linecap="round"/>`;
191
+ });
192
+ // Needle
193
+ const nx=cx+(r-12)*Math.cos(needleA), ny=cy-(r-12)*Math.sin(needleA);
194
+ svg += `<line x1="${cx}" y1="${cy}" x2="${nx}" y2="${ny}" stroke="${col}" stroke-width="2.5" stroke-linecap="round"/>`;
195
+ svg += `<circle cx="${cx}" cy="${cy}" r="4" fill="${col}"/>`;
196
+ // Value text
197
+ svg += `<text x="${cx}" y="${cy-20}" text-anchor="middle" fill="${col}" font-family="JetBrains Mono" font-size="28" font-weight="700" style="text-shadow:0 0 15px ${col}40">${freq.toFixed(2)}</text>`;
198
+ svg += `<text x="${cx}" y="${cy-6}" text-anchor="middle" fill="#90a4ae" font-family="Inter" font-size="11">Hz</text>`;
199
+ // Scale labels
200
+ svg += `<text x="18" y="${cy+14}" fill="#546e7a" font-size="8" font-family="JetBrains Mono">49.0</text>`;
201
+ svg += `<text x="${W-30}" y="${cy+14}" fill="#546e7a" font-size="8" font-family="JetBrains Mono">51.0</text>`;
202
+ svg += `<text x="${cx}" y="12" text-anchor="middle" fill="#546e7a" font-size="8" font-family="JetBrains Mono">50.0</text>`;
203
+ svg += '</svg>';
204
+ container.innerHTML = svg;
205
+ document.getElementById('freqDev').textContent = `Deviation: ${(freq-50).toFixed(3)} Hz | Nominal: 50.00 Hz`;
206
+ // Grid condition
207
+ const gc = document.getElementById('gridCondition');
208
+ const dev = Math.abs(freq-50);
209
+ if(dev<0.15){gc.textContent='NORMAL';gc.className='grid-condition normal';}
210
+ else if(dev<0.3){gc.textContent='CONSERVATIVE OPS';gc.className='grid-condition conservative';}
211
+ else if(dev<0.5){gc.textContent='CONSERVATION ALERT';gc.className='grid-condition alert';}
212
+ else{gc.textContent='EMERGENCY';gc.className='grid-condition emergency';}
213
+ }
214
+
215
+ function freqClass(f) { return Math.abs(f-50)<0.5?'normal':Math.abs(f-50)<1?'warning':'critical'; }
216
+
217
+ function updateSystemSummary() {
218
+ let gen=0, load=0, lines=0, overloaded=0, totalLines=0;
219
+ for (const obs of Object.values(state.observations)) {
220
+ gen += obs.zone_gen_mw || 0;
221
+ load += obs.zone_load_mw || 0;
222
+ (obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
223
+ totalLines++; if(l.connected) lines++;
224
+ if(l.connected && l.rho > 1) overloaded++;
225
+ });
226
+ }
227
+ document.getElementById('totalGen').textContent = gen.toFixed(1) + ' MW';
228
+ document.getElementById('totalLoad').textContent = load.toFixed(1) + ' MW';
229
+ document.getElementById('netBalance').textContent = (gen-load).toFixed(1) + ' MW';
230
+ document.getElementById('linesConnected').textContent = `${lines} / ${totalLines}`;
231
+ const olEl = document.getElementById('linesOverloaded');
232
+ olEl.textContent = overloaded;
233
+ olEl.style.color = overloaded > 0 ? 'var(--status-critical)' : 'var(--status-normal)';
234
+ }
235
+
236
+ function updateOversight() {
237
+ const o = state.lastOversight;
238
+ if (!o) return;
239
+ const cs = document.getElementById('coordScore');
240
+ cs.textContent = o.coordination_score.toFixed(2);
241
+ cs.style.color = o.coordination_score > 0.7 ? 'var(--status-normal)' : o.coordination_score > 0.4 ? 'var(--status-warning)' : 'var(--status-critical)';
242
+ document.getElementById('conflicts').textContent = o.conflicting_actions_detected;
243
+ document.getElementById('safetyCorrTotal').textContent = state.safetyTotal;
244
+ document.getElementById('selfishActions').textContent = o.selfish_actions_detected;
245
+ }
246
+
247
+ function updateAlarmLog(stepData) {
248
+ if (!stepData) return;
249
+ const logEl = document.getElementById('alarmLog');
250
+ let newAlarms = [];
251
+ const timeStr = `T+${String(state.step).padStart(2,'0')}s`;
252
+
253
+ // Check frequency
254
+ const freq = getAvgFreq();
255
+ if (Math.abs(freq - 50) > 0.5) {
256
+ newAlarms.push({t: timeStr, msg: `FREQ DEVIATION: ${freq.toFixed(2)} Hz`, type: Math.abs(freq-50)>1?'crit':'warn'});
257
+ }
258
+
259
+ // Check lines and safety
260
+ for (const [aid, obs] of Object.entries(state.observations)) {
261
+ (obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
262
+ if (l.rho > 1.0) newAlarms.push({t: timeStr, msg: `OVERLOAD: Line ${l.id} at ${(l.rho*100).toFixed(0)}%`, type: 'crit'});
263
+ else if (l.rho > 0.9) newAlarms.push({t: timeStr, msg: `CONGESTION: Line ${l.id} at ${(l.rho*100).toFixed(0)}%`, type: 'warn'});
264
+ });
265
+ const sr = stepData.safety_reports?.[aid];
266
+ if (sr && sr.was_corrected) {
267
+ newAlarms.push({t: timeStr, msg: `AGENT ${aid} SAFETY CORRECTED`, type: 'warn'});
268
+ }
269
+ }
270
+
271
+ if (state.done && document.getElementById('simStatus').textContent==='BLACKOUT') {
272
+ newAlarms.push({t: timeStr, msg: `SYSTEM COLLAPSE - BLACKOUT`, type: 'crit'});
273
+ }
274
+
275
+ if (newAlarms.length > 0) {
276
+ state.alarms = [...newAlarms, ...state.alarms].slice(0, 50); // Keep last 50
277
+ logEl.innerHTML = state.alarms.map(a => `<div class="alarm-entry ${a.type}"><span class="alarm-time">[${a.t}]</span>${a.msg}</div>`).join('');
278
+ }
279
+ }
280
+
281
+ function updateAgentCards(stepData) {
282
+ const container = document.getElementById('agentCards');
283
+ container.innerHTML = '';
284
+ for (let i = 0; i < state.numAgents; i++) {
285
+ const obs = state.observations[String(i)];
286
+ const zi = state.zoneInfo[String(i)] || {};
287
+ const sr = stepData?.safety_reports?.[String(i)];
288
+ const rew = stepData?.rewards?.[String(i)];
289
+ const cumReward = (state.perAgentRewards[i]||[]).reduce((a,b)=>a+b,0);
290
+ const wasCorrected = sr?.was_corrected || false;
291
+ const cardClass = wasCorrected ? 'warning' : 'active';
292
+ const html = `
293
+ <div class="agent-card ${cardClass}">
294
+ <div class="agent-header">
295
+ <div class="agent-name">
296
+ <span class="agent-dot" style="background:${AGENT_COLORS[i]}"></span>
297
+ Agent ${i} - ${zi.zone_name||AGENT_NAMES[i]}
298
+ </div>
299
+ <span class="agent-status-badge ${wasCorrected?'corrected':'active'}">${wasCorrected?'Corrected':'Safe'}</span>
300
+ </div>
301
+ <div class="agent-metrics">
302
+ <div class="agent-metric">
303
+ <div class="label">Step Reward</div>
304
+ <div class="value" style="color:${(rew?.value||0)>=0?'var(--status-normal)':'var(--status-critical)'}">${(rew?.value||0).toFixed(2)}</div>
305
+ </div>
306
+ <div class="agent-metric">
307
+ <div class="label">Cumulative</div>
308
+ <div class="value">${cumReward.toFixed(1)}</div>
309
+ </div>
310
+ <div class="agent-metric">
311
+ <div class="label">Zone Load</div>
312
+ <div class="value">${(obs?.zone_load_mw||0).toFixed(0)} MW</div>
313
+ </div>
314
+ <div class="agent-metric">
315
+ <div class="label">Zone Gen</div>
316
+ <div class="value">${(obs?.zone_gen_mw||0).toFixed(0)} MW</div>
317
+ </div>
318
+ </div>
319
+ <div class="safety-shield ${wasCorrected?'corrected':'safe'}">
320
+ ${wasCorrected?'&#9888; Safety Corrected':'&#9635; Safety OK'}
321
+ ${sr?.blocked_topology_actions ? ` | ${sr.blocked_topology_actions} blocked` : ''}
322
+ </div>
323
+ <div class="sparkline-container"><svg id="spark${i}"></svg></div>
324
+ </div>`;
325
+ container.innerHTML += html;
326
+ }
327
+ // Draw sparklines
328
+ for (let i = 0; i < state.numAgents; i++) {
329
+ drawSparkline(`spark${i}`, state.perAgentRewards[i]||[], AGENT_COLORS[i]);
330
+ }
331
+ }
332
+
333
+ function updateLeaderboard() {
334
+ const lb = document.getElementById('leaderboard');
335
+ const agents = [];
336
+ for (let i = 0; i < state.numAgents; i++) {
337
+ const cum = (state.perAgentRewards[i]||[]).reduce((a,b)=>a+b,0);
338
+ const zi = state.zoneInfo[String(i)] || {};
339
+ agents.push({id:i, name: zi.zone_name||AGENT_NAMES[i], score: cum});
340
+ }
341
+ agents.sort((a,b) => b.score - a.score);
342
+ lb.innerHTML = agents.map((a,idx) => `
343
+ <li>
344
+ <span class="agent-label">
345
+ <span class="agent-dot" style="background:${AGENT_COLORS[a.id]};width:6px;height:6px;border-radius:50%;display:inline-block;"></span>
346
+ ${['#1','#2','#3'][idx]||' '} ${a.name}
347
+ </span>
348
+ <span class="score" style="color:${AGENT_COLORS[a.id]}">${a.score.toFixed(1)}</span>
349
+ </li>`).join('');
350
+ }
351
+
352
+ // --- Grid Map (Leaflet) ---
353
+ let leafletMap = null;
354
+ let mapLayers = { lines: null, nodes: null, badges: null };
355
+ let mapFitted = false;
356
+
357
+ function initLeafletMap() {
358
+ const container = document.getElementById('gridMap');
359
+ if (leafletMap) return;
360
+
361
+ // Karnataka bounds
362
+ const kaBounds = [[11.5, 73.5], [18.5, 79.0]];
363
+
364
+ leafletMap = L.map(container, {
365
+ center: [14.5, 76.5],
366
+ zoom: 7,
367
+ zoomControl: true,
368
+ attributionControl: false,
369
+ minZoom: 5,
370
+ maxZoom: 15,
371
+ preferCanvas: true,
372
+ });
373
+
374
+ // Dark tile layer for SCADA aesthetic
375
+ L.tileLayer('https://{s}.basemaps.cartocdn.com/dark_all/{z}/{x}/{y}{r}.png', {
376
+ subdomains: 'abcd',
377
+ maxZoom: 19,
378
+ }).addTo(leafletMap);
379
+
380
+ // Attribution (small, bottom-right)
381
+ L.control.attribution({position: 'bottomright', prefix: false})
382
+ .addAttribution('© <a href="https://carto.com/">CARTO</a>')
383
+ .addTo(leafletMap);
384
+
385
+ // Layer groups for easy clearing
386
+ mapLayers.lines = L.layerGroup().addTo(leafletMap);
387
+ mapLayers.nodes = L.layerGroup().addTo(leafletMap);
388
+ mapLayers.badges = L.layerGroup().addTo(leafletMap);
389
+
390
+ // Fix Leaflet size after container is fully rendered
391
+ setTimeout(() => {
392
+ leafletMap.invalidateSize();
393
+ leafletMap.fitBounds(kaBounds, { padding: [20, 20] });
394
+ }, 200);
395
+ }
396
+
397
+ function updateGridMap() {
398
+ if (!leafletMap) initLeafletMap();
399
+
400
+ // Clear previous layers
401
+ mapLayers.lines.clearLayers();
402
+ mapLayers.nodes.clearLayers();
403
+ mapLayers.badges.clearLayers();
404
+
405
+ const typeIcons = {slack:'S',generator:'G',load:'L',battery:'B',solar:'PV',wind:'W'};
406
+ const typeColors = {slack:'#00e5a0',generator:'#f5a623',load:'#e94560',battery:'#4a90d9',solar:'#ffeb3b',wind:'#64ffda'};
407
+
408
+ // Collect buses — merge static config with runtime state
409
+ let allBuses = [];
410
+ const taskCfg = state.taskConfigs[state.task];
411
+ const runtimeState = {};
412
+ for (const obs of Object.values(state.observations)) {
413
+ (obs.local_buses||[]).forEach(b => { runtimeState[b.id] = b; });
414
+ }
415
+ if (taskCfg && taskCfg.buses) {
416
+ allBuses = taskCfg.buses.map(b => {
417
+ const rt = runtimeState[b.id];
418
+ return {...b, p_injection: rt ? rt.p_injection : (b.base_p || 0)};
419
+ });
420
+ } else {
421
+ allBuses = Object.values(runtimeState);
422
+ }
423
+
424
+ const hasGPS = allBuses.some(b => b.lat !== undefined && b.lon !== undefined);
425
+
426
+ // For non-GPS tasks, generate fake positions around Karnataka center
427
+ const busPositions = {};
428
+ const zones = [
429
+ {id:0, lat:16.8, lon:76.8, color:AGENT_COLORS[0], label:'Kalaburagi'},
430
+ {id:1, lat:15.2, lon:75.2, color:AGENT_COLORS[1], label:'Hubballi'},
431
+ {id:2, lat:12.8, lon:75.5, color:AGENT_COLORS[2], label:'Mysuru'},
432
+ {id:3, lat:13.2, lon:77.5, color:AGENT_COLORS[3], label:'Bengaluru'},
433
+ ];
434
+
435
+ allBuses.forEach((b, idx) => {
436
+ const aid = findAgent(b.id);
437
+ let lat, lon;
438
+ if (hasGPS && b.lat !== undefined && b.lon !== undefined) {
439
+ lat = b.lat;
440
+ lon = b.lon;
441
+ } else {
442
+ // Fallback: spread around zone center
443
+ const zd = zones[aid >= 0 && aid < zones.length ? aid : 0];
444
+ const zBuses = allBuses.filter(bb => findAgent(bb.id) === aid);
445
+ const zi = zBuses.indexOf(b);
446
+ const a = (zi / Math.max(zBuses.length, 1)) * Math.PI * 2;
447
+ lat = zd.lat + Math.cos(a) * 0.3;
448
+ lon = zd.lon + Math.sin(a) * 0.3;
449
+ }
450
+ busPositions[b.id] = {lat, lon, bus: b, agent: aid};
451
+ });
452
+
453
+ // Draw transmission lines
454
+ const drawnLines = new Set();
455
+ for (const obs of Object.values(state.observations)) {
456
+ (obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
457
+ if (drawnLines.has(l.id)) return;
458
+ drawnLines.add(l.id);
459
+ const parts = l.id.replace('L_','').split('_');
460
+ const fromId = parseInt(parts[0]);
461
+ const toId = parseInt(parts[1]);
462
+ const from = busPositions[fromId];
463
+ const to = busPositions[toId];
464
+ if (!from || !to) return;
465
+
466
+ const lc = !l.connected ? '#4a5568' : l.rho > 1 ? '#ff1744' : l.rho > 0.8 ? '#ff9100' : '#e91e63';
467
+ const w = !l.connected ? 1.5 : l.rho > 0.8 ? 5 : 3;
468
+
469
+ const polyline = L.polyline(
470
+ [[from.lat, from.lon], [to.lat, to.lon]],
471
+ { color: lc, weight: w, dashArray: l.connected ? '10 5' : '4 4', opacity: 0.9 }
472
+ );
473
+ if (l.connected && Math.abs(l.flow) > 0.5) {
474
+ polyline.bindTooltip(`${l.id}: ${l.flow.toFixed(0)} MW (${(l.rho*100).toFixed(0)}%)`, {
475
+ permanent: false, className: 'leaflet-tooltip-dark'
476
+ });
477
+ }
478
+ mapLayers.lines.addLayer(polyline);
479
+ });
480
+ }
481
+ // Ensure lines are visible above tiles
482
+ if (drawnLines.size > 0) {
483
+ mapLayers.lines.eachLayer(l => { if (l.bringToFront) l.bringToFront(); });
484
+ }
485
+
486
+ // Draw bus markers
487
+ for (const [bid, pos] of Object.entries(busPositions)) {
488
+ const b = pos.bus;
489
+ const col = AGENT_COLORS[pos.agent] || '#4a5568';
490
+ const fill = typeColors[b.type] || '#666';
491
+ const r = b.type === 'slack' ? 12 : b.type === 'load' ? 7 : 9;
492
+ const inj = (b.p_injection !== undefined ? b.p_injection : 0);
493
+ const busLabel = b.name || `${b.type} ${b.id}`;
494
+ const icon = typeIcons[b.type] || '?';
495
+
496
+ // Outer ring (zone color)
497
+ const outerRing = L.circleMarker([pos.lat, pos.lon], {
498
+ radius: r + 4, fillColor: 'transparent', fillOpacity: 0,
499
+ color: col, weight: 1.5, opacity: 0.4
500
+ });
501
+ mapLayers.nodes.addLayer(outerRing);
502
+
503
+ // Inner node
504
+ const marker = L.circleMarker([pos.lat, pos.lon], {
505
+ radius: r, fillColor: fill, fillOpacity: 0.9,
506
+ color: col, weight: 1, opacity: 0.6
507
+ });
508
+
509
+ // Rich tooltip
510
+ const tooltipHtml = `
511
+ <div style="font-family:'JetBrains Mono',monospace;font-size:11px;min-width:120px;">
512
+ <b style="color:${fill}">${icon}</b> <b>${busLabel}</b><br>
513
+ <span style="color:#888">Type:</span> ${b.type}<br>
514
+ <span style="color:#888">Injection:</span> <b>${inj.toFixed(1)} MW</b><br>
515
+ <span style="color:#888">Zone:</span> ${state.zoneInfo[String(pos.agent)]?.zone_name || 'Agent ' + pos.agent}
516
+ </div>`;
517
+ marker.bindTooltip(tooltipHtml, { className: 'leaflet-tooltip-dark', direction: 'top', offset: [0, -r] });
518
+ mapLayers.nodes.addLayer(marker);
519
+
520
+ // Label under node
521
+ const labelIcon = L.divIcon({
522
+ className: 'bus-label-icon',
523
+ html: `<span style="color:${fill};text-shadow:0 0 4px #000;font-size:9px;font-family:'JetBrains Mono',monospace;white-space:nowrap;">${busLabel}</span>`,
524
+ iconSize: [80, 14],
525
+ iconAnchor: [40, -r - 2],
526
+ });
527
+ L.marker([pos.lat, pos.lon], { icon: labelIcon, interactive: false }).addTo(mapLayers.nodes);
528
+
529
+ // MW label above node
530
+ const mwIcon = L.divIcon({
531
+ className: 'bus-mw-icon',
532
+ html: `<span style="color:#e0e0e0;text-shadow:0 0 4px #000;font-size:10px;font-weight:700;font-family:'JetBrains Mono',monospace;">${inj.toFixed(0)}</span>`,
533
+ iconSize: [40, 14],
534
+ iconAnchor: [20, r + 16],
535
+ });
536
+ L.marker([pos.lat, pos.lon], { icon: mwIcon, interactive: false }).addTo(mapLayers.nodes);
537
+ }
538
+
539
+ // Zone badge overlays
540
+ zones.slice(0, state.numAgents).forEach(z => {
541
+ const zi = state.zoneInfo[String(z.id)] || {};
542
+ const name = zi.zone_name || z.label || AGENT_NAMES[z.id];
543
+ const cum = (state.perAgentRewards[z.id] || []).reduce((a, b) => a + b, 0);
544
+
545
+ const badgeIcon = L.divIcon({
546
+ className: 'zone-badge-leaflet',
547
+ html: `<div style="background:rgba(10,14,26,0.85);border:1px solid ${z.color};border-radius:6px;padding:4px 10px;text-align:center;white-space:nowrap;">
548
+ <div style="color:${z.color};font-size:11px;font-weight:700;font-family:'JetBrains Mono',monospace;">${name}</div>
549
+ <div style="color:${z.color};font-size:10px;font-family:'JetBrains Mono',monospace;opacity:0.8">${cum.toFixed(1)} pts</div>
550
+ </div>`,
551
+ iconSize: [120, 36],
552
+ iconAnchor: [60, 50],
553
+ });
554
+ L.marker([z.lat, z.lon], { icon: badgeIcon, interactive: false }).addTo(mapLayers.badges);
555
+ });
556
+
557
+ // Fit map to bus extent on first data load
558
+ if (!mapFitted && allBuses.length > 0) {
559
+ const lats = allBuses.filter(b => b.lat).map(b => b.lat);
560
+ const lons = allBuses.filter(b => b.lon).map(b => b.lon);
561
+ if (lats.length > 0) {
562
+ leafletMap.fitBounds([
563
+ [Math.min(...lats) - 0.5, Math.min(...lons) - 0.5],
564
+ [Math.max(...lats) + 0.5, Math.max(...lons) + 0.5]
565
+ ]);
566
+ mapFitted = true;
567
+ }
568
+ }
569
+ }
570
+
571
+ function showBusTooltip(e, node) {
572
+ const tt = document.getElementById('busTooltip');
573
+ const zi = state.zoneInfo[node.dataset.agent]||{};
574
+ document.getElementById('ttTitle').textContent = `Bus ${node.dataset.bus} (${node.dataset.type})`;
575
+ document.getElementById('ttType').textContent = node.dataset.type;
576
+ document.getElementById('ttInj').textContent = node.dataset.inj + ' MW';
577
+ document.getElementById('ttZone').textContent = zi.zone_name || 'Zone ' + node.dataset.agent;
578
+ tt.style.left = (e.clientX + 12) + 'px';
579
+ tt.style.top = (e.clientY - 20) + 'px';
580
+ tt.classList.add('visible');
581
+ }
582
+ function hideBusTooltip() { document.getElementById('busTooltip').classList.remove('visible'); }
583
+
584
+ function findAgent(busId) {
585
+ for (const [aid, zi] of Object.entries(state.zoneInfo)) {
586
+ if ((zi.bus_ids||[]).includes(busId)) return parseInt(aid);
587
+ }
588
+ return -1;
589
+ }
590
+
591
+ // --- Charts ---
592
+ function drawSparkline(id, data, color) {
593
+ const el = document.getElementById(id);
594
+ if (!el || !data.length) return;
595
+ const w = el.clientWidth||120, h = el.clientHeight||22;
596
+ const min = Math.min(...data), max = Math.max(...data);
597
+ const range = max-min || 1;
598
+ const pts = data.slice(-30).map((v,i,a) => `${(i/(a.length-1||1))*w},${h-(((v-min)/range)*h*0.8+h*0.1)}`).join(' ');
599
+ el.innerHTML = `<polyline points="${pts}" fill="none" stroke="${color}" stroke-width="1.5" opacity="0.8"/>`;
600
+ }
601
+
602
+ function updateCharts() {
603
+ // Reward chart
604
+ drawChart('rewardChart', state.rewardHistory, 'var(--chart-reward)', 'Reward');
605
+ // Frequency chart
606
+ drawChart('freqChart', state.freqHistory, 'var(--chart-supply)', 'Hz', 49, 51);
607
+ }
608
+
609
+ function drawChart(containerId, data, color, label, fixedMin, fixedMax) {
610
+ const el = document.getElementById(containerId);
611
+ if (!el) return;
612
+ const W = el.clientWidth||300, H = el.clientHeight||140;
613
+ if (!data.length) { el.innerHTML = `<svg viewBox="0 0 ${W} ${H}"><text x="${W/2}" y="${H/2}" text-anchor="middle" fill="var(--text-muted)" font-size="11">Waiting for data...</text></svg>`; return; }
614
+ const pad = {t:10,r:10,b:20,l:40};
615
+ const cw = W-pad.l-pad.r, ch = H-pad.t-pad.b;
616
+ const min = fixedMin !== undefined ? fixedMin : Math.min(...data);
617
+ const max = fixedMax !== undefined ? fixedMax : Math.max(...data);
618
+ const range = max-min||1;
619
+ const pts = data.map((v,i) => `${pad.l+(i/(data.length-1||1))*cw},${pad.t+ch-(((v-min)/range)*ch)}`).join(' ');
620
+ let svg = `<svg viewBox="0 0 ${W} ${H}" xmlns="http://www.w3.org/2000/svg">`;
621
+ // Grid lines
622
+ for(let i=0;i<=4;i++){const y=pad.t+ch*i/4;const v=(max-((max-min)*i/4)).toFixed(1);svg+=`<line x1="${pad.l}" y1="${y}" x2="${W-pad.r}" y2="${y}" stroke="rgba(255,255,255,0.05)"/><text x="${pad.l-4}" y="${y+3}" text-anchor="end" fill="var(--text-muted)" font-size="8" font-family="JetBrains Mono">${v}</text>`;}
623
+ svg += `<polyline points="${pts}" fill="none" stroke="${color}" stroke-width="1.5"/>`;
624
+ // Fill area
625
+ const firstX = pad.l, lastX = pad.l+(data.length-1)/(data.length-1||1)*cw;
626
+ svg += `<polygon points="${pts} ${lastX},${pad.t+ch} ${firstX},${pad.t+ch}" fill="${color}" opacity="0.08"/>`;
627
+ svg += '</svg>';
628
+ el.innerHTML = svg;
629
+ // Gen mix chart
630
+ if (containerId === 'freqChart') updateGenMix();
631
+ }
632
+
633
+ function updateGenMix() {
634
+ const el = document.getElementById('genMixChart');
635
+ if (!el) return;
636
+ const W = el.clientWidth||200, H = el.clientHeight||140;
637
+ let types = {};
638
+ for (const obs of Object.values(state.observations)) {
639
+ (obs.local_buses||[]).forEach(b => {
640
+ if (b.p_injection > 0) types[b.type] = (types[b.type]||0) + b.p_injection;
641
+ });
642
+ }
643
+ const total = Object.values(types).reduce((a,b)=>a+b,0) || 1;
644
+ const colors = {slack:'#00e5a0',generator:'#f5a623',solar:'#ffeb3b',wind:'#64ffda',battery:'#4a90d9'};
645
+ let svg = `<svg viewBox="0 0 ${W} ${H}">`;
646
+ const cx=W/2, cy=H/2-5, r=Math.min(W,H)*0.3;
647
+ let startAngle = -Math.PI/2;
648
+ for (const [type, val] of Object.entries(types)) {
649
+ const pct = val/total;
650
+ const endAngle = startAngle + pct * Math.PI*2;
651
+ const x1=cx+r*Math.cos(startAngle), y1=cy+r*Math.sin(startAngle);
652
+ const x2=cx+r*Math.cos(endAngle), y2=cy+r*Math.sin(endAngle);
653
+ const large = pct > 0.5 ? 1 : 0;
654
+ svg += `<path d="M${cx},${cy} L${x1},${y1} A${r},${r} 0 ${large},1 ${x2},${y2} Z" fill="${colors[type]||'#666'}" opacity="0.8"/>`;
655
+ const mid = (startAngle+endAngle)/2;
656
+ if (pct > 0.08) {
657
+ const lx=cx+(r+14)*Math.cos(mid), ly=cy+(r+14)*Math.sin(mid);
658
+ svg += `<text x="${lx}" y="${ly}" text-anchor="middle" fill="var(--text-secondary)" font-size="8">${type} ${(pct*100).toFixed(0)}%</text>`;
659
+ }
660
+ startAngle = endAngle;
661
+ }
662
+ svg += `<circle cx="${cx}" cy="${cy}" r="${r*0.55}" fill="var(--bg-card)"/>`;
663
+ svg += `<text x="${cx}" y="${cy-2}" text-anchor="middle" fill="var(--text-primary)" font-family="JetBrains Mono" font-size="14" font-weight="700">${total.toFixed(0)}</text>`;
664
+ svg += `<text x="${cx}" y="${cy+10}" text-anchor="middle" fill="var(--text-muted)" font-size="8">MW</text>`;
665
+ svg += '</svg>';
666
+ el.innerHTML = svg;
667
+ }
668
+
669
+ // --- Alerts ---
670
+ function showAlert(type, msg) {
671
+ const el = document.getElementById('alertBanner');
672
+ el.className = `alert-banner ${type} visible`;
673
+ document.getElementById('alertText').textContent = msg;
674
+ setTimeout(() => el.classList.remove('visible'), 5000);
675
+ }
676
+ function dismissAlert() { document.getElementById('alertBanner').classList.remove('visible'); }
677
+
678
+ // --- Map Controls ---
679
+ function zoomMap(factor) { state.mapScale *= factor; updateGridMap(); }
680
+ function resetMapView() { state.mapScale = 1; updateGridMap(); }
static/index.html ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <meta name="description" content="OpenGrid — Multi-Agent POMDP Power Grid Control Room with Safe RL">
7
+ <title>OpenGrid | Control Room</title>
8
+ <link rel="stylesheet" href="/static/style.css">
9
+ <link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" />
10
+ <script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
11
+ <link rel="icon" href="/static/logo.png" type="image/png">
12
+ </head>
13
+ <body>
14
+
15
+ <!-- Loading Overlay -->
16
+ <div class="loading-overlay" id="loading">
17
+ <div class="loading-spinner"></div>
18
+ <div class="loading-text">OpenGrid — Initializing Control Room</div>
19
+ </div>
20
+
21
+ <!-- Alert Banner -->
22
+ <div class="alert-banner" id="alertBanner">
23
+ <span id="alertText"></span>
24
+ <button class="dismiss" onclick="dismissAlert()">Dismiss</button>
25
+ </div>
26
+
27
+ <!-- Main Layout -->
28
+ <div class="control-room">
29
+
30
+ <!-- ===== HEADER ===== -->
31
+ <header class="header">
32
+ <div class="header-brand">
33
+ <img src="/static/logo.png" alt="OpenGrid" class="logo-img" style="width:32px;height:32px;border-radius:6px;">
34
+ <div>
35
+ <h1>OpenGrid</h1>
36
+ <div class="sub">Multi-Agent Power Grid Control Room</div>
37
+ </div>
38
+ </div>
39
+
40
+ <div class="sim-badge">
41
+ <span class="dot"></span>
42
+ <span id="simStatus">READY</span>
43
+ </div>
44
+
45
+ <div class="header-stats">
46
+ <div class="header-stat">
47
+ <span class="label">Episode</span>
48
+ <span class="value normal" id="headerEpisode">--</span>
49
+ </div>
50
+ <div class="header-stat">
51
+ <span class="label">Step</span>
52
+ <span class="value" id="headerStep">0 / 50</span>
53
+ </div>
54
+ <div class="header-stat">
55
+ <span class="label">Frequency</span>
56
+ <span class="value normal" id="headerFreq">50.00 Hz</span>
57
+ </div>
58
+ <div class="header-stat">
59
+ <span class="label">Agents</span>
60
+ <span class="value normal" id="headerAgents">--</span>
61
+ </div>
62
+ <div class="header-stat">
63
+ <span class="label">Team Reward</span>
64
+ <span class="value" id="headerReward">0.00</span>
65
+ </div>
66
+ </div>
67
+ </header>
68
+
69
+ <!-- ===== LEFT PANEL ===== -->
70
+ <aside class="left-panel">
71
+
72
+ <!-- Frequency Display -->
73
+ <div class="card">
74
+ <div class="card-title">Grid Frequency</div>
75
+ <div class="freq-display">
76
+ <div class="freq-arc-container" id="freqArc"></div>
77
+ <div class="freq-deviation" id="freqDev">Deviation: 0.00 Hz | Nominal: 50.00 Hz</div>
78
+ <div class="grid-condition normal" id="gridCondition">NORMAL</div>
79
+ </div>
80
+ </div>
81
+
82
+ <!-- System Summary -->
83
+ <div class="card">
84
+ <div class="card-title">System Summary</div>
85
+ <div class="stat-row highlight">
86
+ <span class="label">Total Generation</span>
87
+ <span class="value" id="totalGen">-- MW</span>
88
+ </div>
89
+ <div class="stat-row">
90
+ <span class="label">Total Load</span>
91
+ <span class="value" id="totalLoad">-- MW</span>
92
+ </div>
93
+ <div class="stat-row">
94
+ <span class="label">Net Balance</span>
95
+ <span class="value" id="netBalance">-- MW</span>
96
+ </div>
97
+ <div class="stat-row">
98
+ <span class="label">Lines Connected</span>
99
+ <span class="value" id="linesConnected">--</span>
100
+ </div>
101
+ <div class="stat-row">
102
+ <span class="label">Lines Overloaded</span>
103
+ <span class="value" id="linesOverloaded" style="color: var(--status-normal);">0</span>
104
+ </div>
105
+ </div>
106
+
107
+ <!-- Coordination -->
108
+ <div class="card">
109
+ <div class="card-title">Oversight Agent</div>
110
+ <div class="coord-score">
111
+ <div class="big-value" id="coordScore" style="color: var(--status-normal);">1.00</div>
112
+ <div style="font-size:10px; color: var(--text-secondary); margin-top:4px;">Coordination Score</div>
113
+ </div>
114
+ <div class="stat-row">
115
+ <span class="label">Conflicts</span>
116
+ <span class="value" id="conflicts">0</span>
117
+ </div>
118
+ <div class="stat-row">
119
+ <span class="label">Safety Corrections</span>
120
+ <span class="value" id="safetyCorrTotal">0</span>
121
+ </div>
122
+ <div class="stat-row">
123
+ <span class="label">Selfish Actions</span>
124
+ <span class="value" id="selfishActions">0</span>
125
+ </div>
126
+ </div>
127
+
128
+ <!-- Exception Log -->
129
+ <div class="card" style="flex:1; display:flex; flex-direction:column; overflow:hidden;">
130
+ <div class="card-title" style="color: var(--status-warning);">Exception Log</div>
131
+ <div class="alarm-log" id="alarmLog">
132
+ <!-- Populated by JS -->
133
+ </div>
134
+ </div>
135
+
136
+ <!-- Task Selector -->
137
+ <div class="card" style="flex-shrink:0;">
138
+ <div class="card-title">Task &amp; Controls</div>
139
+ <div class="task-selector" id="taskSelector">
140
+ <button class="task-btn" data-task="task_easy">Easy</button>
141
+ <button class="task-btn" data-task="task_medium">Medium</button>
142
+ <button class="task-btn" data-task="task_hard">Hard</button>
143
+ <button class="task-btn active" data-task="task_karnataka" style="color: #ffeb3b; border-color: rgba(255,235,59,0.3);">Karnataka</button>
144
+ </div>
145
+ <div class="controls-row" style="margin-top: var(--gap-sm);">
146
+ <button class="ctrl-btn active" id="btnReset" onclick="resetEpisode()">Reset</button>
147
+ <button class="ctrl-btn" id="btnStep" onclick="stepEpisode()">Step</button>
148
+ <button class="ctrl-btn" id="btnAutoRun" onclick="toggleAutoRun()">Auto</button>
149
+ </div>
150
+ </div>
151
+
152
+ </aside>
153
+
154
+ <!-- ===== CENTER PANEL (Grid Map) ===== -->
155
+ <main class="center-panel" id="centerPanel">
156
+ <div class="grid-map" id="gridMap"></div>
157
+ <div class="bus-tooltip" id="busTooltip">
158
+ <div class="tt-title" id="ttTitle">Bus 0</div>
159
+ <div class="tt-row"><span>Type</span><span class="tt-val" id="ttType">--</span></div>
160
+ <div class="tt-row"><span>Injection</span><span class="tt-val" id="ttInj">-- MW</span></div>
161
+ <div class="tt-row"><span>Zone</span><span class="tt-val" id="ttZone">--</span></div>
162
+ </div>
163
+ </main>
164
+
165
+ <!-- ===== RIGHT PANEL (Agent Monitor) ===== -->
166
+ <aside class="right-panel">
167
+ <div class="card">
168
+ <div class="card-title">Agent Leaderboard</div>
169
+ <ul class="leaderboard" id="leaderboard">
170
+ <!-- Populated by JS -->
171
+ </ul>
172
+ </div>
173
+
174
+ <div id="agentCards">
175
+ <!-- Populated by JS -->
176
+ </div>
177
+ </aside>
178
+
179
+ <!-- ===== BOTTOM PANEL ===== -->
180
+ <footer class="bottom-panel">
181
+
182
+ <!-- Reward History Chart -->
183
+ <div class="bottom-card">
184
+ <div class="card-title">Reward History</div>
185
+ <div class="chart-area" id="rewardChart"></div>
186
+ </div>
187
+
188
+ <!-- Frequency Trend -->
189
+ <div class="bottom-card">
190
+ <div class="card-title">Frequency Trend</div>
191
+ <div class="chart-area" id="freqChart"></div>
192
+ </div>
193
+
194
+ <!-- Generation Mix -->
195
+ <div class="bottom-card">
196
+ <div class="card-title">Generation Mix</div>
197
+ <div class="chart-area" id="genMixChart"></div>
198
+ </div>
199
+
200
+ <!-- Episode Score -->
201
+ <div class="bottom-card">
202
+ <div class="card-title">Episode Score</div>
203
+ <div class="coord-score" style="flex:1; display:flex; flex-direction:column; justify-content:center;">
204
+ <div class="big-value" id="episodeScore" style="color: var(--chart-reward); font-size: 36px;">--</div>
205
+ <div style="font-size:10px; color: var(--text-secondary); margin-top:4px;">Grader Score</div>
206
+ <div style="font-size:11px; margin-top:8px;">
207
+ <span style="color: var(--text-secondary);">Steps:</span>
208
+ <span id="totalSteps" style="font-family: 'JetBrains Mono'; font-weight:600;">0</span>
209
+ <span style="color: var(--text-secondary); margin-left:8px;">Blackout:</span>
210
+ <span id="blackoutStatus" style="font-family: 'JetBrains Mono'; font-weight:600; color: var(--status-normal);">No</span>
211
+ </div>
212
+ </div>
213
+ <div class="controls-row">
214
+ <button class="ctrl-btn" onclick="getGrade()">Grade</button>
215
+ <button class="ctrl-btn danger" onclick="resetEpisode()">New Episode</button>
216
+ </div>
217
+ </div>
218
+
219
+ </footer>
220
+
221
+ </div>
222
+
223
+ <script src="/static/app.js"></script>
224
+ </body>
225
+ </html>
static/karnataka.svg ADDED
static/logo.png ADDED

Git LFS Details

  • SHA256: 7c5b33163678b884123740782fbaab4bafba3d02e4a2a36ec1ae4e138af31915
  • Pointer size: 129 Bytes
  • Size of remote file: 1.37 kB
static/style.css ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ============================================================================
2
+ OpenGrid KPTCL-SLDC Control Room — Design System
3
+ Inspired by ERCOT control room aesthetics, adapted for Karnataka grid
4
+ ============================================================================ */
5
+
6
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500;600;700&display=swap');
7
+
8
+ /* ---------- CSS Custom Properties ---------- */
9
+ :root {
10
+ /* Background layers */
11
+ --bg-primary: #0a0e1a;
12
+ --bg-secondary: #0f1628;
13
+ --bg-tertiary: #141d35;
14
+ --bg-glass: rgba(15, 22, 40, 0.85);
15
+ --bg-card: rgba(15, 22, 40, 0.7);
16
+
17
+ /* Operational states */
18
+ --status-normal: #00e5a0;
19
+ --status-warning: #ffd700;
20
+ --status-critical:#ff3d3d;
21
+ --status-offline: #4a5568;
22
+ --status-overload:#ff6b35;
23
+
24
+ /* Voltage colors */
25
+ --voltage-400kv: #e94560;
26
+ --voltage-220kv: #f5a623;
27
+ --voltage-110kv: #7ed321;
28
+ --voltage-66kv: #4a90d9;
29
+
30
+ /* Agent identity colors */
31
+ --agent-0: #00bfff;
32
+ --agent-1: #ff69b4;
33
+ --agent-2: #ff6347;
34
+
35
+ /* Text */
36
+ --text-primary: #e8eaf6;
37
+ --text-secondary: #90a4ae;
38
+ --text-accent: #00e5a0;
39
+ --text-danger: #ff5252;
40
+ --text-muted: #546e7a;
41
+
42
+ /* Chart */
43
+ --chart-demand: #00bfff;
44
+ --chart-supply: #00e5a0;
45
+ --chart-reward: #ffd700;
46
+
47
+ /* Spacing */
48
+ --gap-sm: 8px;
49
+ --gap-md: 12px;
50
+ --gap-lg: 16px;
51
+ --gap-xl: 20px;
52
+
53
+ /* Radius */
54
+ --radius-sm: 6px;
55
+ --radius-md: 10px;
56
+ --radius-lg: 14px;
57
+ }
58
+
59
+ /* ---------- Reset & Base ---------- */
60
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
61
+
62
+ html, body {
63
+ height: 100%;
64
+ background: var(--bg-primary);
65
+ color: var(--text-primary);
66
+ font-family: 'Inter', 'Segoe UI', sans-serif;
67
+ font-size: 13px;
68
+ line-height: 1.5;
69
+ overflow: hidden;
70
+ -webkit-font-smoothing: antialiased;
71
+ }
72
+
73
+ /* Subtle scanline overlay */
74
+ body::before {
75
+ content: '';
76
+ position: fixed;
77
+ top: 0; left: 0; right: 0; bottom: 0;
78
+ pointer-events: none;
79
+ z-index: 9999;
80
+ background: repeating-linear-gradient(
81
+ 0deg,
82
+ transparent,
83
+ transparent 2px,
84
+ rgba(0,0,0,0.03) 2px,
85
+ rgba(0,0,0,0.03) 4px
86
+ );
87
+ }
88
+
89
+ /* ---------- Layout ---------- */
90
+ .control-room {
91
+ display: grid;
92
+ grid-template-rows: 52px 1fr 180px;
93
+ grid-template-columns: 260px 1fr 300px;
94
+ grid-template-areas:
95
+ "header header header"
96
+ "left center right"
97
+ "bottom bottom bottom";
98
+ height: 100vh;
99
+ gap: 1px;
100
+ background: rgba(255,255,255,0.04);
101
+ }
102
+
103
+ /* ---------- Header ---------- */
104
+ .header {
105
+ grid-area: header;
106
+ background: linear-gradient(90deg, #0a0e1a, #0f2040);
107
+ display: flex;
108
+ align-items: center;
109
+ padding: 0 var(--gap-lg);
110
+ gap: var(--gap-lg);
111
+ border-bottom: 1px solid rgba(0,229,160,0.15);
112
+ z-index: 10;
113
+ }
114
+
115
+ .header-brand {
116
+ display: flex;
117
+ align-items: center;
118
+ gap: var(--gap-sm);
119
+ flex-shrink: 0;
120
+ }
121
+
122
+ .header-brand .logo {
123
+ width: 28px;
124
+ height: 28px;
125
+ background: linear-gradient(135deg, #00e5a0, #00bfff);
126
+ border-radius: 6px;
127
+ display: flex;
128
+ align-items: center;
129
+ justify-content: center;
130
+ font-weight: 700;
131
+ font-size: 14px;
132
+ color: #0a0e1a;
133
+ }
134
+
135
+ .header-brand h1 {
136
+ font-size: 14px;
137
+ font-weight: 600;
138
+ letter-spacing: 0.5px;
139
+ }
140
+
141
+ .header-brand .sub {
142
+ font-size: 10px;
143
+ color: var(--text-secondary);
144
+ letter-spacing: 1px;
145
+ text-transform: uppercase;
146
+ }
147
+
148
+ .header-stats {
149
+ display: flex;
150
+ gap: var(--gap-lg);
151
+ margin-left: auto;
152
+ align-items: center;
153
+ }
154
+
155
+ .header-stat {
156
+ display: flex;
157
+ flex-direction: column;
158
+ align-items: center;
159
+ padding: 4px 12px;
160
+ border-radius: var(--radius-sm);
161
+ background: rgba(255,255,255,0.04);
162
+ }
163
+
164
+ .header-stat .label {
165
+ font-size: 9px;
166
+ text-transform: uppercase;
167
+ letter-spacing: 1px;
168
+ color: var(--text-secondary);
169
+ }
170
+
171
+ .header-stat .value {
172
+ font-family: 'JetBrains Mono', monospace;
173
+ font-size: 14px;
174
+ font-weight: 600;
175
+ }
176
+
177
+ .header-stat .value.normal { color: var(--status-normal); }
178
+ .header-stat .value.warning { color: var(--status-warning); }
179
+ .header-stat .value.critical { color: var(--status-critical); }
180
+
181
+ .sim-badge {
182
+ display: flex;
183
+ align-items: center;
184
+ gap: 6px;
185
+ padding: 4px 10px;
186
+ border-radius: 20px;
187
+ background: rgba(0,229,160,0.1);
188
+ border: 1px solid rgba(0,229,160,0.25);
189
+ font-size: 10px;
190
+ font-weight: 600;
191
+ color: var(--status-normal);
192
+ text-transform: uppercase;
193
+ letter-spacing: 1px;
194
+ }
195
+
196
+ .sim-badge .dot {
197
+ width: 6px; height: 6px;
198
+ background: var(--status-normal);
199
+ border-radius: 50%;
200
+ animation: pulse-dot 2s infinite;
201
+ }
202
+
203
+ @keyframes pulse-dot {
204
+ 0%, 100% { opacity: 1; box-shadow: 0 0 0 0 rgba(0,229,160,0.4); }
205
+ 50% { opacity: 0.7; box-shadow: 0 0 0 4px rgba(0,229,160,0); }
206
+ }
207
+
208
+ /* ---------- Left Panel ---------- */
209
+ .left-panel {
210
+ grid-area: left;
211
+ background: var(--bg-secondary);
212
+ padding: var(--gap-md);
213
+ overflow-y: auto;
214
+ display: flex;
215
+ flex-direction: column;
216
+ gap: var(--gap-md);
217
+ border-right: 1px solid rgba(255,255,255,0.05);
218
+ }
219
+
220
+ /* ---------- Cards (shared) ---------- */
221
+ .card {
222
+ background: var(--bg-card);
223
+ border: 1px solid rgba(255,255,255,0.06);
224
+ border-radius: var(--radius-md);
225
+ padding: var(--gap-md);
226
+ backdrop-filter: blur(8px);
227
+ }
228
+
229
+ .card-title {
230
+ font-size: 10px;
231
+ font-weight: 600;
232
+ text-transform: uppercase;
233
+ letter-spacing: 1.5px;
234
+ color: var(--text-secondary);
235
+ margin-bottom: var(--gap-sm);
236
+ padding-bottom: 6px;
237
+ border-bottom: 1px solid rgba(255,255,255,0.06);
238
+ }
239
+
240
+ /* ---------- Alarm Log ---------- */
241
+ .alarm-log {
242
+ flex: 1;
243
+ max-height: 90px;
244
+ overflow-y: auto;
245
+ font-family: 'JetBrains Mono', monospace;
246
+ font-size: 10px;
247
+ line-height: 1.4;
248
+ display: flex;
249
+ flex-direction: column;
250
+ gap: 4px;
251
+ }
252
+ .alarm-entry {
253
+ padding: 4px 6px;
254
+ background: rgba(255,255,255,0.03);
255
+ border-left: 2px solid transparent;
256
+ border-radius: 2px;
257
+ }
258
+ .alarm-time { color: var(--text-muted); margin-right: 6px; }
259
+ .alarm-entry.warn { border-left-color: var(--status-warning); background: rgba(255,152,0,0.05); color: #ffb74d; }
260
+ .alarm-entry.crit { border-left-color: var(--status-critical); background: rgba(244,67,54,0.05); color: #ef5350; }
261
+ .alarm-entry.info { border-left-color: var(--status-normal); }
262
+
263
+ /* ---------- Frequency Display ---------- */
264
+ .freq-display {
265
+ text-align: center;
266
+ padding: var(--gap-md) var(--gap-sm);
267
+ }
268
+
269
+ .freq-arc-container {
270
+ position: relative;
271
+ width: 200px;
272
+ height: 110px;
273
+ margin: 0 auto;
274
+ }
275
+
276
+ .freq-arc-container svg { overflow: visible; }
277
+
278
+ .freq-value {
279
+ font-family: 'JetBrains Mono', monospace;
280
+ font-size: 32px;
281
+ font-weight: 700;
282
+ letter-spacing: -1px;
283
+ transition: color 0.3s;
284
+ }
285
+
286
+ .freq-value.normal { color: var(--status-normal); text-shadow: 0 0 20px rgba(0,229,160,0.3); }
287
+ .freq-value.warning { color: var(--status-warning); text-shadow: 0 0 20px rgba(255,215,0,0.3); }
288
+ .freq-value.critical { color: var(--status-critical); text-shadow: 0 0 20px rgba(255,61,61,0.3); animation: freq-blink 0.5s infinite; }
289
+
290
+ @keyframes freq-blink {
291
+ 0%, 100% { opacity: 1; }
292
+ 50% { opacity: 0.6; }
293
+ }
294
+
295
+ .freq-deviation {
296
+ margin-top: 4px;
297
+ font-family: 'JetBrains Mono', monospace;
298
+ font-size: 10px;
299
+ color: var(--text-secondary);
300
+ }
301
+
302
+ /* Grid condition badge */
303
+ .grid-condition {
304
+ display: flex;
305
+ align-items: center;
306
+ justify-content: center;
307
+ gap: 6px;
308
+ margin-top: var(--gap-sm);
309
+ padding: 5px 10px;
310
+ border-radius: 20px;
311
+ font-size: 10px;
312
+ font-weight: 600;
313
+ text-transform: uppercase;
314
+ letter-spacing: 0.8px;
315
+ }
316
+ .grid-condition.normal { background: rgba(0,229,160,0.1); color: var(--status-normal); border: 1px solid rgba(0,229,160,0.2); }
317
+ .grid-condition.conservative { background: rgba(255,215,0,0.08); color: var(--status-warning); border: 1px solid rgba(255,215,0,0.15); }
318
+ .grid-condition.alert { background: rgba(255,107,53,0.1); color: var(--status-overload); border: 1px solid rgba(255,107,53,0.2); }
319
+ .grid-condition.emergency { background: rgba(255,61,61,0.1); color: var(--status-critical); border: 1px solid rgba(255,61,61,0.2); animation: cond-pulse 1s infinite; }
320
+
321
+ @keyframes cond-pulse {
322
+ 0%,100% { box-shadow: 0 0 0 0 rgba(255,61,61,0.2); }
323
+ 50% { box-shadow: 0 0 0 4px rgba(255,61,61,0); }
324
+ }
325
+
326
+ /* ---------- System Summary ---------- */
327
+ .stat-row {
328
+ display: flex;
329
+ justify-content: space-between;
330
+ align-items: center;
331
+ padding: 4px 0;
332
+ font-size: 12px;
333
+ }
334
+
335
+ .stat-row .label { color: var(--text-secondary); }
336
+ .stat-row .value {
337
+ font-family: 'JetBrains Mono', monospace;
338
+ font-weight: 500;
339
+ }
340
+
341
+ .stat-row.highlight .value {
342
+ color: var(--status-normal);
343
+ font-weight: 600;
344
+ }
345
+
346
+ /* Progress bars */
347
+ .progress-bar {
348
+ height: 4px;
349
+ background: rgba(255,255,255,0.06);
350
+ border-radius: 2px;
351
+ overflow: hidden;
352
+ margin-top: 4px;
353
+ }
354
+
355
+ .progress-bar-fill {
356
+ height: 100%;
357
+ border-radius: 2px;
358
+ transition: width 0.5s;
359
+ }
360
+
361
+ /* ---------- Center Panel (Grid Map) ---------- */
362
+ .center-panel {
363
+ grid-area: center;
364
+ background: var(--bg-tertiary);
365
+ position: relative;
366
+ overflow: hidden;
367
+ }
368
+
369
+ .grid-map {
370
+ width: 100%;
371
+ height: 100%;
372
+ }
373
+
374
+ .grid-map svg {
375
+ width: 100%;
376
+ height: 100%;
377
+ }
378
+
379
+ /* SVG map styles */
380
+ .zone-polygon {
381
+ opacity: 0.06;
382
+ transition: opacity 0.4s;
383
+ cursor: pointer;
384
+ filter: blur(0.5px);
385
+ }
386
+ .zone-polygon:hover { opacity: 0.18; }
387
+
388
+ .substation-node { cursor: pointer; }
389
+ .substation-node:hover .node-outer { stroke-width: 2.5; filter: url(#glow); }
390
+ .substation-node:hover .node-label { opacity: 1; }
391
+
392
+ .node-label {
393
+ font-family: 'Inter', sans-serif;
394
+ font-size: 8px;
395
+ fill: var(--text-secondary);
396
+ text-anchor: middle;
397
+ pointer-events: none;
398
+ opacity: 0.7;
399
+ transition: opacity 0.2s;
400
+ }
401
+
402
+ .node-mw {
403
+ font-family: 'JetBrains Mono', monospace;
404
+ font-size: 9px;
405
+ fill: var(--text-primary);
406
+ text-anchor: middle;
407
+ pointer-events: none;
408
+ font-weight: 500;
409
+ }
410
+
411
+ .line-flow {
412
+ fill: none;
413
+ stroke-linecap: round;
414
+ }
415
+
416
+ /* Animated flow on lines */
417
+ @keyframes dash-flow {
418
+ to { stroke-dashoffset: -24; }
419
+ }
420
+ .line-animated {
421
+ animation: dash-flow 1.2s linear infinite;
422
+ }
423
+ .line-animated.reverse {
424
+ animation-direction: reverse;
425
+ }
426
+
427
+ .flow-label {
428
+ font-family: 'JetBrains Mono', monospace;
429
+ font-size: 8px;
430
+ fill: rgba(232,234,246,0.6);
431
+ text-anchor: middle;
432
+ pointer-events: none;
433
+ }
434
+
435
+ .zone-badge { font-family: 'Inter', sans-serif; pointer-events: none; }
436
+ .zone-badge-bg {
437
+ rx: 8;
438
+ fill: rgba(10, 14, 26, 0.88);
439
+ stroke-width: 1;
440
+ backdrop-filter: blur(6px);
441
+ }
442
+ .zone-badge-name { font-size: 10px; font-weight: 600; text-anchor: middle; }
443
+ .zone-badge-status { font-size: 8px; text-anchor: middle; fill: var(--text-secondary); }
444
+ .zone-badge-reward { font-size: 9px; text-anchor: middle; font-weight: 600; font-family: 'JetBrains Mono', monospace; }
445
+
446
+ /* Bus tooltip */
447
+ .bus-tooltip {
448
+ position: absolute;
449
+ background: rgba(10, 14, 26, 0.95);
450
+ border: 1px solid rgba(0,229,160,0.2);
451
+ border-radius: var(--radius-sm);
452
+ padding: 8px 10px;
453
+ font-size: 11px;
454
+ pointer-events: none;
455
+ z-index: 20;
456
+ min-width: 140px;
457
+ backdrop-filter: blur(12px);
458
+ box-shadow: 0 4px 20px rgba(0,0,0,0.4);
459
+ display: none;
460
+ }
461
+ .bus-tooltip.visible { display: block; }
462
+ .bus-tooltip .tt-title {
463
+ font-weight: 600;
464
+ margin-bottom: 4px;
465
+ padding-bottom: 4px;
466
+ border-bottom: 1px solid rgba(255,255,255,0.08);
467
+ }
468
+ .bus-tooltip .tt-row {
469
+ display: flex;
470
+ justify-content: space-between;
471
+ padding: 1px 0;
472
+ }
473
+ .bus-tooltip .tt-row .tt-val {
474
+ font-family: 'JetBrains Mono', monospace;
475
+ font-weight: 500;
476
+ }
477
+
478
+ /* Map overlay controls */
479
+ .map-controls {
480
+ position: absolute;
481
+ top: var(--gap-md);
482
+ right: var(--gap-md);
483
+ display: flex;
484
+ flex-direction: column;
485
+ gap: 4px;
486
+ z-index: 5;
487
+ }
488
+
489
+ .map-btn {
490
+ width: 32px; height: 32px;
491
+ background: var(--bg-glass);
492
+ border: 1px solid rgba(255,255,255,0.1);
493
+ border-radius: var(--radius-sm);
494
+ color: var(--text-secondary);
495
+ font-size: 14px;
496
+ cursor: pointer;
497
+ display: flex;
498
+ align-items: center;
499
+ justify-content: center;
500
+ backdrop-filter: blur(8px);
501
+ transition: all 0.2s;
502
+ }
503
+
504
+ .map-btn:hover {
505
+ background: rgba(0,229,160,0.15);
506
+ color: var(--status-normal);
507
+ border-color: rgba(0,229,160,0.3);
508
+ }
509
+
510
+ /* ---------- Right Panel (Agent Monitor) ---------- */
511
+ .right-panel {
512
+ grid-area: right;
513
+ background: var(--bg-secondary);
514
+ padding: var(--gap-md);
515
+ overflow-y: auto;
516
+ display: flex;
517
+ flex-direction: column;
518
+ gap: var(--gap-md);
519
+ border-left: 1px solid rgba(255,255,255,0.05);
520
+ }
521
+
522
+ /* Agent cards */
523
+ .agent-card {
524
+ border-radius: var(--radius-md);
525
+ padding: var(--gap-md);
526
+ background: var(--bg-card);
527
+ border: 1px solid rgba(255,255,255,0.06);
528
+ backdrop-filter: blur(8px);
529
+ transition: border-color 0.3s, box-shadow 0.3s;
530
+ }
531
+
532
+ .agent-card.active {
533
+ border-color: rgba(0,229,160,0.2);
534
+ }
535
+
536
+ .agent-card.warning {
537
+ border-color: rgba(255,215,0,0.3);
538
+ box-shadow: 0 0 12px rgba(255,215,0,0.05);
539
+ }
540
+
541
+ .agent-card.critical {
542
+ border-color: rgba(255,61,61,0.3);
543
+ box-shadow: 0 0 12px rgba(255,61,61,0.08);
544
+ animation: card-pulse 1.5s infinite;
545
+ }
546
+
547
+ @keyframes card-pulse {
548
+ 0%, 100% { box-shadow: 0 0 12px rgba(255,61,61,0.08); }
549
+ 50% { box-shadow: 0 0 20px rgba(255,61,61,0.15); }
550
+ }
551
+
552
+ .agent-header {
553
+ display: flex;
554
+ justify-content: space-between;
555
+ align-items: center;
556
+ margin-bottom: var(--gap-sm);
557
+ }
558
+
559
+ .agent-name {
560
+ font-size: 12px;
561
+ font-weight: 600;
562
+ display: flex;
563
+ align-items: center;
564
+ gap: 6px;
565
+ }
566
+
567
+ .agent-dot {
568
+ width: 8px; height: 8px;
569
+ border-radius: 50%;
570
+ flex-shrink: 0;
571
+ }
572
+
573
+ .agent-status-badge {
574
+ font-size: 9px;
575
+ font-weight: 600;
576
+ padding: 2px 6px;
577
+ border-radius: 10px;
578
+ text-transform: uppercase;
579
+ letter-spacing: 0.5px;
580
+ }
581
+
582
+ .agent-status-badge.active {
583
+ background: rgba(0,229,160,0.15);
584
+ color: var(--status-normal);
585
+ }
586
+
587
+ .agent-status-badge.corrected {
588
+ background: rgba(255,215,0,0.15);
589
+ color: var(--status-warning);
590
+ }
591
+
592
+ .agent-metrics {
593
+ display: grid;
594
+ grid-template-columns: 1fr 1fr;
595
+ gap: 6px;
596
+ margin-top: var(--gap-sm);
597
+ }
598
+
599
+ .agent-metric {
600
+ padding: 6px 8px;
601
+ background: rgba(255,255,255,0.02);
602
+ border-radius: var(--radius-sm);
603
+ }
604
+
605
+ .agent-metric .label {
606
+ font-size: 9px;
607
+ text-transform: uppercase;
608
+ letter-spacing: 0.5px;
609
+ color: var(--text-muted);
610
+ }
611
+
612
+ .agent-metric .value {
613
+ font-family: 'JetBrains Mono', monospace;
614
+ font-size: 14px;
615
+ font-weight: 600;
616
+ margin-top: 2px;
617
+ }
618
+
619
+ /* Safety shield */
620
+ .safety-shield {
621
+ margin-top: var(--gap-sm);
622
+ padding: 6px 8px;
623
+ border-radius: var(--radius-sm);
624
+ display: flex;
625
+ align-items: center;
626
+ gap: 6px;
627
+ font-size: 10px;
628
+ font-weight: 600;
629
+ text-transform: uppercase;
630
+ letter-spacing: 0.5px;
631
+ }
632
+
633
+ .safety-shield.safe {
634
+ background: rgba(0,229,160,0.08);
635
+ border: 1px solid rgba(0,229,160,0.15);
636
+ color: var(--status-normal);
637
+ }
638
+
639
+ .safety-shield.corrected {
640
+ background: rgba(255,215,0,0.08);
641
+ border: 1px solid rgba(255,215,0,0.2);
642
+ color: var(--status-warning);
643
+ }
644
+
645
+ .safety-shield.violated {
646
+ background: rgba(255,61,61,0.08);
647
+ border: 1px solid rgba(255,61,61,0.2);
648
+ color: var(--status-critical);
649
+ }
650
+
651
+ /* Sparkline */
652
+ .sparkline-container {
653
+ margin-top: var(--gap-sm);
654
+ height: 30px;
655
+ background: rgba(255,255,255,0.02);
656
+ border-radius: var(--radius-sm);
657
+ padding: 4px;
658
+ }
659
+
660
+ .sparkline-container svg {
661
+ width: 100%;
662
+ height: 100%;
663
+ }
664
+
665
+ /* ---------- Bottom Panel ---------- */
666
+ .bottom-panel {
667
+ grid-area: bottom;
668
+ background: var(--bg-secondary);
669
+ display: grid;
670
+ grid-template-columns: 2fr 1fr 1fr 1fr;
671
+ gap: 1px;
672
+ border-top: 1px solid rgba(255,255,255,0.05);
673
+ }
674
+
675
+ .bottom-card {
676
+ background: var(--bg-card);
677
+ padding: var(--gap-md);
678
+ display: flex;
679
+ flex-direction: column;
680
+ }
681
+
682
+ .chart-area {
683
+ flex: 1;
684
+ position: relative;
685
+ min-height: 0;
686
+ }
687
+
688
+ .chart-area canvas, .chart-area svg {
689
+ width: 100%;
690
+ height: 100%;
691
+ }
692
+
693
+ /* Reward chart */
694
+ .reward-history {
695
+ flex: 1;
696
+ }
697
+
698
+ /* Controls */
699
+ .controls-row {
700
+ display: flex;
701
+ gap: var(--gap-sm);
702
+ margin-top: var(--gap-sm);
703
+ }
704
+
705
+ .ctrl-btn {
706
+ flex: 1;
707
+ padding: 6px 10px;
708
+ background: rgba(255,255,255,0.04);
709
+ border: 1px solid rgba(255,255,255,0.1);
710
+ border-radius: var(--radius-sm);
711
+ color: var(--text-primary);
712
+ font-family: 'Inter', sans-serif;
713
+ font-size: 11px;
714
+ font-weight: 500;
715
+ cursor: pointer;
716
+ transition: all 0.2s;
717
+ text-align: center;
718
+ }
719
+
720
+ .ctrl-btn:hover {
721
+ background: rgba(0,229,160,0.1);
722
+ border-color: rgba(0,229,160,0.3);
723
+ }
724
+
725
+ .ctrl-btn.active {
726
+ background: rgba(0,229,160,0.15);
727
+ border-color: var(--status-normal);
728
+ color: var(--status-normal);
729
+ }
730
+
731
+ .ctrl-btn.danger {
732
+ border-color: rgba(255,61,61,0.3);
733
+ }
734
+
735
+ .ctrl-btn.danger:hover {
736
+ background: rgba(255,61,61,0.1);
737
+ border-color: rgba(255,61,61,0.5);
738
+ color: var(--status-critical);
739
+ }
740
+
741
+ /* Task selector */
742
+ .task-selector {
743
+ display: flex;
744
+ gap: 4px;
745
+ }
746
+
747
+ .task-btn {
748
+ flex: 1;
749
+ padding: 4px 8px;
750
+ background: rgba(255,255,255,0.03);
751
+ border: 1px solid rgba(255,255,255,0.08);
752
+ border-radius: var(--radius-sm);
753
+ color: var(--text-secondary);
754
+ font-size: 10px;
755
+ font-weight: 500;
756
+ cursor: pointer;
757
+ transition: all 0.2s;
758
+ text-transform: uppercase;
759
+ letter-spacing: 0.5px;
760
+ }
761
+
762
+ .task-btn:hover { border-color: rgba(0,229,160,0.3); color: var(--text-primary); }
763
+ .task-btn.active { background: rgba(0,229,160,0.1); border-color: var(--status-normal); color: var(--status-normal); }
764
+
765
+ /* Leaderboard */
766
+ .leaderboard {
767
+ list-style: none;
768
+ }
769
+
770
+ .leaderboard li {
771
+ display: flex;
772
+ justify-content: space-between;
773
+ align-items: center;
774
+ padding: 5px 0;
775
+ font-size: 11px;
776
+ border-bottom: 1px solid rgba(255,255,255,0.03);
777
+ }
778
+
779
+ .leaderboard li:last-child { border-bottom: none; }
780
+
781
+ .leaderboard .agent-label {
782
+ display: flex;
783
+ align-items: center;
784
+ gap: 6px;
785
+ }
786
+
787
+ .leaderboard .score {
788
+ font-family: 'JetBrains Mono', monospace;
789
+ font-weight: 600;
790
+ font-size: 12px;
791
+ }
792
+
793
+ /* Coordination score */
794
+ .coord-score {
795
+ text-align: center;
796
+ padding: var(--gap-sm);
797
+ }
798
+
799
+ .coord-score .big-value {
800
+ font-family: 'JetBrains Mono', monospace;
801
+ font-size: 28px;
802
+ font-weight: 700;
803
+ }
804
+
805
+ /* Alert banner */
806
+ .alert-banner {
807
+ position: fixed;
808
+ top: 52px;
809
+ left: 0; right: 0;
810
+ z-index: 100;
811
+ padding: 8px var(--gap-lg);
812
+ display: flex;
813
+ align-items: center;
814
+ gap: var(--gap-sm);
815
+ font-size: 12px;
816
+ font-weight: 500;
817
+ transform: translateY(-100%);
818
+ transition: transform 0.3s;
819
+ }
820
+
821
+ .alert-banner.visible { transform: translateY(0); }
822
+
823
+ .alert-banner.critical {
824
+ background: rgba(255,61,61,0.15);
825
+ border-bottom: 1px solid rgba(255,61,61,0.3);
826
+ color: var(--status-critical);
827
+ }
828
+
829
+ .alert-banner.warning {
830
+ background: rgba(255,215,0,0.1);
831
+ border-bottom: 1px solid rgba(255,215,0,0.2);
832
+ color: var(--status-warning);
833
+ }
834
+
835
+ .alert-banner .dismiss {
836
+ margin-left: auto;
837
+ background: none;
838
+ border: 1px solid currentColor;
839
+ border-radius: var(--radius-sm);
840
+ color: inherit;
841
+ padding: 2px 8px;
842
+ font-size: 10px;
843
+ cursor: pointer;
844
+ opacity: 0.7;
845
+ }
846
+
847
+ .alert-banner .dismiss:hover { opacity: 1; }
848
+
849
+ /* Scrollbar */
850
+ ::-webkit-scrollbar { width: 4px; }
851
+ ::-webkit-scrollbar-track { background: transparent; }
852
+ ::-webkit-scrollbar-thumb { background: rgba(255,255,255,0.1); border-radius: 2px; }
853
+ ::-webkit-scrollbar-thumb:hover { background: rgba(255,255,255,0.2); }
854
+
855
+ /* Loading state */
856
+ .loading-overlay {
857
+ position: fixed;
858
+ top: 0; left: 0; right: 0; bottom: 0;
859
+ background: var(--bg-primary);
860
+ display: flex;
861
+ flex-direction: column;
862
+ align-items: center;
863
+ justify-content: center;
864
+ z-index: 1000;
865
+ transition: opacity 0.5s;
866
+ }
867
+
868
+ .loading-overlay.hidden {
869
+ opacity: 0;
870
+ pointer-events: none;
871
+ }
872
+
873
+ .loading-spinner {
874
+ width: 40px; height: 40px;
875
+ border: 3px solid rgba(0,229,160,0.15);
876
+ border-top-color: var(--status-normal);
877
+ border-radius: 50%;
878
+ animation: spin 0.8s linear infinite;
879
+ }
880
+
881
+ @keyframes spin { to { transform: rotate(360deg); } }
882
+
883
+ .loading-text {
884
+ margin-top: var(--gap-md);
885
+ color: var(--text-secondary);
886
+ font-size: 12px;
887
+ letter-spacing: 2px;
888
+ text-transform: uppercase;
889
+ }
890
+
891
+ /* ── Leaflet Overrides ── */
892
+ .grid-map .leaflet-container {
893
+ background: var(--bg-primary) !important;
894
+ }
895
+
896
+ .leaflet-tooltip-dark {
897
+ background: rgba(10, 14, 26, 0.92) !important;
898
+ border: 1px solid rgba(0, 229, 160, 0.3) !important;
899
+ color: #e0e0e0 !important;
900
+ font-family: 'JetBrains Mono', monospace !important;
901
+ font-size: 11px !important;
902
+ border-radius: 6px !important;
903
+ padding: 6px 10px !important;
904
+ box-shadow: 0 4px 20px rgba(0,0,0,0.6) !important;
905
+ }
906
+
907
+ .leaflet-tooltip-dark::before {
908
+ border-top-color: rgba(10, 14, 26, 0.92) !important;
909
+ }
910
+
911
+ .bus-label-icon, .bus-mw-icon, .zone-badge-leaflet {
912
+ background: none !important;
913
+ border: none !important;
914
+ text-align: center;
915
+ }
916
+
917
+ /* Dark zoom controls */
918
+ .leaflet-control-zoom a {
919
+ background: rgba(15, 22, 40, 0.9) !important;
920
+ color: var(--status-normal) !important;
921
+ border-color: rgba(0, 229, 160, 0.2) !important;
922
+ font-family: 'JetBrains Mono', monospace !important;
923
+ }
924
+ .leaflet-control-zoom a:hover {
925
+ background: rgba(0, 229, 160, 0.15) !important;
926
+ }
927
+
928
+ .leaflet-control-attribution {
929
+ background: rgba(10, 14, 26, 0.6) !important;
930
+ color: #555 !important;
931
+ font-size: 9px !important;
932
+ }
933
+ .leaflet-control-attribution a {
934
+ color: #666 !important;
935
+ }
tests/__init__.py ADDED
File without changes
tests/test_multi_agent.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for multi-agent POMDP features:
3
+ - Zone assignment and partitioning
4
+ - Partial observability (ZoneObservation)
5
+ - Safety layer (action validation and correction)
6
+ - Oversight agent (coordination monitoring)
7
+ - Multi-agent step (combined pipeline)
8
+ """
9
+
10
+ import copy
11
+ import unittest
12
+
13
+ import networkx as nx
14
+ import numpy as np
15
+
16
+ from src.environment import OpenGridEnv
17
+ from src.tasks import TASKS
18
+ from src.models import GridAction, BusAdjustment, TopologyAction, ZoneObservation
19
+ from src.safety import SafetyLayer
20
+ from src.oversight import OversightAgent
21
+
22
+
23
+ def task(task_id: str):
24
+ """Get a deep-copied task config to prevent cross-test contamination."""
25
+ return copy.deepcopy(TASKS[task_id])
26
+
27
+
28
+ class TestZoneAssignment(unittest.TestCase):
29
+ """Tests for multi-agent zone partitioning."""
30
+
31
+ def test_all_buses_assigned(self):
32
+ """Every bus should be assigned to exactly one zone."""
33
+ for task_id, config in TASKS.items():
34
+ zone_map = config['zone_assignments']
35
+ for i in range(config['num_buses']):
36
+ self.assertIn(i, zone_map, f"Bus {i} not assigned in {task_id}")
37
+
38
+ def test_zone_count_matches(self):
39
+ """Number of zones should match num_agents."""
40
+ for task_id, config in TASKS.items():
41
+ agents = set(config['zone_assignments'].values())
42
+ self.assertEqual(len(agents), config['num_agents'],
43
+ f"Zone count mismatch in {task_id}")
44
+
45
+ def test_no_empty_zones(self):
46
+ """Each zone should have at least 1 bus."""
47
+ for task_id, config in TASKS.items():
48
+ for agent_id in range(config['num_agents']):
49
+ bus_ids = config['zone_bus_ids'][agent_id]
50
+ self.assertGreater(len(bus_ids), 0,
51
+ f"Empty zone {agent_id} in {task_id}")
52
+
53
+ def test_lines_classified(self):
54
+ """All lines should be classified as internal or boundary."""
55
+ for task_id, config in TASKS.items():
56
+ all_internal = set()
57
+ all_boundary = set()
58
+ for agent_id in range(config['num_agents']):
59
+ all_internal.update(config['internal_lines'].get(agent_id, []))
60
+ all_boundary.update(config['boundary_lines'].get(agent_id, []))
61
+
62
+ all_line_ids = {l['id'] for l in config['lines']}
63
+ classified = all_internal | all_boundary
64
+ self.assertEqual(all_line_ids, classified,
65
+ f"Unclassified lines in {task_id}")
66
+
67
+
68
+ class TestPartialObservability(unittest.TestCase):
69
+ """Tests for POMDP zone observations."""
70
+
71
+ def test_partial_obs_returns_zone_obs(self):
72
+ """reset_multi should return ZoneObservation for each agent."""
73
+ config = task("task_easy")
74
+ env = OpenGridEnv(config)
75
+ zone_obs = env.reset_multi()
76
+
77
+ self.assertEqual(len(zone_obs), config["num_agents"],
78
+ "Should have one observation per agent")
79
+ for agent_id, obs in zone_obs.items():
80
+ self.assertIsInstance(obs, ZoneObservation)
81
+ self.assertEqual(obs.agent_id, agent_id)
82
+
83
+ def test_partial_obs_only_shows_local_buses(self):
84
+ """Each agent should only see buses in their zone."""
85
+ config = task("task_medium")
86
+ env = OpenGridEnv(config)
87
+ zone_obs = env.reset_multi()
88
+
89
+ for agent_id, obs in zone_obs.items():
90
+ expected_bus_ids = set(config['zone_bus_ids'][agent_id])
91
+ actual_bus_ids = {b.id for b in obs.local_buses}
92
+ self.assertEqual(actual_bus_ids, expected_bus_ids,
93
+ f"Agent {agent_id} sees wrong buses")
94
+
95
+ def test_frequency_has_noise(self):
96
+ """POMDP observations should have noisy frequency readings."""
97
+ config = task("task_easy")
98
+ env = OpenGridEnv(config)
99
+ env.reset_multi()
100
+
101
+ # Compare zone obs against full obs from the same reset
102
+ full_obs = env.state()
103
+ differences = []
104
+ for agent_id in range(config['num_agents']):
105
+ z_obs = env._get_zone_obs(agent_id)
106
+ diff = abs(z_obs.grid_frequency - full_obs.grid_frequency)
107
+ differences.append(diff)
108
+
109
+ # At least one agent should see noisy frequency
110
+ self.assertTrue(any(d > 0.001 for d in differences),
111
+ "No frequency noise detected in POMDP observations")
112
+
113
+
114
+ class TestSafetyLayer(unittest.TestCase):
115
+ """Tests for the safety constraint filter."""
116
+
117
+ def setUp(self):
118
+ self.config = task("task_medium")
119
+ self.safety = SafetyLayer(self.config)
120
+ self.env = OpenGridEnv(self.config)
121
+ self.env.reset()
122
+
123
+ def test_zone_boundary_enforcement(self):
124
+ """Agent should not be able to adjust buses in another zone."""
125
+ agent_0_buses = set(self.config['zone_bus_ids'][0])
126
+ other_bus = None
127
+ for bus_cfg in self.config['buses']:
128
+ if bus_cfg['id'] not in agent_0_buses:
129
+ other_bus = bus_cfg['id']
130
+ break
131
+
132
+ if other_bus is None:
133
+ self.skipTest("All buses in agent 0's zone (trivial grid)")
134
+
135
+ action = GridAction(bus_adjustments=[
136
+ BusAdjustment(bus_id=other_bus, delta=10.0)
137
+ ])
138
+
139
+ corrected, report = self.safety.validate_and_correct(
140
+ agent_id=0,
141
+ proposed_action=action,
142
+ current_line_state=self.env.line_state,
143
+ current_bus_state=self.env.bus_state,
144
+ cooldowns=self.env.cooldowns,
145
+ )
146
+
147
+ self.assertTrue(report.was_corrected, "Should have corrected cross-zone action")
148
+ self.assertEqual(len(corrected.bus_adjustments), 0,
149
+ "Cross-zone adjustment should be removed")
150
+
151
+ def test_safe_action_passes_through(self):
152
+ """A valid action within the agent's zone should not be corrected."""
153
+ agent_0_buses = self.config['zone_bus_ids'][0]
154
+ controllable = None
155
+ for bus_cfg in self.config['buses']:
156
+ if bus_cfg['id'] in agent_0_buses and bus_cfg['type'] in ['generator', 'battery', 'slack']:
157
+ controllable = bus_cfg['id']
158
+ break
159
+
160
+ if controllable is None:
161
+ self.skipTest("No controllable bus in agent 0's zone")
162
+
163
+ action = GridAction(bus_adjustments=[
164
+ BusAdjustment(bus_id=controllable, delta=5.0)
165
+ ])
166
+
167
+ corrected, report = self.safety.validate_and_correct(
168
+ agent_id=0,
169
+ proposed_action=action,
170
+ current_line_state=self.env.line_state,
171
+ current_bus_state=self.env.bus_state,
172
+ cooldowns=self.env.cooldowns,
173
+ )
174
+
175
+ # Should pass through (may have minor clamping)
176
+ self.assertEqual(len(corrected.bus_adjustments), 1,
177
+ "Valid action should produce one adjustment")
178
+
179
+ def test_islanding_blocked(self):
180
+ """Opening a bridge line should be blocked by safety layer."""
181
+ G = nx.Graph()
182
+ for line in self.config['lines']:
183
+ G.add_edge(line['from'], line['to'])
184
+ bridges = list(nx.bridges(G))
185
+ if not bridges:
186
+ self.skipTest("No bridges in grid topology")
187
+
188
+ bridge = bridges[0]
189
+ line_id = next(
190
+ l['id'] for l in self.config['lines']
191
+ if (l['from'], l['to']) == bridge or (l['to'], l['from']) == bridge
192
+ )
193
+
194
+ action = GridAction(topology_actions=[
195
+ TopologyAction(line_id=line_id, action="open")
196
+ ])
197
+
198
+ corrected, report = self.safety.validate_and_correct(
199
+ agent_id=0,
200
+ proposed_action=action,
201
+ current_line_state=self.env.line_state,
202
+ current_bus_state=self.env.bus_state,
203
+ cooldowns=self.env.cooldowns,
204
+ )
205
+
206
+ self.assertTrue(report.was_corrected, "Bridge opening should be blocked")
207
+ self.assertEqual(len(corrected.topology_actions), 0,
208
+ "Bridge opening should be removed")
209
+
210
+ def test_duplicate_battery_adjustments_aggregated(self):
211
+ """Multiple adjustments to the same battery should be aggregated."""
212
+ battery = next(
213
+ (b for b in self.config['buses'] if b['type'] == 'battery'), None
214
+ )
215
+ if battery is None:
216
+ self.skipTest("No battery in task")
217
+
218
+ bus_id = battery['id']
219
+ agent_id = self.config['zone_assignments'].get(bus_id, 0)
220
+
221
+ # Set SOC to a known value
222
+ for b in self.env.bus_state:
223
+ if b['id'] == bus_id:
224
+ b['soc'] = 10.0
225
+
226
+ action = GridAction(bus_adjustments=[
227
+ BusAdjustment(bus_id=bus_id, delta=8.0),
228
+ BusAdjustment(bus_id=bus_id, delta=8.0),
229
+ ])
230
+
231
+ corrected, report = self.safety.validate_and_correct(
232
+ agent_id=agent_id,
233
+ proposed_action=action,
234
+ current_line_state=self.env.line_state,
235
+ current_bus_state=self.env.bus_state,
236
+ cooldowns=self.env.cooldowns,
237
+ )
238
+
239
+ total_delta = sum(a.delta for a in corrected.bus_adjustments)
240
+ self.assertLessEqual(total_delta, 10.0,
241
+ "Combined discharge should not exceed SOC")
242
+
243
+
244
+ class TestOversightAgent(unittest.TestCase):
245
+ """Tests for the coordination oversight agent."""
246
+
247
+ def test_no_conflict_scores_high(self):
248
+ """Cooperative actions should score high coordination."""
249
+ config = task("task_easy")
250
+ oversight = OversightAgent(config)
251
+
252
+ # Both agents inject (cooperative)
253
+ agent_actions = {
254
+ 0: GridAction(bus_adjustments=[BusAdjustment(bus_id=0, delta=5.0)]),
255
+ 1: GridAction(bus_adjustments=[BusAdjustment(bus_id=1, delta=3.0)]),
256
+ }
257
+
258
+ report = oversight.evaluate(
259
+ agent_actions=agent_actions,
260
+ safety_reports={},
261
+ pre_frequency=49.8,
262
+ post_frequency=49.9,
263
+ pre_bus_state=[],
264
+ post_bus_state=[],
265
+ )
266
+
267
+ self.assertGreater(report.coordination_score, 0.5,
268
+ "Cooperative actions should score > 0.5")
269
+
270
+ def test_reset_clears_history(self):
271
+ """Resetting oversight should clear intervention history."""
272
+ config = task("task_easy")
273
+ oversight = OversightAgent(config)
274
+ oversight.intervention_history[0] = 5
275
+ oversight.reset()
276
+ self.assertEqual(oversight.intervention_history[0], 0)
277
+
278
+
279
+ class TestMultiAgentStep(unittest.TestCase):
280
+ """Integration tests for the full multi-agent pipeline."""
281
+
282
+ def test_multi_agent_step_returns_result(self):
283
+ """step_multi should return a complete MultiAgentStepResult."""
284
+ config = task("task_easy")
285
+ env = OpenGridEnv(config)
286
+ env.reset_multi()
287
+
288
+ # No-op actions for all agents
289
+ actions = {i: GridAction() for i in range(config['num_agents'])}
290
+ result = env.step_multi(actions)
291
+
292
+ self.assertEqual(len(result.observations), config['num_agents'])
293
+ self.assertEqual(len(result.rewards), config['num_agents'])
294
+ self.assertIsInstance(result.team_reward, float)
295
+ self.assertIsInstance(result.done, bool)
296
+ self.assertEqual(len(result.safety_reports), config['num_agents'])
297
+
298
+ def test_safety_reports_match_agent_ids(self):
299
+ """Safety reports should contain all expected agent IDs."""
300
+ config = task("task_easy")
301
+ env = OpenGridEnv(config)
302
+ env.reset_multi()
303
+
304
+ result = env.step_multi({
305
+ i: GridAction() for i in range(config['num_agents'])
306
+ })
307
+
308
+ report_ids = set(result.safety_reports.keys())
309
+ expected_ids = set(range(config['num_agents']))
310
+ self.assertEqual(report_ids, expected_ids,
311
+ "Safety report agent IDs should match expected agents")
312
+
313
+ def test_multi_agent_episode_completes(self):
314
+ """A full multi-agent episode should complete without errors."""
315
+ config = task("task_easy")
316
+ env = OpenGridEnv(config)
317
+ env.reset_multi()
318
+
319
+ done = False
320
+ steps = 0
321
+ while not done and steps < config['max_steps'] + 5:
322
+ actions = {i: GridAction() for i in range(config['num_agents'])}
323
+ result = env.step_multi(actions)
324
+ done = result.done
325
+ steps += 1
326
+
327
+ self.assertTrue(done, "Episode should terminate")
328
+ self.assertLessEqual(steps, config['max_steps'] + 1)
329
+
330
+ def test_backward_compatibility(self):
331
+ """Single-agent reset/step should still work after multi-agent changes."""
332
+ for task_id in TASKS:
333
+ config = task(task_id)
334
+ env = OpenGridEnv(config)
335
+ obs = env.reset()
336
+ self.assertGreater(len(obs.buses), 0,
337
+ f"No buses in {task_id}")
338
+
339
+ obs, reward, done, info = env.step(GridAction())
340
+ self.assertEqual(obs.timestep, 1)
341
+ self.assertIsInstance(reward.value, float)
342
+
343
+
344
+ if __name__ == '__main__':
345
+ unittest.main()
tests/test_solver.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for core simulation components:
3
+ - DC power flow solver
4
+ - Environment lifecycle (reset, step, terminate)
5
+ - Grading system (scoring, bounds, reproducibility)
6
+ - Baseline heuristic policy
7
+ """
8
+
9
+ import copy
10
+ import unittest
11
+
12
+ import numpy as np
13
+
14
+ from src.physics import DCSolver, IslandedException
15
+ from src.environment import OpenGridEnv
16
+ from src.tasks import TASKS
17
+ from src.models import GridAction, BusAdjustment
18
+ from src.grader import RobustnessGrader, compute_analytical_ceiling
19
+ from src.baseline import heuristic_policy
20
+
21
+
22
+ def task(task_id: str):
23
+ """Get a deep-copied task config to prevent cross-test contamination."""
24
+ return copy.deepcopy(TASKS[task_id])
25
+
26
+
27
+ class TestDCSolver(unittest.TestCase):
28
+ def setUp(self):
29
+ self.num_buses = 3
30
+ self.lines = [
31
+ {'id': 'L01', 'from': 0, 'to': 1, 'susceptance': 100, 'connected': True},
32
+ {'id': 'L12', 'from': 1, 'to': 2, 'susceptance': 50, 'connected': True},
33
+ {'id': 'L02', 'from': 0, 'to': 2, 'susceptance': 100, 'connected': True}
34
+ ]
35
+ self.solver = DCSolver(self.num_buses)
36
+ self.solver.update_grid(self.lines)
37
+
38
+ def test_power_flow_balance(self):
39
+ """Slack bus should absorb any generation/load imbalance."""
40
+ p_inj = np.array([0.0, 50.0, -100.0])
41
+ theta, flows, slack_inj = self.solver.solve(p_inj)
42
+
43
+ # Check that flows are computed
44
+ self.assertIn('L01', flows)
45
+ self.assertIn('L02', flows)
46
+
47
+ def test_islanding_detection(self):
48
+ """Disconnecting lines to island bus 2 should raise IslandedException."""
49
+ with self.assertRaises(IslandedException):
50
+ broken_lines = [
51
+ {'id': 'L01', 'from': 0, 'to': 1, 'susceptance': 100, 'connected': True},
52
+ {'id': 'L12', 'from': 1, 'to': 2, 'susceptance': 50, 'connected': False},
53
+ {'id': 'L02', 'from': 0, 'to': 2, 'susceptance': 100, 'connected': False}
54
+ ]
55
+ self.solver.update_grid(broken_lines)
56
+
57
+ def test_slack_injection_returned(self):
58
+ """solve() should return slack bus injection as third element."""
59
+ p_inj = np.array([0.0, 50.0, -100.0])
60
+ result = self.solver.solve(p_inj)
61
+ self.assertEqual(len(result), 3)
62
+ theta, flows, slack_inj = result
63
+ # Slack should inject ~50 MW to cover the deficit
64
+ self.assertAlmostEqual(slack_inj, 50.0, places=0)
65
+
66
+ def test_solve_before_update_raises(self):
67
+ """Calling solve() on a fresh solver should raise RuntimeError."""
68
+ fresh = DCSolver(3)
69
+ with self.assertRaises(RuntimeError):
70
+ fresh.solve(np.array([0.0, 10.0, -10.0]))
71
+
72
+ def test_invalid_bus_index_raises(self):
73
+ """Lines referencing out-of-range bus IDs should raise ValueError."""
74
+ bad_lines = [
75
+ {'id': 'L_bad', 'from': 0, 'to': 99, 'susceptance': 50, 'connected': True},
76
+ ]
77
+ solver = DCSolver(3)
78
+ with self.assertRaises(ValueError):
79
+ solver.update_grid(bad_lines)
80
+
81
+
82
+ class TestEnvironment(unittest.TestCase):
83
+ def test_reset_returns_observation(self):
84
+ """reset() should return a valid GridObservation."""
85
+ env = OpenGridEnv(task("task_easy"))
86
+ obs = env.reset()
87
+ self.assertEqual(obs.timestep, 0)
88
+ self.assertGreater(len(obs.buses), 0, "Observation should have buses")
89
+ self.assertGreater(len(obs.lines), 0, "Observation should have lines")
90
+
91
+ def test_step_returns_tuple(self):
92
+ """step() should return (obs, reward, done, info)."""
93
+ env = OpenGridEnv(task("task_easy"))
94
+ env.reset()
95
+ obs, reward, done, info = env.step(GridAction())
96
+ self.assertEqual(obs.timestep, 1)
97
+ self.assertIsInstance(reward.value, float)
98
+ self.assertIsInstance(done, bool)
99
+
100
+ def test_reproducibility(self):
101
+ """Running the same task twice should produce identical initial observations."""
102
+ env1 = OpenGridEnv(task("task_easy"))
103
+ obs1 = env1.reset()
104
+
105
+ env2 = OpenGridEnv(task("task_easy"))
106
+ obs2 = env2.reset()
107
+
108
+ self.assertEqual(obs1.grid_frequency, obs2.grid_frequency)
109
+ self.assertEqual(len(obs1.buses), len(obs2.buses))
110
+
111
+ def test_episode_terminates(self):
112
+ """Episode should end after max_steps."""
113
+ config = task("task_easy")
114
+ env = OpenGridEnv(config)
115
+ env.reset()
116
+ done = False
117
+ steps = 0
118
+ while not done and steps < 100:
119
+ _, _, done, _ = env.step(GridAction())
120
+ steps += 1
121
+ self.assertTrue(done, "Episode should terminate")
122
+ self.assertLessEqual(steps, config["max_steps"])
123
+
124
+ def test_frequency_reasonable(self):
125
+ """Frequency should stay in a reasonable range for do-nothing agent."""
126
+ env = OpenGridEnv(task("task_easy"))
127
+ obs = env.reset()
128
+ for _ in range(10):
129
+ obs, _, done, _ = env.step(GridAction())
130
+ if done:
131
+ break
132
+ self.assertGreater(obs.grid_frequency, 40.0,
133
+ "Frequency below reasonable minimum")
134
+ self.assertLess(obs.grid_frequency, 60.0,
135
+ "Frequency above reasonable maximum")
136
+
137
+
138
+ class TestGrader(unittest.TestCase):
139
+ def test_grader_score_range(self):
140
+ """Grader should return score strictly in (0, 1) — never 0.0 or 1.0."""
141
+ grader = RobustnessGrader(task("task_easy"))
142
+ result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
143
+ self.assertGreater(result["score"], 0.0)
144
+ self.assertLess(result["score"], 1.0)
145
+
146
+ def test_grader_all_tasks(self):
147
+ """Grader should work on all registered tasks."""
148
+ for task_id, config in TASKS.items():
149
+ grader = RobustnessGrader(copy.deepcopy(config))
150
+ result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
151
+ self.assertIn("score", result, f"Missing 'score' for {task_id}")
152
+ self.assertIn("avg_raw_reward", result,
153
+ f"Missing 'avg_raw_reward' for {task_id}")
154
+
155
+
156
+ class TestBaseline(unittest.TestCase):
157
+ def test_heuristic_returns_valid_action(self):
158
+ """Heuristic policy should return a valid GridAction."""
159
+ env = OpenGridEnv(task("task_easy"))
160
+ obs = env.reset()
161
+ action = heuristic_policy(obs)
162
+ self.assertIsInstance(action, GridAction)
163
+
164
+
165
+ class TestReproducibility(unittest.TestCase):
166
+ def test_floor_deterministic(self):
167
+ """Two calls to _estimate_bounds should produce identical floors (seeded RNG)."""
168
+ grader1 = RobustnessGrader(task("task_easy"))
169
+ grader1._estimate_bounds(n_samples=3)
170
+
171
+ grader2 = RobustnessGrader(task("task_easy"))
172
+ grader2._estimate_bounds(n_samples=3)
173
+
174
+ self.assertEqual(grader1.reward_floor, grader2.reward_floor,
175
+ "Floor should be deterministic with same seed")
176
+
177
+ def test_ceiling_is_analytical(self):
178
+ """Ceiling should be max_steps * 1.2, not an empirical estimate."""
179
+ config = task("task_easy")
180
+ grader = RobustnessGrader(config)
181
+ bounds = grader.get_bounds()
182
+ expected_ceiling = compute_analytical_ceiling(config["max_steps"])
183
+ self.assertEqual(bounds["reward_ceiling"], expected_ceiling,
184
+ "Ceiling should match analytical formula")
185
+
186
+ def test_heuristic_score_below_one(self):
187
+ """With analytical ceiling, heuristic should score < 1.0 (not degenerate)."""
188
+ grader = RobustnessGrader(task("task_easy"))
189
+ result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
190
+ self.assertLess(result["score"], 1.0)
191
+ self.assertGreater(result["score"], 0.0)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ unittest.main()
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Training module for OpenGrid GRPO pipeline
training/opengrid_grpo_colab.ipynb ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🔋 OpenGrid — GRPO Training Notebook\n",
8
+ "\n",
9
+ "**Multi-Agent RL for Power Grid Operations**\n",
10
+ "\n",
11
+ "This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
12
+ "\n",
13
+ "- **Environment**: OpenGrid — multi-agent POMDP with safety layer & oversight agent\n",
14
+ "- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
15
+ "- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
16
+ "\n",
17
+ "⚡ **Runtime**: Select `T4 GPU` from Runtime → Change runtime type"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## 1. Install Dependencies"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "%%capture\n",
34
+ "!pip install unsloth\n",
35
+ "!pip install --no-deps trl peft accelerate bitsandbytes\n",
36
+ "!pip install fastapi uvicorn pydantic numpy networkx matplotlib openai httpx datasets"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "## 2. Clone OpenGrid Repository"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import os\n",
53
+ "\n",
54
+ "# ⚠️ UPDATE THIS with your actual repo URL\n",
55
+ "REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
56
+ "\n",
57
+ "if not os.path.exists(\"opengrid\"):\n",
58
+ " !git clone {REPO_URL} opengrid\n",
59
+ "else:\n",
60
+ " !cd opengrid && git pull\n",
61
+ "\n",
62
+ "os.chdir(\"opengrid\")\n",
63
+ "print(f\"Working directory: {os.getcwd()}\")\n",
64
+ "!ls -la"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## 3. Verify GPU & Environment"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "import torch\n",
81
+ "print(f\"PyTorch: {torch.__version__}\")\n",
82
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
83
+ "if torch.cuda.is_available():\n",
84
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
85
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
86
+ "else:\n",
87
+ " print(\"⚠️ No GPU detected! Go to Runtime → Change runtime type → T4 GPU\")"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "# Verify OpenGrid imports work\n",
97
+ "import sys\n",
98
+ "sys.path.insert(0, '.')\n",
99
+ "\n",
100
+ "from src.environment import OpenGridEnv\n",
101
+ "from src.tasks import TASKS\n",
102
+ "from src.models import GridAction, BusAdjustment\n",
103
+ "\n",
104
+ "print(f\"Available tasks: {list(TASKS.keys())}\")\n",
105
+ "for tid, cfg in TASKS.items():\n",
106
+ " print(f\" {tid}: {cfg['num_buses']} buses, {cfg['num_agents']} agents, {cfg.get('difficulty','')}\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "## 4. Run Test Mode (Pipeline Verification)"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "!python training/train_grpo.py --test-mode"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "## 5. Baseline Evaluation (Before Training)\n",
130
+ "\n",
131
+ "Run the heuristic policy to get baseline scores. We'll compare against this after training."
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "import json\n",
141
+ "import re\n",
142
+ "import numpy as np\n",
143
+ "from src.environment import OpenGridEnv\n",
144
+ "from src.tasks import TASKS\n",
145
+ "from src.models import GridAction, BusAdjustment\n",
146
+ "from training.train_grpo import (\n",
147
+ " rollout_multi_agent, format_observation_prompt, extract_action\n",
148
+ ")\n",
149
+ "\n",
150
+ "def heuristic_generate(prompt):\n",
151
+ " \"\"\"Simple proportional controller as baseline.\"\"\"\n",
152
+ " freq_match = re.search(r'Frequency: ([\\d.]+)', prompt)\n",
153
+ " freq = float(freq_match.group(1)) if freq_match else 50.0\n",
154
+ " error = 50.0 - freq\n",
155
+ " delta = max(-20, min(20, error * 10))\n",
156
+ " bus_match = re.search(r'Bus (\\d+) \\((generator|battery|slack)\\)', prompt)\n",
157
+ " if bus_match:\n",
158
+ " return json.dumps({\"bus_adjustments\": [{\"bus_id\": int(bus_match.group(1)), \"delta\": round(delta, 1)}], \"topology_actions\": []})\n",
159
+ " return json.dumps({\"bus_adjustments\": [], \"topology_actions\": []})\n",
160
+ "\n",
161
+ "# Evaluate baseline on all tasks\n",
162
+ "baseline_results = {}\n",
163
+ "for task_id in [\"task_easy\", \"task_medium\", \"task_karnataka\"]:\n",
164
+ " if task_id not in TASKS:\n",
165
+ " continue\n",
166
+ " config = TASKS[task_id]\n",
167
+ " rewards = []\n",
168
+ " import copy\n",
169
+ " for ep in range(5):\n",
170
+ " ep_config = copy.deepcopy(config)\n",
171
+ " ep_config['seed'] = 42 + ep\n",
172
+ " env = OpenGridEnv(ep_config)\n",
173
+ " result = rollout_multi_agent(env, heuristic_generate, ep_config)\n",
174
+ " rewards.append(result['total_reward'])\n",
175
+ " baseline_results[task_id] = {\n",
176
+ " \"avg_reward\": np.mean(rewards),\n",
177
+ " \"std_reward\": np.std(rewards),\n",
178
+ " \"rewards\": rewards\n",
179
+ " }\n",
180
+ " print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\")\n",
181
+ "\n",
182
+ "# Save baseline for later comparison\n",
183
+ "import pickle\n",
184
+ "os.makedirs(\"training/outputs\", exist_ok=True)\n",
185
+ "with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
186
+ " pickle.dump(baseline_results, f)\n",
187
+ "print(\"\\n✅ Baseline scores saved.\")"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "## 6. Load Model with Unsloth (4-bit Quantized)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "from unsloth import FastLanguageModel\n",
204
+ "\n",
205
+ "MODEL_NAME = \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\"\n",
206
+ "\n",
207
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
208
+ " model_name=MODEL_NAME,\n",
209
+ " max_seq_length=2048,\n",
210
+ " load_in_4bit=True,\n",
211
+ ")\n",
212
+ "\n",
213
+ "model = FastLanguageModel.get_peft_model(\n",
214
+ " model,\n",
215
+ " r=16,\n",
216
+ " lora_alpha=16,\n",
217
+ " lora_dropout=0,\n",
218
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
219
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
220
+ ")\n",
221
+ "\n",
222
+ "if tokenizer.pad_token is None:\n",
223
+ " tokenizer.pad_token = tokenizer.eos_token\n",
224
+ "\n",
225
+ "print(f\"✅ Model loaded: {MODEL_NAME}\")\n",
226
+ "print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {},
232
+ "source": [
233
+ "## 7. Generate Training Prompts from Environment"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "import copy\n",
243
+ "import json as _json\n",
244
+ "import numpy as np\n",
245
+ "from training.train_grpo import SYSTEM_PROMPT, format_observation_prompt\n",
246
+ "\n",
247
+ "TRAIN_TASK = \"task_karnataka\" # Change to task_easy for faster first run\n",
248
+ "NUM_EPISODES = 30\n",
249
+ "\n",
250
+ "task_config = TASKS[TRAIN_TASK]\n",
251
+ "base_seed = task_config.get('seed', 42)\n",
252
+ "\n",
253
+ "prompts = []\n",
254
+ "obs_contexts = [] # stored as JSON strings to satisfy PyArrow schema inference\n",
255
+ "\n",
256
+ "for episode in range(NUM_EPISODES):\n",
257
+ " ep_config = copy.deepcopy(task_config)\n",
258
+ " ep_config['seed'] = base_seed + episode\n",
259
+ " env = OpenGridEnv(ep_config)\n",
260
+ " zone_obs = env.reset_multi()\n",
261
+ "\n",
262
+ " for t in range(min(10, task_config['max_steps'])):\n",
263
+ " for agent_id, obs in zone_obs.items():\n",
264
+ " # model_dump_json() → json.loads() ensures all keys are strings\n",
265
+ " obs_dict = _json.loads(obs.model_dump_json())\n",
266
+ " prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
267
+ " messages = [\n",
268
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
269
+ " {\"role\": \"user\", \"content\": prompt_text},\n",
270
+ " ]\n",
271
+ " formatted = tokenizer.apply_chat_template(\n",
272
+ " messages, tokenize=False, add_generation_prompt=True\n",
273
+ " )\n",
274
+ " prompts.append(formatted)\n",
275
+ " # Store as JSON string — flat scalar, no schema-inference issues\n",
276
+ " obs_contexts.append(_json.dumps(obs_dict))\n",
277
+ "\n",
278
+ " # Advance env with diverse random actions (no slack bus)\n",
279
+ " random_actions = {}\n",
280
+ " for aid in range(env.num_agents):\n",
281
+ " zone_buses = task_config['zone_bus_ids'].get(aid, [])\n",
282
+ " controllable = [bid for bid in zone_buses\n",
283
+ " if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')\n",
284
+ " in ['generator', 'battery']]\n",
285
+ " adj = []\n",
286
+ " if controllable:\n",
287
+ " bid = np.random.choice(controllable)\n",
288
+ " adj = [BusAdjustment(bus_id=int(bid), delta=float(np.random.uniform(-15, 15)))]\n",
289
+ " random_actions[aid] = GridAction(bus_adjustments=adj)\n",
290
+ "\n",
291
+ " result = env.step_multi(random_actions)\n",
292
+ " if result.done:\n",
293
+ " break\n",
294
+ " zone_obs = result.observations\n",
295
+ "\n",
296
+ "print(f\"✅ Generated {len(prompts)} training prompts\")\n",
297
+ "print(f\"\\nSample prompt (first 400 chars):\")\n",
298
+ "print(prompts[0][:400])"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "metadata": {},
304
+ "source": [
305
+ "## 8. Define GRPO Reward Function"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "import json as _json\n",
315
+ "from training.train_grpo import compute_grpo_reward, extract_action\n",
316
+ "\n",
317
+ "def reward_fn(completions, obs_context=None, **kwargs):\n",
318
+ " \"\"\"GRPO-compatible reward function for OpenGrid.\n",
319
+ " obs_context arrives as JSON strings from the dataset column.\n",
320
+ " \"\"\"\n",
321
+ " texts = []\n",
322
+ " for c in completions:\n",
323
+ " if isinstance(c, list):\n",
324
+ " text = c[-1]['content'] if c else \"\"\n",
325
+ " else:\n",
326
+ " text = str(c)\n",
327
+ " texts.append(text)\n",
328
+ "\n",
329
+ " # Deserialize JSON strings → dicts for the reward scorer\n",
330
+ " if obs_context is None:\n",
331
+ " batch_obs = [None] * len(texts)\n",
332
+ " else:\n",
333
+ " batch_obs = [\n",
334
+ " _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
335
+ " for ctx in obs_context\n",
336
+ " ]\n",
337
+ " return compute_grpo_reward(texts, batch_obs)\n",
338
+ "\n",
339
+ "# Quick sanity test\n",
340
+ "test_rewards = reward_fn([\n",
341
+ " '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
342
+ " 'invalid json here',\n",
343
+ "])\n",
344
+ "print(f\"Test rewards: {test_rewards}\")\n",
345
+ "assert len(test_rewards) == 2, \"reward_fn must return one score per completion\"\n",
346
+ "print(\"✅ reward_fn OK\")"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "metadata": {},
352
+ "source": [
353
+ "## 9. Train with GRPO 🚀"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "from trl import GRPOTrainer, GRPOConfig\n",
363
+ "from datasets import Dataset\n",
364
+ "\n",
365
+ "_cuda_ok = torch.cuda.is_available()\n",
366
+ "_bf16 = _cuda_ok and torch.cuda.is_bf16_supported()\n",
367
+ "_fp16 = _cuda_ok and not _bf16\n",
368
+ "\n",
369
+ "grpo_config = GRPOConfig(\n",
370
+ " output_dir=\"training/outputs/grpo_checkpoints\",\n",
371
+ " num_train_epochs=1,\n",
372
+ " per_device_train_batch_size=2,\n",
373
+ " gradient_accumulation_steps=4,\n",
374
+ " learning_rate=5e-6,\n",
375
+ " logging_steps=5,\n",
376
+ " save_steps=50,\n",
377
+ " max_completion_length=256,\n",
378
+ " num_generations=4,\n",
379
+ " report_to=\"none\",\n",
380
+ " remove_unused_columns=False,\n",
381
+ " bf16=_bf16,\n",
382
+ " fp16=_fp16,\n",
383
+ ")\n",
384
+ "\n",
385
+ "# obs_contexts are JSON strings — PyArrow handles flat strings with no issues\n",
386
+ "train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
387
+ "print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
388
+ "\n",
389
+ "trainer = GRPOTrainer(\n",
390
+ " model=model,\n",
391
+ " args=grpo_config,\n",
392
+ " train_dataset=train_dataset,\n",
393
+ " reward_funcs=reward_fn,\n",
394
+ " processing_class=tokenizer,\n",
395
+ ")\n",
396
+ "\n",
397
+ "print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
398
+ "print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
399
+ "print(\"\\n🚀 Starting GRPO training...\")\n",
400
+ "\n",
401
+ "train_result = trainer.train()\n",
402
+ "\n",
403
+ "print(\"\\n✅ Training complete!\")\n",
404
+ "print(f\" Total steps: {trainer.state.global_step}\")"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "markdown",
409
+ "metadata": {},
410
+ "source": [
411
+ "## 10. Save Trained Model"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "OUTPUT_PATH = \"training/outputs/trained_model\"\n",
421
+ "trainer.save_model(OUTPUT_PATH)\n",
422
+ "tokenizer.save_pretrained(OUTPUT_PATH)\n",
423
+ "print(f\"✅ Model saved to {OUTPUT_PATH}\")"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "markdown",
428
+ "metadata": {},
429
+ "source": [
430
+ "## 11. Evaluate Trained Model (After Training)"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": null,
436
+ "metadata": {},
437
+ "outputs": [],
438
+ "source": [
439
+ "from transformers import pipeline\n",
440
+ "\n",
441
+ "# Create generation function from trained model\n",
442
+ "FastLanguageModel.for_inference(model)\n",
443
+ "\n",
444
+ "def trained_generate(prompt):\n",
445
+ " \"\"\"Generate action using the trained model.\"\"\"\n",
446
+ " messages = [\n",
447
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
448
+ " {\"role\": \"user\", \"content\": prompt},\n",
449
+ " ]\n",
450
+ " formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
451
+ " inputs = tokenizer(formatted, return_tensors=\"pt\").to(model.device)\n",
452
+ " with torch.no_grad():\n",
453
+ " outputs = model.generate(\n",
454
+ " **inputs,\n",
455
+ " max_new_tokens=256,\n",
456
+ " temperature=0.3,\n",
457
+ " do_sample=True,\n",
458
+ " )\n",
459
+ " response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
460
+ " return response\n",
461
+ "\n",
462
+ "# Evaluate on same tasks as baseline\n",
463
+ "trained_results = {}\n",
464
+ "for task_id in [\"task_easy\", \"task_medium\", \"task_karnataka\"]:\n",
465
+ " if task_id not in TASKS:\n",
466
+ " continue\n",
467
+ " config = TASKS[task_id]\n",
468
+ " rewards = []\n",
469
+ " import copy\n",
470
+ " for ep in range(5):\n",
471
+ " ep_config = copy.deepcopy(config)\n",
472
+ " ep_config['seed'] = 42 + ep\n",
473
+ " env = OpenGridEnv(ep_config)\n",
474
+ " result = rollout_multi_agent(env, trained_generate, ep_config)\n",
475
+ " rewards.append(result['total_reward'])\n",
476
+ " print(f\" {task_id} ep{ep}: reward={result['total_reward']:.2f}, blackout={result['is_blackout']}\")\n",
477
+ " trained_results[task_id] = {\n",
478
+ " \"avg_reward\": np.mean(rewards),\n",
479
+ " \"std_reward\": np.std(rewards),\n",
480
+ " \"rewards\": rewards\n",
481
+ " }\n",
482
+ " print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\\n\")"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "metadata": {},
488
+ "source": [
489
+ "## 12. Generate Before/After Plots 📊"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "import matplotlib.pyplot as plt\n",
499
+ "import pickle\n",
500
+ "\n",
501
+ "# Load baseline\n",
502
+ "with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
503
+ " baseline_results = pickle.load(f)\n",
504
+ "\n",
505
+ "# ── Plot 1: Before vs After Bar Chart ──\n",
506
+ "common_tasks = [t for t in baseline_results if t in trained_results]\n",
507
+ "fig, ax = plt.subplots(figsize=(10, 6))\n",
508
+ "x = np.arange(len(common_tasks))\n",
509
+ "width = 0.35\n",
510
+ "\n",
511
+ "before_vals = [baseline_results[t]['avg_reward'] for t in common_tasks]\n",
512
+ "after_vals = [trained_results[t]['avg_reward'] for t in common_tasks]\n",
513
+ "\n",
514
+ "bars1 = ax.bar(x - width/2, before_vals, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.8)\n",
515
+ "bars2 = ax.bar(x + width/2, after_vals, width, label='GRPO Trained', color='#00d4aa', alpha=0.8)\n",
516
+ "\n",
517
+ "ax.set_xlabel('Task', fontsize=12)\n",
518
+ "ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
519
+ "ax.set_title('OpenGrid — GRPO Training: Before vs After', fontsize=14, fontweight='bold')\n",
520
+ "ax.set_xticks(x)\n",
521
+ "ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
522
+ "ax.legend(fontsize=11)\n",
523
+ "ax.grid(True, alpha=0.3, axis='y')\n",
524
+ "\n",
525
+ "# Fix label positioning for negative bar heights\n",
526
+ "for bars in (bars1, bars2):\n",
527
+ " for bar in bars:\n",
528
+ " h = bar.get_height()\n",
529
+ " ax.text(\n",
530
+ " bar.get_x() + bar.get_width() / 2.,\n",
531
+ " h + (2 if h >= 0 else -5),\n",
532
+ " f'{h:.1f}',\n",
533
+ " ha='center', va='bottom' if h >= 0 else 'top', fontsize=10\n",
534
+ " )\n",
535
+ "\n",
536
+ "plt.tight_layout()\n",
537
+ "plt.savefig('training/outputs/before_after.png', dpi=150)\n",
538
+ "plt.show()\n",
539
+ "print(\"✅ Saved: training/outputs/before_after.png\")"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "metadata": {},
546
+ "outputs": [],
547
+ "source": [
548
+ "# ── Plot 2: Training Reward Curve ──\n",
549
+ "history = trainer.state.log_history\n",
550
+ "\n",
551
+ "steps = [h['step'] for h in history if 'loss' in h]\n",
552
+ "losses = [h['loss'] for h in history if 'loss' in h]\n",
553
+ "\n",
554
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
555
+ "ax.plot(steps, losses, color='#ff6b6b', linewidth=1.5, alpha=0.6, label='Loss')\n",
556
+ "if len(losses) > 10:\n",
557
+ " window = min(20, len(losses) // 3)\n",
558
+ " smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')\n",
559
+ " ax.plot(steps[window-1:], smoothed, color='#ff6b6b', linewidth=2.5, label=f'Smoothed (w={window})')\n",
560
+ "\n",
561
+ "ax.set_xlabel('Training Step', fontsize=12)\n",
562
+ "ax.set_ylabel('Loss', fontsize=12)\n",
563
+ "ax.set_title('OpenGrid GRPO — Training Loss', fontsize=14, fontweight='bold')\n",
564
+ "ax.legend()\n",
565
+ "ax.grid(True, alpha=0.3)\n",
566
+ "plt.tight_layout()\n",
567
+ "plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
568
+ "plt.show()\n",
569
+ "print(\"✅ Saved: training/outputs/training_loss.png\")"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "markdown",
574
+ "metadata": {},
575
+ "source": [
576
+ "## 13. Summary & Next Steps\n",
577
+ "\n",
578
+ "### Results Table"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": null,
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "print(\"=\"*60)\n",
588
+ "print(\" OpenGrid GRPO Training — Results Summary\")\n",
589
+ "print(\"=\"*60)\n",
590
+ "\n",
591
+ "# Rebuild common_tasks in case Cell 12 was skipped\n",
592
+ "common_tasks = [t for t in baseline_results if t in trained_results]\n",
593
+ "\n",
594
+ "print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'Δ':>10}\")\n",
595
+ "print(\"-\"*60)\n",
596
+ "for t in common_tasks:\n",
597
+ " b = baseline_results[t]['avg_reward']\n",
598
+ " a = trained_results[t]['avg_reward']\n",
599
+ " delta = a - b\n",
600
+ " arrow = '↑' if delta > 0 else '↓'\n",
601
+ " print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
602
+ "print(\"=\"*60)"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "# Download plots for your README\n",
612
+ "from google.colab import files\n",
613
+ "files.download('training/outputs/before_after.png')\n",
614
+ "files.download('training/outputs/training_loss.png')"
615
+ ]
616
+ }
617
+ ],
618
+ "metadata": {
619
+ "accelerator": "GPU",
620
+ "colab": {
621
+ "gpuType": "T4",
622
+ "provenance": []
623
+ },
624
+ "kernelspec": {
625
+ "display_name": "Python 3",
626
+ "name": "python3"
627
+ },
628
+ "language_info": {
629
+ "name": "python",
630
+ "version": "3.10.0"
631
+ }
632
+ },
633
+ "nbformat": 4,
634
+ "nbformat_minor": 0
635
+ }
training/train_grpo.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenGrid GRPO Training Script
3
+ ==============================
4
+ Uses TRL's GRPOTrainer to train an LLM for multi-agent power grid control.
5
+
6
+ The LLM receives grid observations (partial, per-zone) as text prompts,
7
+ generates JSON actions, and is trained via GRPO to maximize grid stability rewards.
8
+
9
+ Compatible with:
10
+ - Unsloth for 4-bit quantized training (recommended)
11
+ - HuggingFace TRL GRPOTrainer
12
+ - Colab / HF Spaces with GPU
13
+
14
+ Usage:
15
+ # Quick test (no GPU needed, just verifies the pipeline)
16
+ python training/train_grpo.py --test-mode
17
+
18
+ # Full training on GPU
19
+ python training/train_grpo.py --model Qwen/Qwen2.5-1.5B-Instruct --epochs 3
20
+
21
+ # With Unsloth quantization (faster, less memory)
22
+ python training/train_grpo.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit --use-unsloth
23
+ """
24
+
25
+ import argparse
26
+ import copy
27
+ import json
28
+ import random
29
+ import sys
30
+ import os
31
+ import re
32
+ import time
33
+ from pathlib import Path
34
+
35
+ # Add project root to path
36
+ sys.path.insert(0, str(Path(__file__).parent.parent))
37
+
38
+ import numpy as np
39
+ import matplotlib
40
+ matplotlib.use('Agg')
41
+ import matplotlib.pyplot as plt
42
+
43
+ from src.environment import OpenGridEnv
44
+ from src.tasks import TASKS
45
+ from src.models import GridAction, BusAdjustment, TopologyAction
46
+
47
+
48
+ # ============================================================================
49
+ # Prompt Engineering
50
+ # ============================================================================
51
+
52
+ SYSTEM_PROMPT = """You are an AI power grid operator for the Karnataka Power Transmission Corporation (KPTCL).
53
+ You manage one zone of a multi-agent grid. Your goal: keep frequency at 50.0 Hz, avoid line overloads, and prevent blackouts.
54
+
55
+ You receive partial observations of your zone and must output a JSON action.
56
+ Respond ONLY with valid JSON matching this schema:
57
+ {"bus_adjustments": [{"bus_id": <int>, "delta": <float>}], "topology_actions": []}
58
+
59
+ Rules:
60
+ - Positive delta = inject more power (discharge battery / increase generation)
61
+ - Negative delta = reduce injection (charge battery / decrease generation)
62
+ - Only adjust buses in YOUR zone
63
+ - Keep frequency close to 50.0 Hz
64
+ - Avoid overloading lines (rho > 1.0 is dangerous)"""
65
+
66
+
67
+ def format_observation_prompt(obs_dict: dict, zone_name: str = "") -> str:
68
+ """Convert a zone observation to a text prompt for the LLM."""
69
+ freq = obs_dict.get('grid_frequency', 50.0)
70
+ timestep = obs_dict.get('timestep', 0)
71
+
72
+ prompt = f"[Zone: {zone_name}] Step {timestep} | Frequency: {freq:.3f} Hz"
73
+
74
+ freq_error = freq - 50.0
75
+ if abs(freq_error) > 0.3:
76
+ prompt += f" [!] CRITICAL: {freq_error:+.3f} Hz deviation!"
77
+ elif abs(freq_error) > 0.1:
78
+ prompt += f" WARNING: {freq_error:+.3f} Hz deviation"
79
+
80
+ # Local buses
81
+ buses = obs_dict.get('local_buses', [])
82
+ if buses:
83
+ prompt += "\n\nYour buses:"
84
+ for b in buses:
85
+ bus_info = f" Bus {b['id']} ({b['type']}): {b['p_injection']:.1f} MW"
86
+ if b['type'] == 'battery':
87
+ bus_info += f" | SoC: {b['soc']:.1f} MWh"
88
+ prompt += f"\n{bus_info}"
89
+
90
+ # Lines
91
+ all_lines = obs_dict.get('internal_lines', []) + obs_dict.get('boundary_lines', [])
92
+ overloaded = [l for l in all_lines if l.get('rho', 0) > 0.8 and l.get('connected', True)]
93
+ if overloaded:
94
+ prompt += "\n\n[!] Stressed lines:"
95
+ for l in overloaded:
96
+ prompt += f"\n {l['id']}: {l['rho']:.2f} loading ({l['flow']:.1f} MW)"
97
+
98
+ # Neighbor signals
99
+ neighbors = obs_dict.get('neighbor_signals', {})
100
+ if neighbors:
101
+ prompt += "\n\nNeighbor zones (avg injection):"
102
+ for nid, val in neighbors.items():
103
+ prompt += f"\n Zone {nid}: {val:.1f} MW"
104
+
105
+ # Zone summary
106
+ zone_load = obs_dict.get('zone_load_mw', 0)
107
+ zone_gen = obs_dict.get('zone_gen_mw', 0)
108
+ if zone_load or zone_gen:
109
+ prompt += f"\n\nZone balance: Gen={zone_gen:.1f} MW, Load={zone_load:.1f} MW, Net={zone_gen-zone_load:.1f} MW"
110
+
111
+ prompt += "\n\nWhat action do you take? Respond with JSON only."
112
+ return prompt
113
+
114
+
115
+ def extract_action(text: str) -> GridAction:
116
+ """Parse LLM output to a GridAction, with fallback for malformed JSON."""
117
+ text = text.strip()
118
+
119
+ # Try to find JSON in the response
120
+ json_match = re.search(r'\{[\s\S]*\}', text)
121
+ if json_match:
122
+ try:
123
+ data = json.loads(json_match.group())
124
+ return GridAction(
125
+ bus_adjustments=[
126
+ BusAdjustment(**a) for a in data.get('bus_adjustments', [])
127
+ ],
128
+ topology_actions=[
129
+ TopologyAction(**t) for t in data.get('topology_actions', [])
130
+ ],
131
+ )
132
+ except (json.JSONDecodeError, Exception):
133
+ pass
134
+
135
+ # Fallback: no-op action
136
+ return GridAction()
137
+
138
+
139
+ # ============================================================================
140
+ # Environment Rollout
141
+ # ============================================================================
142
+
143
+ def rollout_single_agent(env: OpenGridEnv, generate_fn, task_config: dict) -> dict:
144
+ """Run one episode in single-agent mode. Returns episode data."""
145
+ obs = env.reset()
146
+ total_reward = 0.0
147
+ rewards = []
148
+ steps = 0
149
+ is_blackout = False
150
+
151
+ for t in range(task_config['max_steps']):
152
+ obs_dict = obs.model_dump()
153
+ prompt = format_observation_prompt(obs_dict, zone_name="Full_Grid")
154
+
155
+ response = generate_fn(prompt)
156
+ action = extract_action(response)
157
+
158
+ obs, reward, done, info = env.step(action)
159
+ total_reward += reward.value
160
+ rewards.append(reward.value)
161
+ steps += 1
162
+
163
+ if done:
164
+ is_blackout = info.is_blackout
165
+ break
166
+
167
+ return {
168
+ "total_reward": total_reward,
169
+ "rewards": rewards,
170
+ "steps": steps,
171
+ "is_blackout": is_blackout,
172
+ "avg_reward": total_reward / max(steps, 1),
173
+ }
174
+
175
+
176
+ def rollout_multi_agent(env: OpenGridEnv, generate_fn, task_config: dict) -> dict:
177
+ """Run one episode in multi-agent mode. Returns episode data."""
178
+ zone_obs = env.reset_multi()
179
+ total_reward = 0.0
180
+ rewards = []
181
+ per_agent_rewards = {i: [] for i in range(env.num_agents)}
182
+ steps = 0
183
+ safety_interventions = 0
184
+ is_blackout = False
185
+
186
+ for t in range(task_config['max_steps']):
187
+ agent_actions = {}
188
+ for agent_id, obs in zone_obs.items():
189
+ obs_dict = obs.model_dump()
190
+ prompt = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
191
+
192
+ response = generate_fn(prompt)
193
+ action = extract_action(response)
194
+ agent_actions[agent_id] = action
195
+
196
+ result = env.step_multi(agent_actions)
197
+
198
+ total_reward += result.team_reward
199
+ rewards.append(result.team_reward)
200
+ for aid, r in result.rewards.items():
201
+ per_agent_rewards[aid].append(r.value)
202
+
203
+ # safety_reports is Dict[int, SafetyReport] — iterate values
204
+ safety_interventions += sum(
205
+ 1 for sr in result.safety_reports.values() if sr.was_corrected
206
+ )
207
+ steps += 1
208
+
209
+ if result.done:
210
+ is_blackout = result.info.is_blackout
211
+ break
212
+
213
+ zone_obs = result.observations
214
+
215
+ return {
216
+ "total_reward": total_reward,
217
+ "rewards": rewards,
218
+ "per_agent_rewards": per_agent_rewards,
219
+ "steps": steps,
220
+ "is_blackout": is_blackout,
221
+ "safety_interventions": safety_interventions,
222
+ "avg_reward": total_reward / max(steps, 1),
223
+ }
224
+
225
+
226
+ # ============================================================================
227
+ # GRPO Reward Functions
228
+ # ============================================================================
229
+
230
+ # Cache task configs to avoid re-deepcopy on every reward call
231
+ _REWARD_ENV_CACHE = {}
232
+
233
+
234
+ def _get_reward_env(task_config: dict) -> OpenGridEnv:
235
+ """Get a fresh environment for reward computation."""
236
+ env = OpenGridEnv(copy.deepcopy(task_config))
237
+ env.reset()
238
+ return env
239
+
240
+
241
+ def compute_grpo_reward_env(
242
+ completions: list,
243
+ observations: list,
244
+ task_config: dict,
245
+ horizon: int = 3,
246
+ ) -> list:
247
+ """Environment-grounded reward: step the actual physics simulation.
248
+
249
+ For each LLM-generated action:
250
+ 1. Restore the env to the observation state
251
+ 2. Step with the proposed action and get the real reward
252
+ 3. Run a short rollout (horizon steps) with heuristic continuation
253
+ to capture trajectory-level impact
254
+ 4. Add format/schema bonuses
255
+
256
+ This directly addresses the proxy-reward disconnect that caused
257
+ the original GRPO training to show zero improvement.
258
+ """
259
+ from src.baseline import heuristic_policy
260
+
261
+ rewards = []
262
+ for completion, obs_dict in zip(completions, observations):
263
+ if obs_dict is None:
264
+ rewards.append(0.0)
265
+ continue
266
+
267
+ # Deserialize if needed (TRL may pass strings)
268
+ if isinstance(obs_dict, str):
269
+ try:
270
+ obs_dict = json.loads(obs_dict)
271
+ except (json.JSONDecodeError, TypeError):
272
+ rewards.append(0.0)
273
+ continue
274
+
275
+ action = extract_action(completion)
276
+ has_adjustments = bool(action.bus_adjustments)
277
+
278
+ # ── 1. Format reward (small but keeps gradient alive) ──
279
+ format_score = 0.0
280
+ if has_adjustments:
281
+ format_score += 0.05
282
+ else:
283
+ freq = obs_dict.get('grid_frequency', 50.0)
284
+ if abs(freq - 50.0) < 0.05:
285
+ format_score += 0.05 # No-op when stable is fine
286
+ else:
287
+ format_score -= 0.05 # No-op during deviation is bad
288
+
289
+ # ── 2. Environment-grounded reward ──
290
+ try:
291
+ env = _get_reward_env(task_config)
292
+ env._set_state(obs_dict)
293
+
294
+ # Step with the LLM's proposed action
295
+ obs_after, reward, done, info = env.step(action)
296
+ env_score = reward.value
297
+
298
+ # Blackout = catastrophic
299
+ if info.is_blackout:
300
+ rewards.append(-1.0)
301
+ continue
302
+
303
+ # ── 3. Mini-rollout: what happens next? ──
304
+ # Run a few more steps with heuristic to measure trajectory impact
305
+ rollout_reward = 0.0
306
+ for _ in range(horizon - 1):
307
+ if done:
308
+ break
309
+ h_action = heuristic_policy(obs_after)
310
+ obs_after, r, done, info = env.step(h_action)
311
+ rollout_reward += r.value
312
+ if info.is_blackout:
313
+ rollout_reward -= 10.0
314
+ break
315
+
316
+ # Combine: immediate reward + discounted future
317
+ total_env_score = env_score + 0.5 * rollout_reward
318
+
319
+ # Normalize to [-1, 1] range
320
+ # Typical per-step reward is ~0.5 to 1.5, rollout adds ~1-4
321
+ # So total_env_score is roughly in [-10, 4] range
322
+ normalized = total_env_score / 5.0
323
+
324
+ except Exception as e:
325
+ # Fallback: use lightweight heuristic scoring
326
+ normalized = _compute_heuristic_score(action, obs_dict)
327
+
328
+ total = format_score + normalized
329
+ rewards.append(max(-1.0, min(1.0, total)))
330
+
331
+ return rewards
332
+
333
+
334
+ def _compute_heuristic_score(action: GridAction, obs_dict: dict) -> float:
335
+ """Lightweight fallback scorer when env rollout fails."""
336
+ score = 0.0
337
+ freq = obs_dict.get('grid_frequency', 50.0)
338
+ freq_error = freq - 50.0
339
+ abs_error = abs(freq_error)
340
+
341
+ if not action.bus_adjustments:
342
+ return 0.0
343
+
344
+ total_delta = sum(a.delta for a in action.bus_adjustments)
345
+
346
+ # Direction
347
+ if abs_error > 0.05:
348
+ correct = (freq_error < 0 and total_delta > 0) or \
349
+ (freq_error > 0 and total_delta < 0)
350
+ score += 0.3 if correct else -0.3
351
+
352
+ # Proportionality
353
+ if abs_error > 0.05:
354
+ ideal = abs(freq_error) * 15.0
355
+ actual = abs(total_delta)
356
+ if actual > 0.1:
357
+ ratio = min(actual, ideal) / max(actual, ideal, 0.1)
358
+ score += 0.2 * ratio
359
+
360
+ # Stability
361
+ if abs_error < 0.05 and abs(total_delta) < 2.0:
362
+ score += 0.1
363
+
364
+ return max(-0.5, min(0.5, score))
365
+
366
+
367
+ # Keep old function for backward compat / test mode
368
+ def compute_grpo_reward(completions: list, observations: list, env_url: str = None) -> list:
369
+ """Legacy heuristic reward (used in test mode only)."""
370
+ return [_compute_heuristic_score(extract_action(c), o or {})
371
+ for c, o in zip(completions, observations)]
372
+
373
+
374
+ # ============================================================================
375
+ # Training Loop
376
+ # ============================================================================
377
+
378
+ def train_grpo(args):
379
+ """Main GRPO training loop using TRL."""
380
+ try:
381
+ from trl import GRPOTrainer, GRPOConfig
382
+ from transformers import AutoTokenizer, AutoModelForCausalLM
383
+ except ImportError:
384
+ print("ERROR: TRL not installed. Run: pip install trl transformers")
385
+ print("For quantized training: pip install unsloth")
386
+ sys.exit(1)
387
+
388
+ print(f"[TRAIN] Model: {args.model}")
389
+ print(f"[TRAIN] Task: {args.task}")
390
+ print(f"[TRAIN] Epochs: {args.epochs}")
391
+ print(f"[TRAIN] Batch size: {args.batch_size}")
392
+
393
+ # Load model
394
+ if args.use_unsloth:
395
+ try:
396
+ from unsloth import FastLanguageModel
397
+ model, tokenizer = FastLanguageModel.from_pretrained(
398
+ model_name=args.model,
399
+ max_seq_length=2048,
400
+ load_in_4bit=True,
401
+ )
402
+ model = FastLanguageModel.get_peft_model(
403
+ model,
404
+ r=16, lora_alpha=16, lora_dropout=0,
405
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
406
+ "gate_proj", "up_proj", "down_proj"],
407
+ )
408
+ print("[TRAIN] Loaded with Unsloth 4-bit quantization")
409
+ except ImportError:
410
+ print("WARNING: Unsloth not available, falling back to standard loading")
411
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
412
+ model = AutoModelForCausalLM.from_pretrained(args.model)
413
+ else:
414
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
415
+ model = AutoModelForCausalLM.from_pretrained(args.model)
416
+
417
+ if tokenizer.pad_token is None:
418
+ tokenizer.pad_token = tokenizer.eos_token
419
+
420
+ # Prepare training data: observation prompts from the environment
421
+ task_config = copy.deepcopy(TASKS[args.task])
422
+ base_seed = task_config.get('seed', 42)
423
+
424
+ # Generate prompts with diverse grid states:
425
+ # - Larger random perturbations (-30 to +30 MW)
426
+ # - Adversarial states (drained batteries, high frequency deviation)
427
+ # - More steps per episode for temporal diversity
428
+ print("[TRAIN] Generating training prompts from environment...")
429
+ prompts = []
430
+ obs_contexts = []
431
+ rng = np.random.RandomState(base_seed)
432
+
433
+ steps_per_episode = min(15, task_config['max_steps'])
434
+
435
+ for episode in range(args.num_prompts):
436
+ ep_config = copy.deepcopy(task_config)
437
+ ep_config['seed'] = base_seed + episode
438
+ env = OpenGridEnv(ep_config)
439
+ zone_obs = env.reset_multi()
440
+
441
+ # Adversarial injection: every 5th episode, drain batteries
442
+ if episode % 5 == 0:
443
+ for b in env.bus_state:
444
+ b_cfg = env._find_bus_config(b['id'])
445
+ if b_cfg and b_cfg['type'] == 'battery':
446
+ b['soc'] = max(1.0, b['soc'] * 0.1) # Near-empty
447
+
448
+ for t in range(steps_per_episode):
449
+ for agent_id, obs in zone_obs.items():
450
+ obs_dict = json.loads(obs.model_dump_json())
451
+ prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
452
+
453
+ messages = [
454
+ {"role": "system", "content": SYSTEM_PROMPT},
455
+ {"role": "user", "content": prompt_text},
456
+ ]
457
+
458
+ formatted = tokenizer.apply_chat_template(
459
+ messages, tokenize=False, add_generation_prompt=True
460
+ )
461
+ prompts.append(formatted)
462
+ obs_contexts.append(json.dumps(obs_dict)) # Store as string for Arrow compat
463
+
464
+ # Larger random perturbations for state diversity
465
+ random_actions = {}
466
+ for agent_id in range(env.num_agents):
467
+ zone_buses = task_config['zone_bus_ids'].get(agent_id, [])
468
+ controllable = [
469
+ bid for bid in zone_buses
470
+ if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')
471
+ in ['generator', 'battery']
472
+ ]
473
+ adj = []
474
+ if controllable:
475
+ # Pick 1-2 buses with larger perturbations
476
+ n_adj = min(len(controllable), rng.randint(1, 3))
477
+ chosen = rng.choice(controllable, size=n_adj, replace=False)
478
+ for bid in chosen:
479
+ adj.append(BusAdjustment(
480
+ bus_id=int(bid),
481
+ delta=float(rng.uniform(-30, 30)) # Was ±15
482
+ ))
483
+ random_actions[agent_id] = GridAction(bus_adjustments=adj)
484
+
485
+ result = env.step_multi(random_actions)
486
+ if result.done:
487
+ break
488
+ zone_obs = result.observations
489
+
490
+ print(f"[TRAIN] Generated {len(prompts)} training prompts")
491
+
492
+ # GRPO reward function: environment-grounded
493
+ def reward_fn(completions, obs_context=None, **kwargs):
494
+ """Environment-grounded GRPO reward.
495
+
496
+ Steps the actual physics simulation to score each action,
497
+ rather than using a disconnected heuristic proxy.
498
+ """
499
+ texts = []
500
+ for c in completions:
501
+ if isinstance(c, list):
502
+ text = c[-1]['content'] if c else ""
503
+ else:
504
+ text = str(c)
505
+ texts.append(text)
506
+
507
+ if obs_context is None:
508
+ obs_context = [None] * len(texts)
509
+
510
+ # Deserialize obs_context strings
511
+ obs_dicts = []
512
+ for ctx in obs_context:
513
+ if isinstance(ctx, str):
514
+ try:
515
+ obs_dicts.append(json.loads(ctx))
516
+ except (json.JSONDecodeError, TypeError):
517
+ obs_dicts.append(None)
518
+ else:
519
+ obs_dicts.append(ctx)
520
+
521
+ return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=3)
522
+
523
+ # GRPO Config — tuned for sustained learning signal
524
+ grpo_config = GRPOConfig(
525
+ output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
526
+ num_train_epochs=args.epochs,
527
+ per_device_train_batch_size=args.batch_size,
528
+ gradient_accumulation_steps=max(1, 16 // args.batch_size),
529
+ learning_rate=1e-5, # Was 5e-6 — slightly more aggressive
530
+ logging_steps=5,
531
+ save_steps=50,
532
+ max_completion_length=256,
533
+ num_generations=8, # Was 4 — wider group for better ranking signal
534
+ report_to="none",
535
+ remove_unused_columns=False,
536
+ )
537
+
538
+ # Create dataset — include obs_context so TRL passes it to reward_fn
539
+ from datasets import Dataset
540
+ train_dataset = Dataset.from_dict({
541
+ "prompt": prompts,
542
+ "obs_context": obs_contexts,
543
+ })
544
+
545
+ # Initialize trainer
546
+ trainer = GRPOTrainer(
547
+ model=model,
548
+ args=grpo_config,
549
+ train_dataset=train_dataset,
550
+ reward_funcs=reward_fn,
551
+ processing_class=tokenizer,
552
+ )
553
+
554
+ # Train
555
+ print("[TRAIN] Starting GRPO training...")
556
+ train_result = trainer.train()
557
+
558
+ # Save model
559
+ output_path = Path(args.output_dir) / "trained_model"
560
+ trainer.save_model(str(output_path))
561
+ tokenizer.save_pretrained(str(output_path))
562
+ print(f"[TRAIN] Model saved to {output_path}")
563
+
564
+ return train_result
565
+
566
+
567
+ # ============================================================================
568
+ # Evaluation & Plotting
569
+ # ============================================================================
570
+
571
+ def evaluate_model(generate_fn, task_ids=None, n_episodes=3, multi_agent=True):
572
+ """Evaluate a model across tasks. Returns per-task results.
573
+
574
+ Each episode uses a distinct seed to produce meaningful variance.
575
+ """
576
+ if task_ids is None:
577
+ task_ids = list(TASKS.keys())
578
+
579
+ results = {}
580
+ for task_id in task_ids:
581
+ base_config = TASKS[task_id]
582
+ base_seed = base_config.get('seed', 42)
583
+ episode_rewards = []
584
+
585
+ for ep in range(n_episodes):
586
+ # Vary seed per episode to get independent rollouts
587
+ ep_config = copy.deepcopy(base_config)
588
+ ep_config['seed'] = base_seed + ep
589
+ env = OpenGridEnv(ep_config)
590
+
591
+ if multi_agent:
592
+ data = rollout_multi_agent(env, generate_fn, ep_config)
593
+ else:
594
+ data = rollout_single_agent(env, generate_fn, ep_config)
595
+ episode_rewards.append(data['total_reward'])
596
+
597
+ results[task_id] = {
598
+ "avg_reward": np.mean(episode_rewards),
599
+ "std_reward": np.std(episode_rewards),
600
+ "rewards": episode_rewards,
601
+ }
602
+
603
+ return results
604
+
605
+
606
+ def plot_training_curves(training_log: list, output_path: str):
607
+ """Generate reward curves from training log."""
608
+ if not training_log:
609
+ print("[PLOT] No training data to plot.")
610
+ return
611
+
612
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
613
+
614
+ # Reward curve
615
+ steps = range(len(training_log))
616
+ rewards = [entry.get('reward', 0) for entry in training_log]
617
+
618
+ axes[0].plot(steps, rewards, color='#00d4aa', linewidth=1.5, alpha=0.6, label='Step Reward')
619
+
620
+ # Smoothed reward
621
+ if len(rewards) > 10:
622
+ window = min(20, len(rewards) // 5)
623
+ smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
624
+ axes[0].plot(range(window-1, len(rewards)), smoothed, color='#00d4aa',
625
+ linewidth=2.5, label=f'Smoothed (window={window})')
626
+
627
+ axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
628
+ axes[0].set_xlabel('Training Step')
629
+ axes[0].set_ylabel('Reward')
630
+ axes[0].set_title('GRPO Training — Reward Curve')
631
+ axes[0].legend()
632
+ axes[0].grid(True, alpha=0.3)
633
+
634
+ # Loss curve (if available)
635
+ losses = [entry.get('loss', 0) for entry in training_log if 'loss' in entry]
636
+ if losses:
637
+ axes[1].plot(range(len(losses)), losses, color='#ff6b6b', linewidth=1.5)
638
+ axes[1].set_xlabel('Training Step')
639
+ axes[1].set_ylabel('Loss')
640
+ axes[1].set_title('Training Loss')
641
+ axes[1].grid(True, alpha=0.3)
642
+ else:
643
+ axes[1].text(0.5, 0.5, 'Loss data not available', ha='center', va='center',
644
+ transform=axes[1].transAxes, fontsize=14, color='gray')
645
+ axes[1].set_title('Training Loss')
646
+
647
+ plt.tight_layout()
648
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
649
+ plt.close()
650
+ print(f"[PLOT] Saved training curves to {output_path}")
651
+
652
+
653
+ def plot_before_after(before_results: dict, after_results: dict, output_path: str):
654
+ """Generate before/after comparison chart."""
655
+ fig, ax = plt.subplots(figsize=(10, 6))
656
+
657
+ tasks = list(before_results.keys())
658
+ x = np.arange(len(tasks))
659
+ width = 0.35
660
+
661
+ before_vals = [before_results[t]['avg_reward'] for t in tasks]
662
+ after_vals = [after_results[t]['avg_reward'] for t in tasks]
663
+
664
+ bars1 = ax.bar(x - width/2, before_vals, width, label='Before Training',
665
+ color='#ff6b6b', alpha=0.8)
666
+ bars2 = ax.bar(x + width/2, after_vals, width, label='After Training',
667
+ color='#00d4aa', alpha=0.8)
668
+
669
+ ax.set_xlabel('Task')
670
+ ax.set_ylabel('Average Episode Reward')
671
+ ax.set_title('OpenGrid — GRPO Training: Before vs After')
672
+ ax.set_xticks(x)
673
+ ax.set_xticklabels([t.replace('task_', '').title() for t in tasks])
674
+ ax.legend()
675
+ ax.grid(True, alpha=0.3, axis='y')
676
+
677
+ # Add value labels on bars (handle negative heights)
678
+ for bar in list(bars1) + list(bars2):
679
+ h = bar.get_height()
680
+ va = 'bottom' if h >= 0 else 'top'
681
+ offset = 1 if h >= 0 else -1
682
+ ax.text(bar.get_x() + bar.get_width()/2., h + offset,
683
+ f'{h:.1f}', ha='center', va=va, fontsize=9)
684
+
685
+ plt.tight_layout()
686
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
687
+ plt.close()
688
+ print(f"[PLOT] Saved before/after comparison to {output_path}")
689
+
690
+
691
+ # ============================================================================
692
+ # Test Mode
693
+ # ============================================================================
694
+
695
+ def run_test_mode():
696
+ """Quick pipeline verification without GPU. Runs a few episodes with heuristic."""
697
+ print("\n" + "="*60)
698
+ print(" OpenGrid GRPO Training — TEST MODE")
699
+ print(" (Verifies the pipeline without training)")
700
+ print("="*60 + "\n")
701
+
702
+ # Test 1: Prompt generation
703
+ print("[TEST] Generating prompts...")
704
+ env = OpenGridEnv(TASKS["task_easy"])
705
+ zone_obs = env.reset_multi()
706
+ for agent_id, obs in zone_obs.items():
707
+ prompt = format_observation_prompt(obs.model_dump(), zone_name=obs.zone_name)
708
+ print(f"\n--- Agent {agent_id} ({obs.zone_name}) ---")
709
+ print(prompt[:500])
710
+
711
+ # Test 2: Action extraction
712
+ print("\n[TEST] Testing action extraction...")
713
+ test_cases = [
714
+ '{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}',
715
+ 'Here is my action: {"bus_adjustments": [], "topology_actions": []}',
716
+ 'invalid garbage',
717
+ ]
718
+ for tc in test_cases:
719
+ action = extract_action(tc)
720
+ print(f" Input: {tc[:60]}... -> {len(action.bus_adjustments)} adjustments")
721
+
722
+ # Test 3: Multi-agent rollout with heuristic
723
+ print("\n[TEST] Running multi-agent rollout...")
724
+ from src.baseline import heuristic_policy
725
+
726
+ def heuristic_generate(prompt):
727
+ """Pseudo-LLM: use heuristic policy and format as JSON."""
728
+ # Extract frequency from prompt (handles negative/signed values)
729
+ freq_match = re.search(r'Frequency:\s*([-+]?\d+(?:\.\d+)?)', prompt)
730
+ freq = float(freq_match.group(1)) if freq_match else 50.0
731
+
732
+ # Simple proportional control
733
+ error = 50.0 - freq
734
+ delta = error * 10 # proportional gain
735
+ delta = max(-20, min(20, delta))
736
+
737
+ # Find controllable buses (generator/battery, NOT slack — physics overwrites it)
738
+ bus_matches = re.findall(r'Bus (\d+) \((generator|battery)\)', prompt)
739
+ if bus_matches:
740
+ # Distribute across all controllable buses
741
+ per_bus = delta / len(bus_matches)
742
+ adjustments = [
743
+ {"bus_id": int(m[0]), "delta": round(per_bus, 1)}
744
+ for m in bus_matches
745
+ ]
746
+ return json.dumps({
747
+ "bus_adjustments": adjustments,
748
+ "topology_actions": []
749
+ })
750
+ return json.dumps({"bus_adjustments": [], "topology_actions": []})
751
+
752
+ for task_id in ["task_easy", "task_medium"]:
753
+ config = copy.deepcopy(TASKS[task_id])
754
+ env = OpenGridEnv(config)
755
+ result = rollout_multi_agent(env, heuristic_generate, config)
756
+ print(f" {task_id}: reward={result['total_reward']:.2f}, "
757
+ f"steps={result['steps']}, blackout={result['is_blackout']}, "
758
+ f"safety_interventions={result['safety_interventions']}")
759
+
760
+ # Test 4: Reward function
761
+ print("\n[TEST] Testing GRPO reward function...")
762
+ test_completions = [
763
+ '{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}',
764
+ '{"bus_adjustments": [], "topology_actions": []}',
765
+ 'not valid json at all',
766
+ ]
767
+ test_obs = [{"grid_frequency": 49.5}, {"grid_frequency": 50.0}, {"grid_frequency": 50.3}]
768
+ grpo_rewards = compute_grpo_reward(test_completions, test_obs)
769
+ for tc, r in zip(test_completions, grpo_rewards):
770
+ print(f" Reward: {r:.2f} for: {tc[:50]}...")
771
+
772
+ # Test 5: Generate plots
773
+ output_dir = Path("training/outputs")
774
+ output_dir.mkdir(parents=True, exist_ok=True)
775
+
776
+ fake_log = [{"reward": np.random.normal(0.5, 0.3) + i * 0.01, "loss": 2.0 - i * 0.02}
777
+ for i in range(100)]
778
+ plot_training_curves(fake_log, str(output_dir / "test_training_curves.png"))
779
+
780
+ fake_before = {t: {"avg_reward": np.random.uniform(20, 35)} for t in TASKS}
781
+ fake_after = {t: {"avg_reward": np.random.uniform(40, 55)} for t in TASKS}
782
+ plot_before_after(fake_before, fake_after, str(output_dir / "test_before_after.png"))
783
+
784
+ print("\n" + "="*60)
785
+ print(" [OK] ALL TESTS PASSED - Pipeline is ready for GPU training")
786
+ print("="*60)
787
+
788
+
789
+ # ============================================================================
790
+ # Main
791
+ # ============================================================================
792
+
793
+ def main():
794
+ parser = argparse.ArgumentParser(description="OpenGrid GRPO Training")
795
+ parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
796
+ help="HuggingFace model name or path")
797
+ parser.add_argument("--task", default="task_easy", choices=list(TASKS.keys()),
798
+ help="Which task to train on")
799
+ parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
800
+ parser.add_argument("--batch-size", type=int, default=2, help="Batch size per device")
801
+ parser.add_argument("--num-prompts", type=int, default=50,
802
+ help="Number of episodes to generate prompts from")
803
+ parser.add_argument("--output-dir", default="training/outputs",
804
+ help="Directory for checkpoints and plots")
805
+ parser.add_argument("--use-unsloth", action="store_true",
806
+ help="Use Unsloth for 4-bit quantized training")
807
+ parser.add_argument("--test-mode", action="store_true",
808
+ help="Run pipeline verification without GPU")
809
+
810
+ args = parser.parse_args()
811
+
812
+ if args.test_mode:
813
+ run_test_mode()
814
+ return
815
+
816
+ # Create output directory
817
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
818
+
819
+ # Run training
820
+ train_result = train_grpo(args)
821
+
822
+ print("\n[DONE] Training complete!")
823
+ print(f" Output: {args.output_dir}")
824
+
825
+
826
+ if __name__ == "__main__":
827
+ main()
validate-submission.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -uo pipefail
3
+
4
+ DOCKER_BUILD_TIMEOUT=600
5
+ RED='\033[0;31m'
6
+ GREEN='\033[0;32m'
7
+ YELLOW='\033[1;33m'
8
+ BOLD='\033[1m'
9
+ NC='\033[0m'
10
+
11
+ PING_URL="${1:-}"
12
+ REPO_DIR="${2:-.}"
13
+
14
+ if [ -z "$PING_URL" ]; then
15
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
16
+ exit 1
17
+ fi
18
+
19
+ REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"
20
+ PING_URL="${PING_URL%/}"
21
+ PASS=0
22
+
23
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
24
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
25
+ fail() { log "${RED}FAILED${NC} -- $1"; }
26
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
27
+ stop_at() {
28
+ printf "\n${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
29
+ exit 1
30
+ }
31
+
32
+ printf "\n${BOLD}========================================${NC}\n"
33
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
34
+ printf "${BOLD}========================================${NC}\n"
35
+ log "Repo: $REPO_DIR"
36
+ log "Ping URL: $PING_URL"
37
+ printf "\n"
38
+
39
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
40
+ HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST \
41
+ -H "Content-Type: application/json" -d '{}' \
42
+ "$PING_URL/reset" --max-time 30 2>/dev/null || printf "000")
43
+
44
+ if [ "$HTTP_CODE" = "200" ]; then
45
+ pass "HF Space is live and responds to /reset"
46
+ elif [ "$HTTP_CODE" = "000" ]; then
47
+ fail "HF Space not reachable (connection failed or timed out)"
48
+ stop_at "Step 1"
49
+ else
50
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
51
+ stop_at "Step 1"
52
+ fi
53
+
54
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
55
+ if ! command -v docker &>/dev/null; then
56
+ fail "docker command not found"
57
+ stop_at "Step 2"
58
+ fi
59
+
60
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
61
+ DOCKER_CONTEXT="$REPO_DIR"
62
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
63
+ DOCKER_CONTEXT="$REPO_DIR/server"
64
+ else
65
+ fail "No Dockerfile found"
66
+ stop_at "Step 2"
67
+ fi
68
+
69
+ log " Found Dockerfile in $DOCKER_CONTEXT"
70
+ BUILD_OK=false
71
+ BUILD_OUTPUT=$(timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
72
+
73
+ if [ "$BUILD_OK" = true ]; then
74
+ pass "Docker build succeeded"
75
+ else
76
+ fail "Docker build failed"
77
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
78
+ stop_at "Step 2"
79
+ fi
80
+
81
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
82
+ if ! command -v openenv &>/dev/null; then
83
+ fail "openenv command not found"
84
+ hint "Install it: pip install openenv-core"
85
+ stop_at "Step 3"
86
+ fi
87
+
88
+ VALIDATE_OK=false
89
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
90
+
91
+ if [ "$VALIDATE_OK" = true ]; then
92
+ pass "openenv validate passed"
93
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
94
+ else
95
+ fail "openenv validate failed"
96
+ printf "%s\n" "$VALIDATE_OUTPUT"
97
+ stop_at "Step 3"
98
+ fi
99
+
100
+ printf "\n${BOLD}========================================${NC}\n"
101
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
102
+ printf "${BOLD}========================================${NC}\n\n"
103
+ exit 0