Spaces:
Sleeping
Sleeping
File size: 8,508 Bytes
ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f ddb6ffa 47ee65f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Data models for the LogiFlow-RL environment."""
from typing import Any, Dict, List, Optional
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field
class CrisisLogisticsAction(Action):
"""Move freight through the supply-chain network.
The Phase 2 action uses source/destination/volume. The older target_hub
field remains supported so existing validators and demos keep working.
"""
target_hub: Optional[int] = Field(
default=None,
ge=0,
le=11,
description=(
"Backward-compatible route selector. If source_node/dest_node are omitted, "
"this chooses one outgoing edge from the current shipment source."
),
)
source_node: Optional[int] = Field(
default=None,
ge=0,
le=11,
description="Node id that dispatches freight.",
)
dest_node: Optional[int] = Field(
default=None,
ge=0,
le=11,
description="Connected downstream node id that receives freight after transit delay.",
)
shipment_volume: Optional[float] = Field(
default=None,
gt=0.0,
le=60.0,
description="Freight volume to move. Defaults to the active shipment volume.",
)
reasoning: Optional[str] = Field(
default=None,
description="Optional natural-language rationale from the agent before the structured action.",
)
class CrisisLogisticsObservation(Observation):
"""Partially observable state of the supply-chain network."""
task_id: str = Field(default="easy", description="Active benchmark task identifier.")
difficulty: str = Field(default="easy", description="Difficulty label for the active task.")
objective: str = Field(default="", description="Task objective visible to the agent.")
# Backward-compatible 3-value tier summary used by the existing visualizer.
hub_loads: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Aggregate utilization for supplier, warehouse, and downstream tiers.",
)
drain_rates: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Aggregate drain rates for supplier, warehouse, and downstream tiers.",
)
incoming_load: float = Field(
default=0.0,
description="Volume of the active shipment introduced at the current source node.",
)
step_count: int = Field(default=0, description="Current step in the episode.")
max_steps: int = Field(default=80, description="Maximum episode length.")
overloaded_hubs: int = Field(
default=0,
description="Number of nodes currently above capacity.",
)
cumulative_score: float = Field(
default=0.0,
description="Normalized score in the range [0.0, 1.0] for progress so far.",
)
last_reward: float = Field(default=0.0, description="Reward from the previous action.")
event_label: str = Field(
default="normal",
description="Current network condition, e.g. normal, port_closure, or supplier_failure.",
)
dynamic_pressure: float = Field(
default=0.0,
description="Adaptive stress estimate in [0, 1] driven by congestion, SLA risk, and delivery shortfall.",
)
adaptive_disruption_rate: float = Field(
default=0.0,
description="Effective disruption probability after adaptive pressure scaling.",
)
priority_target_node: int = Field(
default=10,
description="Retail node that should receive urgent demand in the current step window.",
)
priority_target_name: str = Field(
default="Retail North",
description="Human-readable name for the active priority retail target.",
)
priority_queue_depth: int = Field(
default=0,
description="Number of in-transit urgent shipments still pending completion.",
)
priority_service_rate: float = Field(
default=0.0,
description="Fraction of urgent shipments delivered on target and within SLA.",
)
message: str = Field(default="", description="Human-readable state summary.")
node_names: List[str] = Field(default_factory=list, description="Names for all network nodes.")
node_types: List[str] = Field(default_factory=list, description="Node type for each network node.")
node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
node_capacities: List[float] = Field(default_factory=list, description="Capacity for each node.")
node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
node_drain_rates: List[float] = Field(default_factory=list, description="Processing rate for each node.")
node_risk_scores: List[float] = Field(default_factory=list, description="Disruption risk for each node.")
connectivity: Dict[str, List[int]] = Field(
default_factory=dict,
description="Directed connectivity graph. Keys are node ids encoded as strings.",
)
visible_node_ids: List[int] = Field(
default_factory=list,
description="Nodes visible to the agent under the two-hop partial-observation rule.",
)
observed_node_loads: List[Optional[float]] = Field(
default_factory=list,
description="Node loads with hidden nodes represented as null.",
)
visible_connectivity: Dict[str, List[int]] = Field(
default_factory=dict,
description="Connectivity restricted to currently visible nodes.",
)
in_transit_shipments: List[Dict[str, Any]] = Field(
default_factory=list,
description="Shipments currently moving between nodes with delayed arrival.",
)
active_disruptions: List[Dict[str, Any]] = Field(
default_factory=list,
description="Active stochastic disruptions and their remaining duration.",
)
reward_breakdown: Dict[str, float] = Field(
default_factory=dict,
description="Inspectible per-step reward components.",
)
last_action: Dict[str, Any] = Field(default_factory=dict, description="Last resolved structured action.")
pending_source_node: int = Field(default=0, description="Source node for the active incoming shipment.")
retail_delivered: float = Field(default=0.0, description="Total volume that has arrived at retail sinks.")
sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")
class CrisisLogisticsState(State):
"""Internal environment state exposed for validation and debugging."""
task_id: str = Field(default="easy", description="Active task identifier.")
difficulty: str = Field(default="easy", description="Task difficulty label.")
hub_loads: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Backward-compatible aggregate tier utilization.",
)
incoming_index: int = Field(default=0, description="Index of the next scheduled shipment.")
bottlenecks: int = Field(default=0, description="Total overload events encountered this episode.")
score: float = Field(default=0.0, description="Current normalized task score.")
node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
in_transit_count: int = Field(default=0, description="Number of delayed shipments in transit.")
active_disruptions: List[Dict[str, Any]] = Field(default_factory=list, description="Active disruptions.")
retail_delivered: float = Field(default=0.0, description="Total retail-arrived volume.")
sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")
dynamic_pressure: float = Field(default=0.0, description="Current adaptive network pressure in [0, 1].")
adaptive_disruption_rate: float = Field(default=0.0, description="Effective disruption probability this step.")
priority_target_node: int = Field(default=10, description="Retail node targeted for urgent demand.")
priority_service_rate: float = Field(default=0.0, description="Urgent shipment success rate.")
|