logiflow-rl / models.py
roshan5emerald's picture
Upload folder using huggingface_hub
47ee65f verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Data models for the LogiFlow-RL environment."""
from typing import Any, Dict, List, Optional
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field
class CrisisLogisticsAction(Action):
"""Move freight through the supply-chain network.
The Phase 2 action uses source/destination/volume. The older target_hub
field remains supported so existing validators and demos keep working.
"""
target_hub: Optional[int] = Field(
default=None,
ge=0,
le=11,
description=(
"Backward-compatible route selector. If source_node/dest_node are omitted, "
"this chooses one outgoing edge from the current shipment source."
),
)
source_node: Optional[int] = Field(
default=None,
ge=0,
le=11,
description="Node id that dispatches freight.",
)
dest_node: Optional[int] = Field(
default=None,
ge=0,
le=11,
description="Connected downstream node id that receives freight after transit delay.",
)
shipment_volume: Optional[float] = Field(
default=None,
gt=0.0,
le=60.0,
description="Freight volume to move. Defaults to the active shipment volume.",
)
reasoning: Optional[str] = Field(
default=None,
description="Optional natural-language rationale from the agent before the structured action.",
)
class CrisisLogisticsObservation(Observation):
"""Partially observable state of the supply-chain network."""
task_id: str = Field(default="easy", description="Active benchmark task identifier.")
difficulty: str = Field(default="easy", description="Difficulty label for the active task.")
objective: str = Field(default="", description="Task objective visible to the agent.")
# Backward-compatible 3-value tier summary used by the existing visualizer.
hub_loads: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Aggregate utilization for supplier, warehouse, and downstream tiers.",
)
drain_rates: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Aggregate drain rates for supplier, warehouse, and downstream tiers.",
)
incoming_load: float = Field(
default=0.0,
description="Volume of the active shipment introduced at the current source node.",
)
step_count: int = Field(default=0, description="Current step in the episode.")
max_steps: int = Field(default=80, description="Maximum episode length.")
overloaded_hubs: int = Field(
default=0,
description="Number of nodes currently above capacity.",
)
cumulative_score: float = Field(
default=0.0,
description="Normalized score in the range [0.0, 1.0] for progress so far.",
)
last_reward: float = Field(default=0.0, description="Reward from the previous action.")
event_label: str = Field(
default="normal",
description="Current network condition, e.g. normal, port_closure, or supplier_failure.",
)
dynamic_pressure: float = Field(
default=0.0,
description="Adaptive stress estimate in [0, 1] driven by congestion, SLA risk, and delivery shortfall.",
)
adaptive_disruption_rate: float = Field(
default=0.0,
description="Effective disruption probability after adaptive pressure scaling.",
)
priority_target_node: int = Field(
default=10,
description="Retail node that should receive urgent demand in the current step window.",
)
priority_target_name: str = Field(
default="Retail North",
description="Human-readable name for the active priority retail target.",
)
priority_queue_depth: int = Field(
default=0,
description="Number of in-transit urgent shipments still pending completion.",
)
priority_service_rate: float = Field(
default=0.0,
description="Fraction of urgent shipments delivered on target and within SLA.",
)
message: str = Field(default="", description="Human-readable state summary.")
node_names: List[str] = Field(default_factory=list, description="Names for all network nodes.")
node_types: List[str] = Field(default_factory=list, description="Node type for each network node.")
node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
node_capacities: List[float] = Field(default_factory=list, description="Capacity for each node.")
node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
node_drain_rates: List[float] = Field(default_factory=list, description="Processing rate for each node.")
node_risk_scores: List[float] = Field(default_factory=list, description="Disruption risk for each node.")
connectivity: Dict[str, List[int]] = Field(
default_factory=dict,
description="Directed connectivity graph. Keys are node ids encoded as strings.",
)
visible_node_ids: List[int] = Field(
default_factory=list,
description="Nodes visible to the agent under the two-hop partial-observation rule.",
)
observed_node_loads: List[Optional[float]] = Field(
default_factory=list,
description="Node loads with hidden nodes represented as null.",
)
visible_connectivity: Dict[str, List[int]] = Field(
default_factory=dict,
description="Connectivity restricted to currently visible nodes.",
)
in_transit_shipments: List[Dict[str, Any]] = Field(
default_factory=list,
description="Shipments currently moving between nodes with delayed arrival.",
)
active_disruptions: List[Dict[str, Any]] = Field(
default_factory=list,
description="Active stochastic disruptions and their remaining duration.",
)
reward_breakdown: Dict[str, float] = Field(
default_factory=dict,
description="Inspectible per-step reward components.",
)
last_action: Dict[str, Any] = Field(default_factory=dict, description="Last resolved structured action.")
pending_source_node: int = Field(default=0, description="Source node for the active incoming shipment.")
retail_delivered: float = Field(default=0.0, description="Total volume that has arrived at retail sinks.")
sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")
class CrisisLogisticsState(State):
"""Internal environment state exposed for validation and debugging."""
task_id: str = Field(default="easy", description="Active task identifier.")
difficulty: str = Field(default="easy", description="Task difficulty label.")
hub_loads: List[float] = Field(
default_factory=lambda: [0.0, 0.0, 0.0],
description="Backward-compatible aggregate tier utilization.",
)
incoming_index: int = Field(default=0, description="Index of the next scheduled shipment.")
bottlenecks: int = Field(default=0, description="Total overload events encountered this episode.")
score: float = Field(default=0.0, description="Current normalized task score.")
node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
in_transit_count: int = Field(default=0, description="Number of delayed shipments in transit.")
active_disruptions: List[Dict[str, Any]] = Field(default_factory=list, description="Active disruptions.")
retail_delivered: float = Field(default=0.0, description="Total retail-arrived volume.")
sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")
dynamic_pressure: float = Field(default=0.0, description="Current adaptive network pressure in [0, 1].")
adaptive_disruption_rate: float = Field(default=0.0, description="Effective disruption probability this step.")
priority_target_node: int = Field(default=10, description="Retail node targeted for urgent demand.")
priority_service_rate: float = Field(default=0.0, description="Urgent shipment success rate.")