sumo_rl_env / server /sumo_environment.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
429558a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
SUMO-RL Environment Server Implementation.
This module wraps the SUMO-RL SumoEnvironment and exposes it
via the OpenEnv Environment interface for traffic signal control.
"""
import os
import uuid
from typing import Any, Dict
# Set SUMO_HOME before importing sumo_rl
os.environ.setdefault("SUMO_HOME", "/usr/share/sumo")
from openenv.core.env_server import Action, Environment, Observation
# Support both in-repo and standalone imports
try:
# In-repo imports (when running from OpenEnv repository)
from ..models import SumoAction, SumoObservation, SumoState
except ImportError as e:
if "relative import" not in str(e) and "no known parent package" not in str(e):
raise
# Standalone imports (when running via uvicorn server.app:app)
from models import SumoAction, SumoObservation, SumoState
# Import SUMO-RL
try:
from sumo_rl import SumoEnvironment as BaseSumoEnv
except ImportError as e:
raise ImportError(
"sumo-rl is not installed. Please install it with: pip install sumo-rl"
) from e
def _json_safe(value: Any) -> Any:
"""Convert library-specific values into JSON-serializable Python types."""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in value.items()}
if isinstance(value, (list, tuple, set)):
return [_json_safe(item) for item in value]
tolist = getattr(value, "tolist", None)
if callable(tolist):
try:
return tolist()
except Exception:
pass
item = getattr(value, "item", None)
if callable(item):
try:
return item()
except Exception:
pass
return str(value)
class SumoEnvironment(Environment):
"""
SUMO-RL Environment wrapper for OpenEnv.
This environment wraps the SUMO traffic signal control environment
for single-agent reinforcement learning.
Args:
net_file: Path to SUMO network file (.net.xml)
route_file: Path to SUMO route file (.rou.xml)
num_seconds: Simulation duration in seconds (default: 20000)
delta_time: Seconds between agent actions (default: 5)
yellow_time: Yellow phase duration in seconds (default: 2)
min_green: Minimum green time in seconds (default: 5)
max_green: Maximum green time in seconds (default: 50)
reward_fn: Reward function name (default: "diff-waiting-time")
sumo_seed: Random seed for reproducibility (default: 42)
Example:
>>> env = SumoEnvironment(
... net_file="/app/nets/single-intersection.net.xml",
... route_file="/app/nets/single-intersection.rou.xml"
... )
>>> obs = env.reset()
>>> print(obs.observation_shape)
>>> obs = env.step(SumoAction(phase_id=1))
>>> print(obs.reward, obs.done)
"""
def __init__(
self,
net_file: str,
route_file: str,
num_seconds: int = 20000,
delta_time: int = 5,
yellow_time: int = 2,
min_green: int = 5,
max_green: int = 50,
reward_fn: str = "diff-waiting-time",
sumo_seed: int = 42,
):
"""Initialize SUMO traffic signal environment."""
super().__init__()
# Store configuration
self.net_file = net_file
self.route_file = route_file
self.num_seconds = num_seconds
self.delta_time = delta_time
self.yellow_time = yellow_time
self.min_green = min_green
self.max_green = max_green
self.reward_fn = reward_fn
self.sumo_seed = sumo_seed
# Create SUMO environment (single-agent mode)
# Key settings:
# - use_gui=False: No GUI in Docker
# - single_agent=True: Returns single obs/reward (not dict)
# - sumo_warnings=False: Suppress SUMO warnings
# - out_csv_name=None: Don't write CSV files
self.env = BaseSumoEnv(
net_file=net_file,
route_file=route_file,
use_gui=False,
single_agent=True,
num_seconds=num_seconds,
delta_time=delta_time,
yellow_time=yellow_time,
min_green=min_green,
max_green=max_green,
reward_fn=reward_fn,
sumo_seed=sumo_seed,
sumo_warnings=False,
out_csv_name=None, # Disable CSV output
add_system_info=True,
add_per_agent_info=False,
)
# Initialize state
self._state = SumoState(
net_file=net_file,
route_file=route_file,
num_seconds=num_seconds,
delta_time=delta_time,
yellow_time=yellow_time,
min_green=min_green,
max_green=max_green,
reward_fn=reward_fn,
)
self._last_info = {}
def reset(self) -> Observation:
"""
Reset the environment and return initial observation.
Returns:
Initial SumoObservation for the agent.
"""
# Reset SUMO simulation
obs, info = self.env.reset()
# Update state tracking
self._state.episode_id = str(uuid.uuid4())
self._state.step_count = 0
self._state.sim_time = 0.0
# Store info for metadata
self._last_info = info
return self._make_observation(obs, reward=None, done=False, info=info)
def step(self, action: Action) -> Observation:
"""
Execute agent's action and return resulting observation.
Args:
action: SumoAction containing the phase_id to execute.
Returns:
SumoObservation after action execution.
Raises:
ValueError: If action is not a SumoAction.
"""
if not isinstance(action, SumoAction):
raise ValueError(f"Expected SumoAction, got {type(action)}")
# Validate phase_id
num_phases = self.env.action_space.n
if action.phase_id < 0 or action.phase_id >= num_phases:
raise ValueError(
f"Invalid phase_id: {action.phase_id}. "
f"Valid range: [0, {num_phases - 1}]"
)
# Execute action in SUMO
# Returns: (obs, reward, terminated, truncated, info)
obs, reward, terminated, truncated, info = self.env.step(action.phase_id)
done = terminated or truncated
# Update state
self._state.step_count += 1
self._state.sim_time = info.get("step", 0.0)
self._state.total_vehicles = info.get("system_total_running", 0)
self._state.total_waiting_time = info.get("system_total_waiting_time", 0.0)
self._state.mean_waiting_time = info.get("system_mean_waiting_time", 0.0)
self._state.mean_speed = info.get("system_mean_speed", 0.0)
# Store info for metadata
self._last_info = info
return self._make_observation(obs, reward=reward, done=done, info=info)
@property
def state(self) -> SumoState:
"""Get current environment state."""
return self._state
def _make_observation(
self, obs: Any, reward: float, done: bool, info: Dict
) -> SumoObservation:
"""
Create SumoObservation from SUMO environment output.
Args:
obs: Observation array from SUMO environment
reward: Reward value (None on reset)
done: Whether episode is complete
info: Info dictionary from SUMO environment
Returns:
SumoObservation for the agent.
"""
# Convert observation to list
if hasattr(obs, "tolist"):
obs_list = obs.tolist()
else:
obs_list = list(obs)
# Get action mask (all actions valid in SUMO-RL)
num_phases = self.env.action_space.n
action_mask = list(range(num_phases))
# Extract system metrics for metadata
system_info = {k: v for k, v in info.items() if k.startswith("system_")}
# Create observation
return SumoObservation(
observation=obs_list,
observation_shape=[len(obs_list)],
action_mask=action_mask,
sim_time=info.get("step", 0.0),
done=done,
reward=reward,
metadata={
"num_green_phases": num_phases,
"system_info": _json_safe(system_info),
},
)