Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Data models for the LogiFlow-RL environment.""" | |
| from typing import Any, Dict, List, Optional | |
| from openenv.core.env_server.types import Action, Observation, State | |
| from pydantic import Field | |
| class CrisisLogisticsAction(Action): | |
| """Move freight through the supply-chain network. | |
| The Phase 2 action uses source/destination/volume. The older target_hub | |
| field remains supported so existing validators and demos keep working. | |
| """ | |
| target_hub: Optional[int] = Field( | |
| default=None, | |
| ge=0, | |
| le=11, | |
| description=( | |
| "Backward-compatible route selector. If source_node/dest_node are omitted, " | |
| "this chooses one outgoing edge from the current shipment source." | |
| ), | |
| ) | |
| source_node: Optional[int] = Field( | |
| default=None, | |
| ge=0, | |
| le=11, | |
| description="Node id that dispatches freight.", | |
| ) | |
| dest_node: Optional[int] = Field( | |
| default=None, | |
| ge=0, | |
| le=11, | |
| description="Connected downstream node id that receives freight after transit delay.", | |
| ) | |
| shipment_volume: Optional[float] = Field( | |
| default=None, | |
| gt=0.0, | |
| le=60.0, | |
| description="Freight volume to move. Defaults to the active shipment volume.", | |
| ) | |
| reasoning: Optional[str] = Field( | |
| default=None, | |
| description="Optional natural-language rationale from the agent before the structured action.", | |
| ) | |
| class CrisisLogisticsObservation(Observation): | |
| """Partially observable state of the supply-chain network.""" | |
| task_id: str = Field(default="easy", description="Active benchmark task identifier.") | |
| difficulty: str = Field(default="easy", description="Difficulty label for the active task.") | |
| objective: str = Field(default="", description="Task objective visible to the agent.") | |
| # Backward-compatible 3-value tier summary used by the existing visualizer. | |
| hub_loads: List[float] = Field( | |
| default_factory=lambda: [0.0, 0.0, 0.0], | |
| description="Aggregate utilization for supplier, warehouse, and downstream tiers.", | |
| ) | |
| drain_rates: List[float] = Field( | |
| default_factory=lambda: [0.0, 0.0, 0.0], | |
| description="Aggregate drain rates for supplier, warehouse, and downstream tiers.", | |
| ) | |
| incoming_load: float = Field( | |
| default=0.0, | |
| description="Volume of the active shipment introduced at the current source node.", | |
| ) | |
| step_count: int = Field(default=0, description="Current step in the episode.") | |
| max_steps: int = Field(default=80, description="Maximum episode length.") | |
| overloaded_hubs: int = Field( | |
| default=0, | |
| description="Number of nodes currently above capacity.", | |
| ) | |
| cumulative_score: float = Field( | |
| default=0.0, | |
| description="Normalized score in the range [0.0, 1.0] for progress so far.", | |
| ) | |
| last_reward: float = Field(default=0.0, description="Reward from the previous action.") | |
| event_label: str = Field( | |
| default="normal", | |
| description="Current network condition, e.g. normal, port_closure, or supplier_failure.", | |
| ) | |
| dynamic_pressure: float = Field( | |
| default=0.0, | |
| description="Adaptive stress estimate in [0, 1] driven by congestion, SLA risk, and delivery shortfall.", | |
| ) | |
| adaptive_disruption_rate: float = Field( | |
| default=0.0, | |
| description="Effective disruption probability after adaptive pressure scaling.", | |
| ) | |
| priority_target_node: int = Field( | |
| default=10, | |
| description="Retail node that should receive urgent demand in the current step window.", | |
| ) | |
| priority_target_name: str = Field( | |
| default="Retail North", | |
| description="Human-readable name for the active priority retail target.", | |
| ) | |
| priority_queue_depth: int = Field( | |
| default=0, | |
| description="Number of in-transit urgent shipments still pending completion.", | |
| ) | |
| priority_service_rate: float = Field( | |
| default=0.0, | |
| description="Fraction of urgent shipments delivered on target and within SLA.", | |
| ) | |
| message: str = Field(default="", description="Human-readable state summary.") | |
| node_names: List[str] = Field(default_factory=list, description="Names for all network nodes.") | |
| node_types: List[str] = Field(default_factory=list, description="Node type for each network node.") | |
| node_loads: List[float] = Field(default_factory=list, description="Current load at each node.") | |
| node_capacities: List[float] = Field(default_factory=list, description="Capacity for each node.") | |
| node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.") | |
| node_drain_rates: List[float] = Field(default_factory=list, description="Processing rate for each node.") | |
| node_risk_scores: List[float] = Field(default_factory=list, description="Disruption risk for each node.") | |
| connectivity: Dict[str, List[int]] = Field( | |
| default_factory=dict, | |
| description="Directed connectivity graph. Keys are node ids encoded as strings.", | |
| ) | |
| visible_node_ids: List[int] = Field( | |
| default_factory=list, | |
| description="Nodes visible to the agent under the two-hop partial-observation rule.", | |
| ) | |
| observed_node_loads: List[Optional[float]] = Field( | |
| default_factory=list, | |
| description="Node loads with hidden nodes represented as null.", | |
| ) | |
| visible_connectivity: Dict[str, List[int]] = Field( | |
| default_factory=dict, | |
| description="Connectivity restricted to currently visible nodes.", | |
| ) | |
| in_transit_shipments: List[Dict[str, Any]] = Field( | |
| default_factory=list, | |
| description="Shipments currently moving between nodes with delayed arrival.", | |
| ) | |
| active_disruptions: List[Dict[str, Any]] = Field( | |
| default_factory=list, | |
| description="Active stochastic disruptions and their remaining duration.", | |
| ) | |
| reward_breakdown: Dict[str, float] = Field( | |
| default_factory=dict, | |
| description="Inspectible per-step reward components.", | |
| ) | |
| last_action: Dict[str, Any] = Field(default_factory=dict, description="Last resolved structured action.") | |
| pending_source_node: int = Field(default=0, description="Source node for the active incoming shipment.") | |
| retail_delivered: float = Field(default=0.0, description="Total volume that has arrived at retail sinks.") | |
| sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.") | |
| class CrisisLogisticsState(State): | |
| """Internal environment state exposed for validation and debugging.""" | |
| task_id: str = Field(default="easy", description="Active task identifier.") | |
| difficulty: str = Field(default="easy", description="Task difficulty label.") | |
| hub_loads: List[float] = Field( | |
| default_factory=lambda: [0.0, 0.0, 0.0], | |
| description="Backward-compatible aggregate tier utilization.", | |
| ) | |
| incoming_index: int = Field(default=0, description="Index of the next scheduled shipment.") | |
| bottlenecks: int = Field(default=0, description="Total overload events encountered this episode.") | |
| score: float = Field(default=0.0, description="Current normalized task score.") | |
| node_loads: List[float] = Field(default_factory=list, description="Current load at each node.") | |
| node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.") | |
| in_transit_count: int = Field(default=0, description="Number of delayed shipments in transit.") | |
| active_disruptions: List[Dict[str, Any]] = Field(default_factory=list, description="Active disruptions.") | |
| retail_delivered: float = Field(default=0.0, description="Total retail-arrived volume.") | |
| sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.") | |
| dynamic_pressure: float = Field(default=0.0, description="Current adaptive network pressure in [0, 1].") | |
| adaptive_disruption_rate: float = Field(default=0.0, description="Effective disruption probability this step.") | |
| priority_target_node: int = Field(default=10, description="Retail node targeted for urgent demand.") | |
| priority_service_rate: float = Field(default=0.0, description="Urgent shipment success rate.") | |