File size: 5,919 Bytes
8a02303 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
Comprehensive training pipeline verification.
Tests: scenarios, reward functions, policies, GRPO integration, safety.
"""
import json
import copy
import sys
sys.path.insert(0, ".")
from src.tasks import TASKS, get_task
from src.environment import OpenGridEnv
from src.models import GridAction, GridObservation
from src.grader import RobustnessGrader
from src.baseline import heuristic_policy
from src.safety import SafetyLayer
print("=" * 60)
print(" COMPREHENSIVE TRAINING PIPELINE VERIFICATION")
print("=" * 60)
errors = []
# --- 1. Scenario loading ---
print("\n[1/7] Scenario Loading...")
expected_tasks = ["task_easy", "task_medium", "task_hard",
"task_karnataka", "karnataka_easy", "karnataka_medium", "karnataka_hard"]
for tid in expected_tasks:
if tid not in TASKS:
errors.append(f"Missing task: {tid}")
print(f" FAIL: {tid} not in TASKS")
else:
cfg = TASKS[tid]
print(f" OK: {tid} - {cfg['num_buses']}b/{cfg['num_agents']}a zones={cfg['zone_names']}")
# --- 2. Environment step for each scenario ---
print("\n[2/7] Environment Step Test...")
for tid in expected_tasks:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [], "topology_actions": []})
)
obs2, reward, done, info = env.step(action)
freq = obs2.grid_frequency
r = reward.value
print(f" OK: {tid} - freq={freq:.2f}Hz reward={r:.2f}")
except Exception as e:
errors.append(f"Env step failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 3. Reward function (GRPO) test ---
print("\n[3/7] GRPO Reward Function Test...")
from training.train_grpo import compute_grpo_reward_env
test_completions = [
'{"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []}',
'{"bus_adjustments": [], "topology_actions": []}',
'not valid json',
]
test_observations = [
{"grid_frequency": 49.5, "buses": [], "lines": []},
{"grid_frequency": 50.0, "buses": [], "lines": []},
{"grid_frequency": 48.0, "buses": [], "lines": []},
]
try:
cfg = get_task("karnataka_easy")
rewards = compute_grpo_reward_env(test_completions, test_observations, cfg, horizon=1)
for i, r in enumerate(rewards):
print(f" Completion {i}: reward={r:.3f}")
print(f" OK: GRPO rewards computed for {len(rewards)} completions")
except Exception as e:
errors.append(f"GRPO reward failed: {e}")
print(f" FAIL: {e}")
# --- 4. Karnataka Difficulty Gradient Test ---
print("\n[4/7] Karnataka Difficulty Gradient Test...")
ka_rewards = {}
for tid in ["karnataka_easy", "karnataka_medium", "karnataka_hard"]:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
total_r = 0
for step_i in range(5):
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [], "topology_actions": []})
)
obs, reward, done, info = env.step(action)
total_r += reward.value
if done:
break
ka_rewards[tid] = total_r
print(f" {tid}: 5-step reward={total_r:.2f}")
except Exception as e:
errors.append(f"Ka difficulty test failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
if len(ka_rewards) == 3:
# Easy should generally give higher or equal rewards than hard
if ka_rewards["karnataka_easy"] >= ka_rewards["karnataka_hard"]:
print(f" OK: Difficulty gradient correct (easy >= hard)")
else:
print(f" WARN: easy ({ka_rewards['karnataka_easy']:.2f}) < hard ({ka_rewards['karnataka_hard']:.2f}) - may vary by seed")
# --- 5. Heuristic policy test ---
print("\n[5/7] Heuristic Policy Test...")
for tid in ["task_easy", "karnataka_easy", "task_karnataka"]:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
total_r = 0
for step_i in range(10):
action = heuristic_policy(obs)
obs, reward, done, info = env.step(action)
total_r += reward.value
if done:
break
print(f" OK: {tid} - 10-step heuristic reward={total_r:.2f}")
except Exception as e:
errors.append(f"Heuristic policy failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 6. Safety layer test ---
print("\n[6/7] Safety Layer Test...")
for tid in ["task_easy", "karnataka_easy", "karnataka_hard"]:
try:
cfg = get_task(tid)
layer = SafetyLayer(cfg)
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [{"bus_id": 0, "delta": 100.0}], "topology_actions": []})
)
bus_state = [{"id": b["id"], "p": b.get("base_p", 0), "soc": b.get("init_soc", 0)} for b in cfg["buses"]]
line_state = [{"id": l["id"], "connected": True, "flow": 0} for l in cfg["lines"]]
safe_action, report = layer.validate_and_correct(0, action, line_state, bus_state, {})
print(f" OK: {tid} - corrected={report.was_corrected}, n1_violations={report.n1_violations_detected}")
except Exception as e:
errors.append(f"Safety layer failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 7. Curriculum order test ---
print("\n[7/7] Curriculum Order Test...")
from training.train_grpo import CURRICULUM_ORDER
for tid in CURRICULUM_ORDER:
if tid in TASKS:
print(f" OK: {tid} available")
else:
errors.append(f"Curriculum task missing: {tid}")
print(f" FAIL: {tid} not in TASKS")
# --- Summary ---
print("\n" + "=" * 60)
if errors:
print(f" FAILED: {len(errors)} errors")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print(" ALL CHECKS PASSED - Training pipeline ready")
print("=" * 60)
|