| """ |
| Tests for multi-agent POMDP features: |
| - Zone assignment and partitioning |
| - Partial observability (ZoneObservation) |
| - Safety layer (action validation and correction) |
| - Oversight agent (coordination monitoring) |
| - Multi-agent step (combined pipeline) |
| """ |
|
|
| import copy |
| import unittest |
|
|
| import networkx as nx |
| import numpy as np |
|
|
| from src.environment import OpenGridEnv |
| from src.tasks import TASKS |
| from src.models import GridAction, BusAdjustment, TopologyAction, ZoneObservation |
| from src.safety import SafetyLayer |
| from src.oversight import OversightAgent |
|
|
|
|
| def task(task_id: str): |
| """Get a deep-copied task config to prevent cross-test contamination.""" |
| return copy.deepcopy(TASKS[task_id]) |
|
|
|
|
| class TestZoneAssignment(unittest.TestCase): |
| """Tests for multi-agent zone partitioning.""" |
|
|
| def test_all_buses_assigned(self): |
| """Every bus should be assigned to exactly one zone.""" |
| for task_id, config in TASKS.items(): |
| zone_map = config['zone_assignments'] |
| for i in range(config['num_buses']): |
| self.assertIn(i, zone_map, f"Bus {i} not assigned in {task_id}") |
|
|
| def test_zone_count_matches(self): |
| """Number of zones should match num_agents.""" |
| for task_id, config in TASKS.items(): |
| agents = set(config['zone_assignments'].values()) |
| self.assertEqual(len(agents), config['num_agents'], |
| f"Zone count mismatch in {task_id}") |
|
|
| def test_no_empty_zones(self): |
| """Each zone should have at least 1 bus.""" |
| for task_id, config in TASKS.items(): |
| for agent_id in range(config['num_agents']): |
| bus_ids = config['zone_bus_ids'][agent_id] |
| self.assertGreater(len(bus_ids), 0, |
| f"Empty zone {agent_id} in {task_id}") |
|
|
| def test_lines_classified(self): |
| """All lines should be classified as internal or boundary.""" |
| for task_id, config in TASKS.items(): |
| all_internal = set() |
| all_boundary = set() |
| for agent_id in range(config['num_agents']): |
| all_internal.update(config['internal_lines'].get(agent_id, [])) |
| all_boundary.update(config['boundary_lines'].get(agent_id, [])) |
|
|
| all_line_ids = {l['id'] for l in config['lines']} |
| classified = all_internal | all_boundary |
| self.assertEqual(all_line_ids, classified, |
| f"Unclassified lines in {task_id}") |
|
|
|
|
| class TestPartialObservability(unittest.TestCase): |
| """Tests for POMDP zone observations.""" |
|
|
| def test_partial_obs_returns_zone_obs(self): |
| """reset_multi should return ZoneObservation for each agent.""" |
| config = task("task_easy") |
| env = OpenGridEnv(config) |
| zone_obs = env.reset_multi() |
|
|
| self.assertEqual(len(zone_obs), config["num_agents"], |
| "Should have one observation per agent") |
| for agent_id, obs in zone_obs.items(): |
| self.assertIsInstance(obs, ZoneObservation) |
| self.assertEqual(obs.agent_id, agent_id) |
|
|
| def test_partial_obs_only_shows_local_buses(self): |
| """Each agent should only see buses in their zone.""" |
| config = task("task_medium") |
| env = OpenGridEnv(config) |
| zone_obs = env.reset_multi() |
|
|
| for agent_id, obs in zone_obs.items(): |
| expected_bus_ids = set(config['zone_bus_ids'][agent_id]) |
| actual_bus_ids = {b.id for b in obs.local_buses} |
| self.assertEqual(actual_bus_ids, expected_bus_ids, |
| f"Agent {agent_id} sees wrong buses") |
|
|
| def test_frequency_has_noise(self): |
| """POMDP observations should have noisy frequency readings.""" |
| config = task("task_easy") |
| env = OpenGridEnv(config) |
| env.reset_multi() |
|
|
| |
| full_obs = env.state() |
| differences = [] |
| for agent_id in range(config['num_agents']): |
| z_obs = env._get_zone_obs(agent_id) |
| diff = abs(z_obs.grid_frequency - full_obs.grid_frequency) |
| differences.append(diff) |
|
|
| |
| self.assertTrue(any(d > 0.001 for d in differences), |
| "No frequency noise detected in POMDP observations") |
|
|
|
|
| class TestSafetyLayer(unittest.TestCase): |
| """Tests for the safety constraint filter.""" |
|
|
| def setUp(self): |
| self.config = task("task_medium") |
| self.safety = SafetyLayer(self.config) |
| self.env = OpenGridEnv(self.config) |
| self.env.reset() |
|
|
| def test_zone_boundary_enforcement(self): |
| """Agent should not be able to adjust buses in another zone.""" |
| agent_0_buses = set(self.config['zone_bus_ids'][0]) |
| other_bus = None |
| for bus_cfg in self.config['buses']: |
| if bus_cfg['id'] not in agent_0_buses: |
| other_bus = bus_cfg['id'] |
| break |
|
|
| if other_bus is None: |
| self.skipTest("All buses in agent 0's zone (trivial grid)") |
|
|
| action = GridAction(bus_adjustments=[ |
| BusAdjustment(bus_id=other_bus, delta=10.0) |
| ]) |
|
|
| corrected, report = self.safety.validate_and_correct( |
| agent_id=0, |
| proposed_action=action, |
| current_line_state=self.env.line_state, |
| current_bus_state=self.env.bus_state, |
| cooldowns=self.env.cooldowns, |
| ) |
|
|
| self.assertTrue(report.was_corrected, "Should have corrected cross-zone action") |
| self.assertEqual(len(corrected.bus_adjustments), 0, |
| "Cross-zone adjustment should be removed") |
|
|
| def test_safe_action_passes_through(self): |
| """A valid action within the agent's zone should not be corrected.""" |
| agent_0_buses = self.config['zone_bus_ids'][0] |
| controllable = None |
| for bus_cfg in self.config['buses']: |
| if bus_cfg['id'] in agent_0_buses and bus_cfg['type'] in ['generator', 'battery', 'slack']: |
| controllable = bus_cfg['id'] |
| break |
|
|
| if controllable is None: |
| self.skipTest("No controllable bus in agent 0's zone") |
|
|
| action = GridAction(bus_adjustments=[ |
| BusAdjustment(bus_id=controllable, delta=5.0) |
| ]) |
|
|
| corrected, report = self.safety.validate_and_correct( |
| agent_id=0, |
| proposed_action=action, |
| current_line_state=self.env.line_state, |
| current_bus_state=self.env.bus_state, |
| cooldowns=self.env.cooldowns, |
| ) |
|
|
| |
| self.assertEqual(len(corrected.bus_adjustments), 1, |
| "Valid action should produce one adjustment") |
|
|
| def test_islanding_blocked(self): |
| """Opening a bridge line should be blocked by safety layer.""" |
| G = nx.Graph() |
| for line in self.config['lines']: |
| G.add_edge(line['from'], line['to']) |
| bridges = list(nx.bridges(G)) |
| if not bridges: |
| self.skipTest("No bridges in grid topology") |
|
|
| bridge = bridges[0] |
| line_id = next( |
| l['id'] for l in self.config['lines'] |
| if (l['from'], l['to']) == bridge or (l['to'], l['from']) == bridge |
| ) |
|
|
| action = GridAction(topology_actions=[ |
| TopologyAction(line_id=line_id, action="open") |
| ]) |
|
|
| corrected, report = self.safety.validate_and_correct( |
| agent_id=0, |
| proposed_action=action, |
| current_line_state=self.env.line_state, |
| current_bus_state=self.env.bus_state, |
| cooldowns=self.env.cooldowns, |
| ) |
|
|
| self.assertTrue(report.was_corrected, "Bridge opening should be blocked") |
| self.assertEqual(len(corrected.topology_actions), 0, |
| "Bridge opening should be removed") |
|
|
| def test_duplicate_battery_adjustments_aggregated(self): |
| """Multiple adjustments to the same battery should be aggregated.""" |
| battery = next( |
| (b for b in self.config['buses'] if b['type'] == 'battery'), None |
| ) |
| if battery is None: |
| self.skipTest("No battery in task") |
|
|
| bus_id = battery['id'] |
| agent_id = self.config['zone_assignments'].get(bus_id, 0) |
|
|
| |
| for b in self.env.bus_state: |
| if b['id'] == bus_id: |
| b['soc'] = 10.0 |
|
|
| action = GridAction(bus_adjustments=[ |
| BusAdjustment(bus_id=bus_id, delta=8.0), |
| BusAdjustment(bus_id=bus_id, delta=8.0), |
| ]) |
|
|
| corrected, report = self.safety.validate_and_correct( |
| agent_id=agent_id, |
| proposed_action=action, |
| current_line_state=self.env.line_state, |
| current_bus_state=self.env.bus_state, |
| cooldowns=self.env.cooldowns, |
| ) |
|
|
| total_delta = sum(a.delta for a in corrected.bus_adjustments) |
| self.assertLessEqual(total_delta, 10.0, |
| "Combined discharge should not exceed SOC") |
|
|
|
|
| class TestOversightAgent(unittest.TestCase): |
| """Tests for the coordination oversight agent.""" |
|
|
| def test_no_conflict_scores_high(self): |
| """Cooperative actions should score high coordination.""" |
| config = task("task_easy") |
| oversight = OversightAgent(config) |
|
|
| |
| agent_actions = { |
| 0: GridAction(bus_adjustments=[BusAdjustment(bus_id=0, delta=5.0)]), |
| 1: GridAction(bus_adjustments=[BusAdjustment(bus_id=1, delta=3.0)]), |
| } |
|
|
| report = oversight.evaluate( |
| agent_actions=agent_actions, |
| safety_reports={}, |
| pre_frequency=49.8, |
| post_frequency=49.9, |
| pre_bus_state=[], |
| post_bus_state=[], |
| ) |
|
|
| self.assertGreater(report.coordination_score, 0.5, |
| "Cooperative actions should score > 0.5") |
|
|
| def test_reset_clears_history(self): |
| """Resetting oversight should clear intervention history.""" |
| config = task("task_easy") |
| oversight = OversightAgent(config) |
| oversight.intervention_history[0] = 5 |
| oversight.reset() |
| self.assertEqual(oversight.intervention_history[0], 0) |
|
|
|
|
| class TestMultiAgentStep(unittest.TestCase): |
| """Integration tests for the full multi-agent pipeline.""" |
|
|
| def test_multi_agent_step_returns_result(self): |
| """step_multi should return a complete MultiAgentStepResult.""" |
| config = task("task_easy") |
| env = OpenGridEnv(config) |
| env.reset_multi() |
|
|
| |
| actions = {i: GridAction() for i in range(config['num_agents'])} |
| result = env.step_multi(actions) |
|
|
| self.assertEqual(len(result.observations), config['num_agents']) |
| self.assertEqual(len(result.rewards), config['num_agents']) |
| self.assertIsInstance(result.team_reward, float) |
| self.assertIsInstance(result.done, bool) |
| self.assertEqual(len(result.safety_reports), config['num_agents']) |
|
|
| def test_safety_reports_match_agent_ids(self): |
| """Safety reports should contain all expected agent IDs.""" |
| config = task("task_easy") |
| env = OpenGridEnv(config) |
| env.reset_multi() |
|
|
| result = env.step_multi({ |
| i: GridAction() for i in range(config['num_agents']) |
| }) |
|
|
| report_ids = set(result.safety_reports.keys()) |
| expected_ids = set(range(config['num_agents'])) |
| self.assertEqual(report_ids, expected_ids, |
| "Safety report agent IDs should match expected agents") |
|
|
| def test_multi_agent_episode_completes(self): |
| """A full multi-agent episode should complete without errors.""" |
| config = task("task_easy") |
| env = OpenGridEnv(config) |
| env.reset_multi() |
|
|
| done = False |
| steps = 0 |
| while not done and steps < config['max_steps'] + 5: |
| actions = {i: GridAction() for i in range(config['num_agents'])} |
| result = env.step_multi(actions) |
| done = result.done |
| steps += 1 |
|
|
| self.assertTrue(done, "Episode should terminate") |
| self.assertLessEqual(steps, config['max_steps'] + 1) |
|
|
| def test_backward_compatibility(self): |
| """Single-agent reset/step should still work after multi-agent changes.""" |
| for task_id in TASKS: |
| config = task(task_id) |
| env = OpenGridEnv(config) |
| obs = env.reset() |
| self.assertGreater(len(obs.buses), 0, |
| f"No buses in {task_id}") |
|
|
| obs, reward, done, info = env.step(GridAction()) |
| self.assertEqual(obs.timestep, 1) |
| self.assertIsInstance(reward.value, float) |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|