Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| SUMO-RL Environment Server Implementation. | |
| This module wraps the SUMO-RL SumoEnvironment and exposes it | |
| via the OpenEnv Environment interface for traffic signal control. | |
| """ | |
| import os | |
| import uuid | |
| from typing import Any, Dict | |
| # Set SUMO_HOME before importing sumo_rl | |
| os.environ.setdefault("SUMO_HOME", "/usr/share/sumo") | |
| from openenv.core.env_server import Action, Environment, Observation | |
| # Support both in-repo and standalone imports | |
| try: | |
| # In-repo imports (when running from OpenEnv repository) | |
| from ..models import SumoAction, SumoObservation, SumoState | |
| except ImportError as e: | |
| if "relative import" not in str(e) and "no known parent package" not in str(e): | |
| raise | |
| # Standalone imports (when running via uvicorn server.app:app) | |
| from models import SumoAction, SumoObservation, SumoState | |
| # Import SUMO-RL | |
| try: | |
| from sumo_rl import SumoEnvironment as BaseSumoEnv | |
| except ImportError as e: | |
| raise ImportError( | |
| "sumo-rl is not installed. Please install it with: pip install sumo-rl" | |
| ) from e | |
| def _json_safe(value: Any) -> Any: | |
| """Convert library-specific values into JSON-serializable Python types.""" | |
| if value is None or isinstance(value, (str, int, float, bool)): | |
| return value | |
| if isinstance(value, dict): | |
| return {str(key): _json_safe(item) for key, item in value.items()} | |
| if isinstance(value, (list, tuple, set)): | |
| return [_json_safe(item) for item in value] | |
| tolist = getattr(value, "tolist", None) | |
| if callable(tolist): | |
| try: | |
| return tolist() | |
| except Exception: | |
| pass | |
| item = getattr(value, "item", None) | |
| if callable(item): | |
| try: | |
| return item() | |
| except Exception: | |
| pass | |
| return str(value) | |
| class SumoEnvironment(Environment): | |
| """ | |
| SUMO-RL Environment wrapper for OpenEnv. | |
| This environment wraps the SUMO traffic signal control environment | |
| for single-agent reinforcement learning. | |
| Args: | |
| net_file: Path to SUMO network file (.net.xml) | |
| route_file: Path to SUMO route file (.rou.xml) | |
| num_seconds: Simulation duration in seconds (default: 20000) | |
| delta_time: Seconds between agent actions (default: 5) | |
| yellow_time: Yellow phase duration in seconds (default: 2) | |
| min_green: Minimum green time in seconds (default: 5) | |
| max_green: Maximum green time in seconds (default: 50) | |
| reward_fn: Reward function name (default: "diff-waiting-time") | |
| sumo_seed: Random seed for reproducibility (default: 42) | |
| Example: | |
| >>> env = SumoEnvironment( | |
| ... net_file="/app/nets/single-intersection.net.xml", | |
| ... route_file="/app/nets/single-intersection.rou.xml" | |
| ... ) | |
| >>> obs = env.reset() | |
| >>> print(obs.observation_shape) | |
| >>> obs = env.step(SumoAction(phase_id=1)) | |
| >>> print(obs.reward, obs.done) | |
| """ | |
| def __init__( | |
| self, | |
| net_file: str, | |
| route_file: str, | |
| num_seconds: int = 20000, | |
| delta_time: int = 5, | |
| yellow_time: int = 2, | |
| min_green: int = 5, | |
| max_green: int = 50, | |
| reward_fn: str = "diff-waiting-time", | |
| sumo_seed: int = 42, | |
| ): | |
| """Initialize SUMO traffic signal environment.""" | |
| super().__init__() | |
| # Store configuration | |
| self.net_file = net_file | |
| self.route_file = route_file | |
| self.num_seconds = num_seconds | |
| self.delta_time = delta_time | |
| self.yellow_time = yellow_time | |
| self.min_green = min_green | |
| self.max_green = max_green | |
| self.reward_fn = reward_fn | |
| self.sumo_seed = sumo_seed | |
| # Create SUMO environment (single-agent mode) | |
| # Key settings: | |
| # - use_gui=False: No GUI in Docker | |
| # - single_agent=True: Returns single obs/reward (not dict) | |
| # - sumo_warnings=False: Suppress SUMO warnings | |
| # - out_csv_name=None: Don't write CSV files | |
| self.env = BaseSumoEnv( | |
| net_file=net_file, | |
| route_file=route_file, | |
| use_gui=False, | |
| single_agent=True, | |
| num_seconds=num_seconds, | |
| delta_time=delta_time, | |
| yellow_time=yellow_time, | |
| min_green=min_green, | |
| max_green=max_green, | |
| reward_fn=reward_fn, | |
| sumo_seed=sumo_seed, | |
| sumo_warnings=False, | |
| out_csv_name=None, # Disable CSV output | |
| add_system_info=True, | |
| add_per_agent_info=False, | |
| ) | |
| # Initialize state | |
| self._state = SumoState( | |
| net_file=net_file, | |
| route_file=route_file, | |
| num_seconds=num_seconds, | |
| delta_time=delta_time, | |
| yellow_time=yellow_time, | |
| min_green=min_green, | |
| max_green=max_green, | |
| reward_fn=reward_fn, | |
| ) | |
| self._last_info = {} | |
| def reset(self) -> Observation: | |
| """ | |
| Reset the environment and return initial observation. | |
| Returns: | |
| Initial SumoObservation for the agent. | |
| """ | |
| # Reset SUMO simulation | |
| obs, info = self.env.reset() | |
| # Update state tracking | |
| self._state.episode_id = str(uuid.uuid4()) | |
| self._state.step_count = 0 | |
| self._state.sim_time = 0.0 | |
| # Store info for metadata | |
| self._last_info = info | |
| return self._make_observation(obs, reward=None, done=False, info=info) | |
| def step(self, action: Action) -> Observation: | |
| """ | |
| Execute agent's action and return resulting observation. | |
| Args: | |
| action: SumoAction containing the phase_id to execute. | |
| Returns: | |
| SumoObservation after action execution. | |
| Raises: | |
| ValueError: If action is not a SumoAction. | |
| """ | |
| if not isinstance(action, SumoAction): | |
| raise ValueError(f"Expected SumoAction, got {type(action)}") | |
| # Validate phase_id | |
| num_phases = self.env.action_space.n | |
| if action.phase_id < 0 or action.phase_id >= num_phases: | |
| raise ValueError( | |
| f"Invalid phase_id: {action.phase_id}. " | |
| f"Valid range: [0, {num_phases - 1}]" | |
| ) | |
| # Execute action in SUMO | |
| # Returns: (obs, reward, terminated, truncated, info) | |
| obs, reward, terminated, truncated, info = self.env.step(action.phase_id) | |
| done = terminated or truncated | |
| # Update state | |
| self._state.step_count += 1 | |
| self._state.sim_time = info.get("step", 0.0) | |
| self._state.total_vehicles = info.get("system_total_running", 0) | |
| self._state.total_waiting_time = info.get("system_total_waiting_time", 0.0) | |
| self._state.mean_waiting_time = info.get("system_mean_waiting_time", 0.0) | |
| self._state.mean_speed = info.get("system_mean_speed", 0.0) | |
| # Store info for metadata | |
| self._last_info = info | |
| return self._make_observation(obs, reward=reward, done=done, info=info) | |
| def state(self) -> SumoState: | |
| """Get current environment state.""" | |
| return self._state | |
| def _make_observation( | |
| self, obs: Any, reward: float, done: bool, info: Dict | |
| ) -> SumoObservation: | |
| """ | |
| Create SumoObservation from SUMO environment output. | |
| Args: | |
| obs: Observation array from SUMO environment | |
| reward: Reward value (None on reset) | |
| done: Whether episode is complete | |
| info: Info dictionary from SUMO environment | |
| Returns: | |
| SumoObservation for the agent. | |
| """ | |
| # Convert observation to list | |
| if hasattr(obs, "tolist"): | |
| obs_list = obs.tolist() | |
| else: | |
| obs_list = list(obs) | |
| # Get action mask (all actions valid in SUMO-RL) | |
| num_phases = self.env.action_space.n | |
| action_mask = list(range(num_phases)) | |
| # Extract system metrics for metadata | |
| system_info = {k: v for k, v in info.items() if k.startswith("system_")} | |
| # Create observation | |
| return SumoObservation( | |
| observation=obs_list, | |
| observation_shape=[len(obs_list)], | |
| action_mask=action_mask, | |
| sim_time=info.get("step", 0.0), | |
| done=done, | |
| reward=reward, | |
| metadata={ | |
| "num_green_phases": num_phases, | |
| "system_info": _json_safe(system_info), | |
| }, | |
| ) | |