Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Client for the LogiFlow-RL environment server.""" | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| from .models import CrisisLogisticsAction, CrisisLogisticsObservation, CrisisLogisticsState | |
| class CrisisLogisticsEnv( | |
| EnvClient[CrisisLogisticsAction, CrisisLogisticsObservation, CrisisLogisticsState] | |
| ): | |
| """Thin client that talks to the HTTP or WebSocket server.""" | |
| def _step_payload(self, action: CrisisLogisticsAction) -> Dict: | |
| return action.model_dump(exclude_none=True) | |
| def _parse_result(self, payload: Dict) -> StepResult[CrisisLogisticsObservation]: | |
| obs_data = payload.get("observation", {}) | |
| observation = CrisisLogisticsObservation( | |
| task_id=obs_data.get("task_id", "easy"), | |
| difficulty=obs_data.get("difficulty", "easy"), | |
| objective=obs_data.get("objective", ""), | |
| hub_loads=obs_data.get("hub_loads", [0.0, 0.0, 0.0]), | |
| drain_rates=obs_data.get("drain_rates", [6.0, 5.0, 4.0]), | |
| incoming_load=obs_data.get("incoming_load", 0.0), | |
| step_count=obs_data.get("step_count", 0), | |
| max_steps=obs_data.get("max_steps", 100), | |
| overloaded_hubs=obs_data.get("overloaded_hubs", 0), | |
| cumulative_score=obs_data.get("cumulative_score", 0.0), | |
| last_reward=obs_data.get("last_reward", 0.0), | |
| event_label=obs_data.get("event_label", "normal"), | |
| dynamic_pressure=obs_data.get("dynamic_pressure", 0.0), | |
| adaptive_disruption_rate=obs_data.get("adaptive_disruption_rate", 0.0), | |
| priority_target_node=obs_data.get("priority_target_node", 10), | |
| priority_target_name=obs_data.get("priority_target_name", "Retail North"), | |
| priority_queue_depth=obs_data.get("priority_queue_depth", 0), | |
| priority_service_rate=obs_data.get("priority_service_rate", 0.0), | |
| message=obs_data.get("message", ""), | |
| node_names=obs_data.get("node_names", []), | |
| node_types=obs_data.get("node_types", []), | |
| node_loads=obs_data.get("node_loads", []), | |
| node_capacities=obs_data.get("node_capacities", []), | |
| node_utilization=obs_data.get("node_utilization", []), | |
| node_drain_rates=obs_data.get("node_drain_rates", []), | |
| node_risk_scores=obs_data.get("node_risk_scores", []), | |
| connectivity=obs_data.get("connectivity", {}), | |
| visible_node_ids=obs_data.get("visible_node_ids", []), | |
| observed_node_loads=obs_data.get("observed_node_loads", []), | |
| visible_connectivity=obs_data.get("visible_connectivity", {}), | |
| in_transit_shipments=obs_data.get("in_transit_shipments", []), | |
| active_disruptions=obs_data.get("active_disruptions", []), | |
| reward_breakdown=obs_data.get("reward_breakdown", {}), | |
| last_action=obs_data.get("last_action", {}), | |
| pending_source_node=obs_data.get("pending_source_node", 0), | |
| retail_delivered=obs_data.get("retail_delivered", 0.0), | |
| sla_success_rate=obs_data.get("sla_success_rate", 0.0), | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> CrisisLogisticsState: | |
| return CrisisLogisticsState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| task_id=payload.get("task_id", "easy"), | |
| difficulty=payload.get("difficulty", "easy"), | |
| hub_loads=payload.get("hub_loads", [0.0, 0.0, 0.0]), | |
| incoming_index=payload.get("incoming_index", 0), | |
| bottlenecks=payload.get("bottlenecks", 0), | |
| score=payload.get("score", 0.0), | |
| node_loads=payload.get("node_loads", []), | |
| node_utilization=payload.get("node_utilization", []), | |
| in_transit_count=payload.get("in_transit_count", 0), | |
| active_disruptions=payload.get("active_disruptions", []), | |
| retail_delivered=payload.get("retail_delivered", 0.0), | |
| sla_success_rate=payload.get("sla_success_rate", 0.0), | |
| dynamic_pressure=payload.get("dynamic_pressure", 0.0), | |
| adaptive_disruption_rate=payload.get("adaptive_disruption_rate", 0.0), | |
| priority_target_node=payload.get("priority_target_node", 10), | |
| priority_service_rate=payload.get("priority_service_rate", 0.0), | |
| ) | |