Spaces:
Sleeping
Sleeping
File size: 3,815 Bytes
98a5a8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """
Gymnasium wrapper for BudgetRouterEnv.
Wraps a single scenario (default: Easy) into a standard Gymnasium interface
compatible with stable-baselines3 and other RL libraries.
Observation space: Box(7,) float32 — all fields normalized to [0, 1]
[provider_a_status, provider_b_status, provider_c_status,
budget_remaining, queue_backlog, system_latency, step_count]
Action space: Discrete(4)
0 → route_to_a
1 → route_to_b
2 → route_to_c
3 → shed_load
"""
from __future__ import annotations
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType
from budget_router.tasks import EASY, TaskConfig
_ACTION_MAP = [
ActionType.ROUTE_TO_A,
ActionType.ROUTE_TO_B,
ActionType.ROUTE_TO_C,
ActionType.SHED_LOAD,
]
class BudgetRouterGymEnv(gym.Env):
"""Gymnasium-compatible wrapper around BudgetRouterEnv.
Args:
scenario: TaskConfig to use. Defaults to EASY.
seed: Fixed seed for reproducible resets. If None, a random seed is
drawn from the Gymnasium RNG on each reset().
"""
metadata = {"render_modes": []}
def __init__(self, scenario: TaskConfig = EASY, seed: int | None = None) -> None:
super().__init__()
self._env = BudgetRouterEnv()
self._scenario = scenario
self._fixed_seed = seed
self._episode_seed: int | None = seed
# 7-dimensional normalized observation
self.observation_space = spaces.Box(
low=0.0,
high=1.0,
shape=(7,),
dtype=np.float32,
)
# 4 discrete actions: route_to_a, route_to_b, route_to_c, shed_load
self.action_space = spaces.Discrete(4)
# ------------------------------------------------------------------
# Gymnasium API
# ------------------------------------------------------------------
def reset(
self,
*,
seed: int | None = None,
options: dict | None = None,
) -> tuple[np.ndarray, dict]:
super().reset(seed=seed)
# Use fixed seed if provided at construction, otherwise use Gymnasium's RNG
if self._fixed_seed is not None:
episode_seed = self._fixed_seed
elif seed is not None:
episode_seed = seed
else:
episode_seed = int(self.np_random.integers(0, 2**31 - 1))
self._episode_seed = episode_seed
obs = self._env.reset(seed=episode_seed, scenario=self._scenario)
return self._obs_to_array(obs), {}
def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict]:
action_type = _ACTION_MAP[int(action)]
obs = self._env.step(Action(action_type=action_type))
reward = float(obs.reward or 0.0)
terminated = bool(obs.done)
truncated = False # termination handled by env's max_steps
info = dict(obs.metadata or {})
return self._obs_to_array(obs), reward, terminated, truncated, info
def render(self) -> None:
pass # no rendering required
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _obs_to_array(obs) -> np.ndarray:
"""Convert Observation dataclass to a (7,) float32 numpy array."""
return np.array(
[
obs.provider_a_status,
obs.provider_b_status,
obs.provider_c_status,
obs.budget_remaining,
obs.queue_backlog,
obs.system_latency,
obs.step_count,
],
dtype=np.float32,
)
|