logiflow-rl / gym_env.py
roshan5emerald's picture
Upload folder using huggingface_hub
47ee65f verified
from __future__ import annotations
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
try:
from .models import CrisisLogisticsAction
from .server.crisis_logistics_env_environment import CrisisLogisticsEnvironment
except ImportError:
from models import CrisisLogisticsAction
from server.crisis_logistics_env_environment import CrisisLogisticsEnvironment
class LogiFlowGymEnv(gym.Env):
"""Gymnasium wrapper for the 12-node delayed logistics benchmark."""
metadata = {"render_modes": ["human"], "render_fps": 4}
def __init__(self, task_id: str = "easy"):
super().__init__()
self.task_id = task_id
self.env = CrisisLogisticsEnvironment()
self.action_space = spaces.Dict(
{
"source_node": spaces.Discrete(12),
"dest_node": spaces.Discrete(12),
"shipment_volume": spaces.Box(low=1.0, high=60.0, shape=(), dtype=np.float32),
}
)
self.observation_space = spaces.Box(low=0.0, high=1.5, shape=(20,), dtype=np.float32)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[np.ndarray, dict[str, Any]]:
task_id = (options or {}).get("task_id", self.task_id)
self.observation = self.env.reset(seed=seed, task_id=task_id)
return self._flatten(self.observation), self._info()
def step(self, action: dict[str, Any]):
env_action = CrisisLogisticsAction(
source_node=int(action["source_node"]),
dest_node=int(action["dest_node"]),
shipment_volume=float(action["shipment_volume"]),
)
self.observation = self.env.step(env_action)
terminated = bool(self.observation.done)
truncated = False
return self._flatten(self.observation), float(self.observation.reward or 0.0), terminated, truncated, self._info()
def render(self):
print(
f"step={self.env.step_count} score={self.env.score:.3f} "
f"retail={self.env.retail_delivered:.1f} transit={len(self.env.in_transit)}"
)
def _flatten(self, observation) -> np.ndarray:
util = list(observation.node_utilization[:12])
while len(util) < 12:
util.append(0.0)
extras = [
observation.incoming_load / 60.0,
len(observation.in_transit_shipments) / 25.0,
len(observation.active_disruptions) / 12.0,
observation.cumulative_score,
observation.step_count / max(observation.max_steps, 1),
observation.dynamic_pressure,
observation.priority_service_rate,
min(1.0, observation.adaptive_disruption_rate),
]
return np.array(util + extras, dtype=np.float32)
def _info(self) -> dict[str, Any]:
return {
"score": self.env.score,
"bottlenecks": self.env.bottlenecks,
"retail_delivered": self.env.retail_delivered,
"sla_success_rate": self.env._sla_success_rate(),
"dynamic_pressure": self.env.dynamic_pressure,
"priority_service_rate": self.env._priority_service_rate(),
}