File size: 12,984 Bytes
78131a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Tests for multi-agent POMDP features:
- Zone assignment and partitioning
- Partial observability (ZoneObservation)
- Safety layer (action validation and correction)
- Oversight agent (coordination monitoring)
- Multi-agent step (combined pipeline)
"""

import copy
import unittest

import networkx as nx
import numpy as np

from src.environment import OpenGridEnv
from src.tasks import TASKS
from src.models import GridAction, BusAdjustment, TopologyAction, ZoneObservation
from src.safety import SafetyLayer
from src.oversight import OversightAgent


def task(task_id: str):
    """Get a deep-copied task config to prevent cross-test contamination."""
    return copy.deepcopy(TASKS[task_id])


class TestZoneAssignment(unittest.TestCase):
    """Tests for multi-agent zone partitioning."""

    def test_all_buses_assigned(self):
        """Every bus should be assigned to exactly one zone."""
        for task_id, config in TASKS.items():
            zone_map = config['zone_assignments']
            for i in range(config['num_buses']):
                self.assertIn(i, zone_map, f"Bus {i} not assigned in {task_id}")

    def test_zone_count_matches(self):
        """Number of zones should match num_agents."""
        for task_id, config in TASKS.items():
            agents = set(config['zone_assignments'].values())
            self.assertEqual(len(agents), config['num_agents'],
                             f"Zone count mismatch in {task_id}")

    def test_no_empty_zones(self):
        """Each zone should have at least 1 bus."""
        for task_id, config in TASKS.items():
            for agent_id in range(config['num_agents']):
                bus_ids = config['zone_bus_ids'][agent_id]
                self.assertGreater(len(bus_ids), 0,
                                   f"Empty zone {agent_id} in {task_id}")

    def test_lines_classified(self):
        """All lines should be classified as internal or boundary."""
        for task_id, config in TASKS.items():
            all_internal = set()
            all_boundary = set()
            for agent_id in range(config['num_agents']):
                all_internal.update(config['internal_lines'].get(agent_id, []))
                all_boundary.update(config['boundary_lines'].get(agent_id, []))

            all_line_ids = {l['id'] for l in config['lines']}
            classified = all_internal | all_boundary
            self.assertEqual(all_line_ids, classified,
                             f"Unclassified lines in {task_id}")


class TestPartialObservability(unittest.TestCase):
    """Tests for POMDP zone observations."""

    def test_partial_obs_returns_zone_obs(self):
        """reset_multi should return ZoneObservation for each agent."""
        config = task("task_easy")
        env = OpenGridEnv(config)
        zone_obs = env.reset_multi()

        self.assertEqual(len(zone_obs), config["num_agents"],
                         "Should have one observation per agent")
        for agent_id, obs in zone_obs.items():
            self.assertIsInstance(obs, ZoneObservation)
            self.assertEqual(obs.agent_id, agent_id)

    def test_partial_obs_only_shows_local_buses(self):
        """Each agent should only see buses in their zone."""
        config = task("task_medium")
        env = OpenGridEnv(config)
        zone_obs = env.reset_multi()

        for agent_id, obs in zone_obs.items():
            expected_bus_ids = set(config['zone_bus_ids'][agent_id])
            actual_bus_ids = {b.id for b in obs.local_buses}
            self.assertEqual(actual_bus_ids, expected_bus_ids,
                             f"Agent {agent_id} sees wrong buses")

    def test_frequency_has_noise(self):
        """POMDP observations should have noisy frequency readings."""
        config = task("task_easy")
        env = OpenGridEnv(config)
        env.reset_multi()

        # Compare zone obs against full obs from the same reset
        full_obs = env.state()
        differences = []
        for agent_id in range(config['num_agents']):
            z_obs = env._get_zone_obs(agent_id)
            diff = abs(z_obs.grid_frequency - full_obs.grid_frequency)
            differences.append(diff)

        # At least one agent should see noisy frequency
        self.assertTrue(any(d > 0.001 for d in differences),
                        "No frequency noise detected in POMDP observations")


class TestSafetyLayer(unittest.TestCase):
    """Tests for the safety constraint filter."""

    def setUp(self):
        self.config = task("task_medium")
        self.safety = SafetyLayer(self.config)
        self.env = OpenGridEnv(self.config)
        self.env.reset()

    def test_zone_boundary_enforcement(self):
        """Agent should not be able to adjust buses in another zone."""
        agent_0_buses = set(self.config['zone_bus_ids'][0])
        other_bus = None
        for bus_cfg in self.config['buses']:
            if bus_cfg['id'] not in agent_0_buses:
                other_bus = bus_cfg['id']
                break

        if other_bus is None:
            self.skipTest("All buses in agent 0's zone (trivial grid)")

        action = GridAction(bus_adjustments=[
            BusAdjustment(bus_id=other_bus, delta=10.0)
        ])

        corrected, report = self.safety.validate_and_correct(
            agent_id=0,
            proposed_action=action,
            current_line_state=self.env.line_state,
            current_bus_state=self.env.bus_state,
            cooldowns=self.env.cooldowns,
        )

        self.assertTrue(report.was_corrected, "Should have corrected cross-zone action")
        self.assertEqual(len(corrected.bus_adjustments), 0,
                         "Cross-zone adjustment should be removed")

    def test_safe_action_passes_through(self):
        """A valid action within the agent's zone should not be corrected."""
        agent_0_buses = self.config['zone_bus_ids'][0]
        controllable = None
        for bus_cfg in self.config['buses']:
            if bus_cfg['id'] in agent_0_buses and bus_cfg['type'] in ['generator', 'battery', 'slack']:
                controllable = bus_cfg['id']
                break

        if controllable is None:
            self.skipTest("No controllable bus in agent 0's zone")

        action = GridAction(bus_adjustments=[
            BusAdjustment(bus_id=controllable, delta=5.0)
        ])

        corrected, report = self.safety.validate_and_correct(
            agent_id=0,
            proposed_action=action,
            current_line_state=self.env.line_state,
            current_bus_state=self.env.bus_state,
            cooldowns=self.env.cooldowns,
        )

        # Should pass through (may have minor clamping)
        self.assertEqual(len(corrected.bus_adjustments), 1,
                         "Valid action should produce one adjustment")

    def test_islanding_blocked(self):
        """Opening a bridge line should be blocked by safety layer."""
        G = nx.Graph()
        for line in self.config['lines']:
            G.add_edge(line['from'], line['to'])
        bridges = list(nx.bridges(G))
        if not bridges:
            self.skipTest("No bridges in grid topology")

        bridge = bridges[0]
        line_id = next(
            l['id'] for l in self.config['lines']
            if (l['from'], l['to']) == bridge or (l['to'], l['from']) == bridge
        )

        action = GridAction(topology_actions=[
            TopologyAction(line_id=line_id, action="open")
        ])

        corrected, report = self.safety.validate_and_correct(
            agent_id=0,
            proposed_action=action,
            current_line_state=self.env.line_state,
            current_bus_state=self.env.bus_state,
            cooldowns=self.env.cooldowns,
        )

        self.assertTrue(report.was_corrected, "Bridge opening should be blocked")
        self.assertEqual(len(corrected.topology_actions), 0,
                         "Bridge opening should be removed")

    def test_duplicate_battery_adjustments_aggregated(self):
        """Multiple adjustments to the same battery should be aggregated."""
        battery = next(
            (b for b in self.config['buses'] if b['type'] == 'battery'), None
        )
        if battery is None:
            self.skipTest("No battery in task")

        bus_id = battery['id']
        agent_id = self.config['zone_assignments'].get(bus_id, 0)

        # Set SOC to a known value
        for b in self.env.bus_state:
            if b['id'] == bus_id:
                b['soc'] = 10.0

        action = GridAction(bus_adjustments=[
            BusAdjustment(bus_id=bus_id, delta=8.0),
            BusAdjustment(bus_id=bus_id, delta=8.0),
        ])

        corrected, report = self.safety.validate_and_correct(
            agent_id=agent_id,
            proposed_action=action,
            current_line_state=self.env.line_state,
            current_bus_state=self.env.bus_state,
            cooldowns=self.env.cooldowns,
        )

        total_delta = sum(a.delta for a in corrected.bus_adjustments)
        self.assertLessEqual(total_delta, 10.0,
                             "Combined discharge should not exceed SOC")


class TestOversightAgent(unittest.TestCase):
    """Tests for the coordination oversight agent."""

    def test_no_conflict_scores_high(self):
        """Cooperative actions should score high coordination."""
        config = task("task_easy")
        oversight = OversightAgent(config)

        # Both agents inject (cooperative)
        agent_actions = {
            0: GridAction(bus_adjustments=[BusAdjustment(bus_id=0, delta=5.0)]),
            1: GridAction(bus_adjustments=[BusAdjustment(bus_id=1, delta=3.0)]),
        }

        report = oversight.evaluate(
            agent_actions=agent_actions,
            safety_reports={},
            pre_frequency=49.8,
            post_frequency=49.9,
            pre_bus_state=[],
            post_bus_state=[],
        )

        self.assertGreater(report.coordination_score, 0.5,
                           "Cooperative actions should score > 0.5")

    def test_reset_clears_history(self):
        """Resetting oversight should clear intervention history."""
        config = task("task_easy")
        oversight = OversightAgent(config)
        oversight.intervention_history[0] = 5
        oversight.reset()
        self.assertEqual(oversight.intervention_history[0], 0)


class TestMultiAgentStep(unittest.TestCase):
    """Integration tests for the full multi-agent pipeline."""

    def test_multi_agent_step_returns_result(self):
        """step_multi should return a complete MultiAgentStepResult."""
        config = task("task_easy")
        env = OpenGridEnv(config)
        env.reset_multi()

        # No-op actions for all agents
        actions = {i: GridAction() for i in range(config['num_agents'])}
        result = env.step_multi(actions)

        self.assertEqual(len(result.observations), config['num_agents'])
        self.assertEqual(len(result.rewards), config['num_agents'])
        self.assertIsInstance(result.team_reward, float)
        self.assertIsInstance(result.done, bool)
        self.assertEqual(len(result.safety_reports), config['num_agents'])

    def test_safety_reports_match_agent_ids(self):
        """Safety reports should contain all expected agent IDs."""
        config = task("task_easy")
        env = OpenGridEnv(config)
        env.reset_multi()

        result = env.step_multi({
            i: GridAction() for i in range(config['num_agents'])
        })

        report_ids = set(result.safety_reports.keys())
        expected_ids = set(range(config['num_agents']))
        self.assertEqual(report_ids, expected_ids,
                         "Safety report agent IDs should match expected agents")

    def test_multi_agent_episode_completes(self):
        """A full multi-agent episode should complete without errors."""
        config = task("task_easy")
        env = OpenGridEnv(config)
        env.reset_multi()

        done = False
        steps = 0
        while not done and steps < config['max_steps'] + 5:
            actions = {i: GridAction() for i in range(config['num_agents'])}
            result = env.step_multi(actions)
            done = result.done
            steps += 1

        self.assertTrue(done, "Episode should terminate")
        self.assertLessEqual(steps, config['max_steps'] + 1)

    def test_backward_compatibility(self):
        """Single-agent reset/step should still work after multi-agent changes."""
        for task_id in TASKS:
            config = task(task_id)
            env = OpenGridEnv(config)
            obs = env.reset()
            self.assertGreater(len(obs.buses), 0,
                               f"No buses in {task_id}")

            obs, reward, done, info = env.step(GridAction())
            self.assertEqual(obs.timestep, 1)
            self.assertIsInstance(reward.value, float)


if __name__ == '__main__':
    unittest.main()