File size: 5,919 Bytes
8a02303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Comprehensive training pipeline verification.
Tests: scenarios, reward functions, policies, GRPO integration, safety.
"""
import json
import copy
import sys
sys.path.insert(0, ".")

from src.tasks import TASKS, get_task
from src.environment import OpenGridEnv
from src.models import GridAction, GridObservation
from src.grader import RobustnessGrader
from src.baseline import heuristic_policy
from src.safety import SafetyLayer

print("=" * 60)
print("  COMPREHENSIVE TRAINING PIPELINE VERIFICATION")
print("=" * 60)

errors = []

# --- 1. Scenario loading ---
print("\n[1/7] Scenario Loading...")
expected_tasks = ["task_easy", "task_medium", "task_hard",
                  "task_karnataka", "karnataka_easy", "karnataka_medium", "karnataka_hard"]
for tid in expected_tasks:
    if tid not in TASKS:
        errors.append(f"Missing task: {tid}")
        print(f"  FAIL: {tid} not in TASKS")
    else:
        cfg = TASKS[tid]
        print(f"  OK: {tid} - {cfg['num_buses']}b/{cfg['num_agents']}a zones={cfg['zone_names']}")

# --- 2. Environment step for each scenario ---
print("\n[2/7] Environment Step Test...")
for tid in expected_tasks:
    try:
        cfg = get_task(tid)
        env = OpenGridEnv(cfg)
        obs = env.reset()
        action = GridAction.model_validate_json(
            json.dumps({"bus_adjustments": [], "topology_actions": []})
        )
        obs2, reward, done, info = env.step(action)
        freq = obs2.grid_frequency
        r = reward.value
        print(f"  OK: {tid} - freq={freq:.2f}Hz reward={r:.2f}")
    except Exception as e:
        errors.append(f"Env step failed for {tid}: {e}")
        print(f"  FAIL: {tid} - {e}")

# --- 3. Reward function (GRPO) test ---
print("\n[3/7] GRPO Reward Function Test...")
from training.train_grpo import compute_grpo_reward_env
test_completions = [
    '{"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []}',
    '{"bus_adjustments": [], "topology_actions": []}',
    'not valid json',
]
test_observations = [
    {"grid_frequency": 49.5, "buses": [], "lines": []},
    {"grid_frequency": 50.0, "buses": [], "lines": []},
    {"grid_frequency": 48.0, "buses": [], "lines": []},
]
try:
    cfg = get_task("karnataka_easy")
    rewards = compute_grpo_reward_env(test_completions, test_observations, cfg, horizon=1)
    for i, r in enumerate(rewards):
        print(f"  Completion {i}: reward={r:.3f}")
    print(f"  OK: GRPO rewards computed for {len(rewards)} completions")
except Exception as e:
    errors.append(f"GRPO reward failed: {e}")
    print(f"  FAIL: {e}")

# --- 4. Karnataka Difficulty Gradient Test ---
print("\n[4/7] Karnataka Difficulty Gradient Test...")
ka_rewards = {}
for tid in ["karnataka_easy", "karnataka_medium", "karnataka_hard"]:
    try:
        cfg = get_task(tid)
        env = OpenGridEnv(cfg)
        obs = env.reset()
        total_r = 0
        for step_i in range(5):
            action = GridAction.model_validate_json(
                json.dumps({"bus_adjustments": [], "topology_actions": []})
            )
            obs, reward, done, info = env.step(action)
            total_r += reward.value
            if done:
                break
        ka_rewards[tid] = total_r
        print(f"  {tid}: 5-step reward={total_r:.2f}")
    except Exception as e:
        errors.append(f"Ka difficulty test failed for {tid}: {e}")
        print(f"  FAIL: {tid} - {e}")

if len(ka_rewards) == 3:
    # Easy should generally give higher or equal rewards than hard
    if ka_rewards["karnataka_easy"] >= ka_rewards["karnataka_hard"]:
        print(f"  OK: Difficulty gradient correct (easy >= hard)")
    else:
        print(f"  WARN: easy ({ka_rewards['karnataka_easy']:.2f}) < hard ({ka_rewards['karnataka_hard']:.2f}) - may vary by seed")

# --- 5. Heuristic policy test ---
print("\n[5/7] Heuristic Policy Test...")
for tid in ["task_easy", "karnataka_easy", "task_karnataka"]:
    try:
        cfg = get_task(tid)
        env = OpenGridEnv(cfg)
        obs = env.reset()
        total_r = 0
        for step_i in range(10):
            action = heuristic_policy(obs)
            obs, reward, done, info = env.step(action)
            total_r += reward.value
            if done:
                break
        print(f"  OK: {tid} - 10-step heuristic reward={total_r:.2f}")
    except Exception as e:
        errors.append(f"Heuristic policy failed for {tid}: {e}")
        print(f"  FAIL: {tid} - {e}")

# --- 6. Safety layer test ---
print("\n[6/7] Safety Layer Test...")
for tid in ["task_easy", "karnataka_easy", "karnataka_hard"]:
    try:
        cfg = get_task(tid)
        layer = SafetyLayer(cfg)
        action = GridAction.model_validate_json(
            json.dumps({"bus_adjustments": [{"bus_id": 0, "delta": 100.0}], "topology_actions": []})
        )
        bus_state = [{"id": b["id"], "p": b.get("base_p", 0), "soc": b.get("init_soc", 0)} for b in cfg["buses"]]
        line_state = [{"id": l["id"], "connected": True, "flow": 0} for l in cfg["lines"]]
        safe_action, report = layer.validate_and_correct(0, action, line_state, bus_state, {})
        print(f"  OK: {tid} - corrected={report.was_corrected}, n1_violations={report.n1_violations_detected}")
    except Exception as e:
        errors.append(f"Safety layer failed for {tid}: {e}")
        print(f"  FAIL: {tid} - {e}")

# --- 7. Curriculum order test ---
print("\n[7/7] Curriculum Order Test...")
from training.train_grpo import CURRICULUM_ORDER
for tid in CURRICULUM_ORDER:
    if tid in TASKS:
        print(f"  OK: {tid} available")
    else:
        errors.append(f"Curriculum task missing: {tid}")
        print(f"  FAIL: {tid} not in TASKS")

# --- Summary ---
print("\n" + "=" * 60)
if errors:
    print(f"  FAILED: {len(errors)} errors")
    for e in errors:
        print(f"    - {e}")
    sys.exit(1)
else:
    print("  ALL CHECKS PASSED - Training pipeline ready")
print("=" * 60)