Opengrid / src /baseline.py
K446's picture
OpenGrid: Multi-agent POMDP power grid environment with GRPO training
78131a0
"""
Baseline Policies for OpenGrid
================================
Provides two agent implementations:
1. heuristic_policy — deterministic rule-based baseline for reproducible scoring
2. llm_policy — LLM-based policy using OpenAI-compatible API
Both support GridObservation (single-agent) and ZoneObservation (multi-agent).
"""
import json
import logging
import os
from typing import List, Union
from openai import OpenAI
from .models import GridAction, BusAdjustment, GridObservation, ZoneObservation
logger = logging.getLogger(__name__)
# API configuration — HF_TOKEN for Hugging Face endpoints, OPENAI_API_KEY for OpenAI
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
API_KEY = os.getenv("OPENAI_API_KEY", os.getenv("HF_TOKEN", ""))
# Cached client instance
_CLIENT = None
def _get_client() -> OpenAI:
"""Lazy-cached client creation."""
global _CLIENT
if _CLIENT is None:
if not API_KEY:
raise RuntimeError(
"Missing API key. Set OPENAI_API_KEY or HF_TOKEN environment variable."
)
_CLIENT = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=15.0)
return _CLIENT
def _obs_buses(obs):
"""Extract bus list from either GridObservation or ZoneObservation."""
return getattr(obs, "buses", getattr(obs, "local_buses", []))
def _obs_lines(obs):
"""Extract line list from either GridObservation or ZoneObservation."""
if hasattr(obs, "lines"):
return obs.lines
internal = getattr(obs, "internal_lines", [])
boundary = getattr(obs, "boundary_lines", [])
return list(internal) + list(boundary)
SYSTEM_PROMPT = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
Key objectives:
1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5–50.5 Hz)
2. Prevent transmission line overloads (rho < 1.0)
3. Avoid grid islanding (blackout)
Available actions:
1. bus_adjustments: List of {"bus_id": int, "delta": float}
- Positive delta = increase power injection (discharge battery / ramp up generator)
- Negative delta = decrease power injection (charge battery / ramp down generator)
- Only works on battery and generator buses (NOT slack, load, solar, or wind)
- Slack bus injection is computed by physics — adjustments are ignored
2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
- Opening a line removes it; closing reconnects. 3-step cooldown after each switch.
- WARNING: Opening lines can cause islanding → blackout → -100 reward
- Prefer NO topology actions unless absolutely necessary.
Strategy tips:
- If frequency < 50 Hz: grid needs more generation → discharge batteries or ramp up generators
- If frequency > 50 Hz: grid has excess generation → charge batteries or ramp down generators
- If a line rho > 0.9: reduce generation at one end or increase at the other to shift flow
- Prefer minimal actions. Do-nothing is better than reckless switching.
Respond with ONLY a valid JSON object, no markdown, no explanation. Example:
{"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
"""
def parse_action_response(response_text: str) -> GridAction:
"""Parse LLM response into a GridAction. Falls back to no-op on parse errors."""
try:
text = response_text.strip()
# Remove fenced code block if present
if text.startswith("```"):
lines = text.splitlines()
if lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].startswith("```"):
lines = lines[:-1]
text = "\n".join(lines).strip()
# Extract first JSON object
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return GridAction()
data = json.loads(text[start:end + 1])
# Handle list wrapping
if isinstance(data, list):
data = data[0] if data else {}
return GridAction(**data)
except Exception:
return GridAction()
def llm_policy(obs: Union[GridObservation, ZoneObservation]) -> GridAction:
"""LLM-based policy using the OpenAI-compatible API.
Supports both GridObservation and ZoneObservation.
Falls back to no-op on any error.
"""
client = _get_client()
obs_json = obs.model_dump_json()
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Current Grid State:\n{obs_json}"}
],
temperature=0.0,
max_tokens=300,
)
action_str = response.choices[0].message.content
return parse_action_response(action_str)
except Exception as e:
logger.debug("LLM policy error: %s", e, exc_info=True)
return GridAction()
def heuristic_policy(
obs: Union[GridObservation, ZoneObservation],
) -> GridAction:
"""Rule-based baseline policy for reproducible scoring.
Strategy:
- Use batteries and generators for frequency regulation (proportional control)
- DO NOT open overloaded lines (causes cascading failures)
- DO NOT adjust the slack bus (overwritten by physics solver)
- Let the environment/safety layer clamp any out-of-range deltas
Supports both GridObservation (single-agent) and ZoneObservation (multi-agent).
"""
adj = []
freq = obs.grid_frequency
freq_error = freq - 50.0 # positive = too high, negative = too low
buses = list(_obs_buses(obs))
lines = list(_obs_lines(obs))
batteries = [b for b in buses if b.type == 'battery']
generators = [b for b in buses if b.type == 'generator']
# --- 1. Proportional frequency control via batteries ---
if abs(freq_error) > 0.1 and batteries:
# Distribute correction across all available batteries
correction_total = -freq_error * 15.0 # stronger gain than naive 2.0
correction_total = max(-20.0, min(20.0, correction_total))
per_battery = correction_total / len(batteries)
for bus in batteries:
if per_battery > 0 and bus.soc > 0:
# Discharge — safety layer clamps to actual SOC
adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
elif per_battery < 0:
# Charge — safety layer clamps to remaining capacity
adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
# --- 2. Generator response for larger deviations ---
if abs(freq_error) > 0.25:
for bus in generators:
delta = -freq_error * 5.0
ramp = getattr(bus, 'ramp_rate', 20.0)
delta = max(-ramp, min(ramp, delta))
adj.append(BusAdjustment(bus_id=bus.id, delta=delta))
# --- 3. Overload relief via generators (not slack) ---
adjusted_for_overload = set()
for line in lines:
if line.rho > 0.95 and line.connected:
for bus in generators:
if bus.id not in adjusted_for_overload and bus.p_injection > 5:
adj.append(BusAdjustment(bus_id=bus.id, delta=-3.0))
adjusted_for_overload.add(bus.id)
break
# No topology actions — much safer than opening overloaded lines
return GridAction(bus_adjustments=adj, topology_actions=[])