# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Client for SUMO-RL environment. This module provides a client to interact with the SUMO traffic signal control environment via WebSocket for persistent sessions. """ from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from .models import SumoAction, SumoObservation, SumoState class SumoRLEnv(EnvClient[SumoAction, SumoObservation, SumoState]): """ Client for SUMO-RL traffic signal control environment. This client maintains a persistent WebSocket connection to a SUMO environment server to control traffic signals using reinforcement learning. Example: >>> # Start container and connect >>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest") >>> try: ... # Reset environment ... result = env.reset() ... print(f"Observation shape: {result.observation.observation_shape}") ... print(f"Action space: {result.observation.action_mask}") ... ... # Take action ... result = env.step(SumoAction(phase_id=1)) ... print(f"Reward: {result.reward}, Done: {result.done}") ... ... # Get state ... state = env.state() ... print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}") ... finally: ... env.close() Example with custom network: >>> # Use custom SUMO network via volume mount >>> env = SumoRLEnv.from_docker_image( ... "sumo-rl-env:latest", ... port=8000, ... volumes={ ... "/path/to/my/nets": {"bind": "/nets", "mode": "ro"} ... }, ... environment={ ... "SUMO_NET_FILE": "/nets/my-network.net.xml", ... "SUMO_ROUTE_FILE": "/nets/my-routes.rou.xml", ... } ... ) Example with configuration: >>> # Adjust simulation parameters >>> env = SumoRLEnv.from_docker_image( ... "sumo-rl-env:latest", ... environment={ ... "SUMO_NUM_SECONDS": "10000", ... "SUMO_DELTA_TIME": "10", ... "SUMO_REWARD_FN": "queue", ... "SUMO_SEED": "123", ... } ... ) """ def _step_payload(self, action: SumoAction) -> Dict[str, Any]: """ Convert SumoAction to JSON payload for HTTP request. Args: action: SumoAction containing phase_id to execute. Returns: Dictionary payload for step endpoint. """ return { "phase_id": action.phase_id, "ts_id": action.ts_id, } def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SumoObservation]: """ Parse step result from HTTP response JSON. Args: payload: JSON response from step endpoint. Returns: StepResult containing SumoObservation. """ obs_data = payload.get("observation", {}) observation = SumoObservation( observation=obs_data.get("observation", []), observation_shape=obs_data.get("observation_shape", []), action_mask=obs_data.get("action_mask", []), sim_time=obs_data.get("sim_time", 0.0), done=obs_data.get("done", False), reward=obs_data.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> SumoState: """ Parse state from HTTP response JSON. Args: payload: JSON response from state endpoint. Returns: SumoState object. """ return SumoState( episode_id=payload.get("episode_id", ""), step_count=payload.get("step_count", 0), net_file=payload.get("net_file", ""), route_file=payload.get("route_file", ""), num_seconds=payload.get("num_seconds", 20000), delta_time=payload.get("delta_time", 5), yellow_time=payload.get("yellow_time", 2), min_green=payload.get("min_green", 5), max_green=payload.get("max_green", 50), reward_fn=payload.get("reward_fn", "diff-waiting-time"), sim_time=payload.get("sim_time", 0.0), total_vehicles=payload.get("total_vehicles", 0), total_waiting_time=payload.get("total_waiting_time", 0.0), mean_waiting_time=payload.get("mean_waiting_time", 0.0), mean_speed=payload.get("mean_speed", 0.0), )