pc-ddpm-alberta / src /pc_ddpm_alberta /feasibility.py
jbobym's picture
space deploy: trim short_description to fit HF 60-char cap
93ed35a
"""GNN surrogate for fast voltage-feasibility checks on generated scenarios.
The GraphSAGE surrogate (3 layers, 128 hidden) is trained to map per-bus
[wind, solar, load] contributions to per-bus voltage magnitudes on the
IEEE 118-bus topology. We reuse it here to score generated scenarios in
milliseconds — orders of magnitude faster than running pandapower AC PF
inline in the demo.
Headline metric reference: PC-DDPM was evaluated against pandapower AC PF
under both operational [0.85, 1.10] pu and strict ANSI [0.89, 1.05] pu
voltage bounds (`eval_all_models_bounds.py` in pc-ddpm-epec2026). The
surrogate-based check here is a fast proxy of those numbers, not the
authoritative power-flow evaluation.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812 — PyTorch convention
from huggingface_hub import hf_hub_download
N_BUS = 118
HF_REPO_ID = "jbobym/pc-ddpm-alberta"
DEFAULT_GNN_FILE = "gnn_surrogate.pt"
DEFAULT_GNN_NORM_FILE = "gnn_surrogate_norm.npz"
V_BOUNDS_OPERATIONAL: tuple[float, float] = (0.85, 1.10)
V_BOUNDS_ANSI: tuple[float, float] = (0.89, 1.05)
class SAGELayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.W_self = nn.Linear(in_dim, out_dim, bias=False)
self.W_neigh = nn.Linear(in_dim, out_dim, bias=True)
def forward(self, h: torch.Tensor, A_norm: torch.Tensor) -> torch.Tensor:
neigh = torch.matmul(A_norm, h)
return F.relu(self.W_self(h) + self.W_neigh(neigh))
class GraphSAGE(nn.Module):
"""Fixed-topology GraphSAGE surrogate. Mirrors EPEC `train_gnn_surrogate.py`."""
def __init__(
self,
in_dim: int = 3,
hidden_dim: int = 128,
out_dim: int = 2,
n_layers: int = 3,
) -> None:
super().__init__()
dims = [in_dim] + [hidden_dim] * n_layers
self.sage_layers = nn.ModuleList(
[SAGELayer(dims[i], dims[i + 1]) for i in range(n_layers)]
)
self.out = nn.Linear(hidden_dim, out_dim)
def forward(self, x: torch.Tensor, A_norm: torch.Tensor) -> torch.Tensor:
h = x
for layer in self.sage_layers:
h = layer(h, A_norm)
return self.out(h) # type: ignore[no-any-return]
@dataclass
class SurrogateBundle:
model: GraphSAGE
A_norm: torch.Tensor
X_mean: np.ndarray
X_std: np.ndarray
V_mean: np.ndarray
V_std: np.ndarray
wind_vec: np.ndarray
solar_vec: np.ndarray
load_frac: np.ndarray
device: torch.device
def load_gnn_surrogate(
repo_id: str = HF_REPO_ID,
model_filename: str = DEFAULT_GNN_FILE,
norm_filename: str = DEFAULT_GNN_NORM_FILE,
device: str | torch.device | None = None,
) -> SurrogateBundle:
"""Pull weights + norm stats from HF Hub and assemble the bundle."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
device = torch.device(device)
norm = np.load(hf_hub_download(repo_id=repo_id, filename=norm_filename))
ckpt = torch.load(
hf_hub_download(repo_id=repo_id, filename=model_filename),
map_location=device,
weights_only=False,
)
model = GraphSAGE(
in_dim=3,
hidden_dim=int(ckpt["hidden_dim"]),
out_dim=2,
n_layers=int(ckpt["n_layers"]),
).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
A_norm = torch.as_tensor(norm["A_norm"], dtype=torch.float32, device=device)
return SurrogateBundle(
model=model,
A_norm=A_norm,
X_mean=norm["X_mean"].astype(np.float32),
X_std=norm["X_std"].astype(np.float32),
V_mean=norm["V_mean"].astype(np.float32),
V_std=norm["V_std"].astype(np.float32),
wind_vec=norm["wind_vec"].astype(np.float32),
solar_vec=norm["solar_vec"].astype(np.float32),
load_frac=norm["load_frac"].astype(np.float32),
device=device,
)
def _build_node_features(
scenarios_flat: np.ndarray,
bundle: SurrogateBundle,
) -> np.ndarray:
"""Map (B, 3) global scenarios to (B, 118, 3) per-bus features."""
X_norm = (scenarios_flat - bundle.X_mean) / bundle.X_std
feat = np.stack(
[
np.outer(X_norm[:, 0], bundle.wind_vec),
np.outer(X_norm[:, 1], bundle.solar_vec),
np.outer(X_norm[:, 2], bundle.load_frac),
],
axis=2,
)
return feat.astype(np.float32)
@torch.no_grad()
def predict_voltages(scenarios: np.ndarray, bundle: SurrogateBundle) -> np.ndarray:
"""Run the surrogate on `(N, 3, T)` scenarios; return `(N, T, 118)` voltages in pu."""
if scenarios.ndim != 3 or scenarios.shape[1] != 3:
raise ValueError(f"scenarios must be (N, 3, T); got {scenarios.shape}")
n_scenarios, _, n_steps = scenarios.shape
flat = scenarios.transpose(0, 2, 1).reshape(-1, 3)
feat = _build_node_features(flat, bundle)
x = torch.as_tensor(feat, device=bundle.device)
pred = bundle.model(x, bundle.A_norm)
v_norm = pred[..., 0].cpu().numpy()
v_pu = v_norm * bundle.V_std + bundle.V_mean
out: np.ndarray = v_pu.reshape(n_scenarios, n_steps, N_BUS).astype(np.float32)
return out
def feasibility_check(
scenarios: np.ndarray,
bundle: SurrogateBundle,
) -> dict[str, float | int]:
"""Score scenarios under operational and ANSI voltage bounds.
Returns the per-scenario feasibility rate (a scenario is feasible iff
every bus stays in bounds for every hour). The two bound regimes match
`eval_all_models_bounds.py` in pc-ddpm-epec2026; absolute numbers may
differ from the AC-PF headline because the GNN is a learned surrogate.
"""
v_pu = predict_voltages(scenarios, bundle)
op_lo, op_hi = V_BOUNDS_OPERATIONAL
a_lo, a_hi = V_BOUNDS_ANSI
op_feasible = ((v_pu >= op_lo) & (v_pu <= op_hi)).all(axis=(1, 2))
a_feasible = ((v_pu >= a_lo) & (v_pu <= a_hi)).all(axis=(1, 2))
n_total = int(scenarios.shape[0])
return {
"n_total": n_total,
"operational_feasible": int(op_feasible.sum()),
"ansi_feasible": int(a_feasible.sum()),
"operational_pct": float(op_feasible.mean() * 100),
"ansi_pct": float(a_feasible.mean() * 100),
}
def traffic_light(operational_pct: float) -> str:
"""Bucket operational feasibility into a green/yellow/red badge.
Headline metric is 100% under operational bounds, so any drop is
surprising. Bucket strictly: 100 → green, 90–99 → yellow, <90 → red.
"""
if operational_pct >= 100.0:
return "green"
if operational_pct >= 90.0:
return "yellow"
return "red"