File size: 8,508 Bytes
ddb6ffa
 
 
 
 
 
 
 
47ee65f
ddb6ffa
 
 
 
 
 
47ee65f
ddb6ffa
47ee65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb6ffa
47ee65f
 
 
 
 
 
 
 
 
 
 
 
ddb6ffa
 
 
 
47ee65f
ddb6ffa
 
 
 
47ee65f
 
ddb6ffa
 
47ee65f
ddb6ffa
 
47ee65f
 
ddb6ffa
47ee65f
ddb6ffa
 
47ee65f
ddb6ffa
 
47ee65f
ddb6ffa
 
47ee65f
ddb6ffa
 
 
47ee65f
ddb6ffa
 
 
 
47ee65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb6ffa
 
 
47ee65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb6ffa
 
 
 
 
 
 
 
47ee65f
ddb6ffa
 
47ee65f
ddb6ffa
47ee65f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Data models for the LogiFlow-RL environment."""

from typing import Any, Dict, List, Optional

from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field


class CrisisLogisticsAction(Action):
    """Move freight through the supply-chain network.

    The Phase 2 action uses source/destination/volume. The older target_hub
    field remains supported so existing validators and demos keep working.
    """

    target_hub: Optional[int] = Field(
        default=None,
        ge=0,
        le=11,
        description=(
            "Backward-compatible route selector. If source_node/dest_node are omitted, "
            "this chooses one outgoing edge from the current shipment source."
        ),
    )
    source_node: Optional[int] = Field(
        default=None,
        ge=0,
        le=11,
        description="Node id that dispatches freight.",
    )
    dest_node: Optional[int] = Field(
        default=None,
        ge=0,
        le=11,
        description="Connected downstream node id that receives freight after transit delay.",
    )
    shipment_volume: Optional[float] = Field(
        default=None,
        gt=0.0,
        le=60.0,
        description="Freight volume to move. Defaults to the active shipment volume.",
    )
    reasoning: Optional[str] = Field(
        default=None,
        description="Optional natural-language rationale from the agent before the structured action.",
    )


class CrisisLogisticsObservation(Observation):
    """Partially observable state of the supply-chain network."""

    task_id: str = Field(default="easy", description="Active benchmark task identifier.")
    difficulty: str = Field(default="easy", description="Difficulty label for the active task.")
    objective: str = Field(default="", description="Task objective visible to the agent.")

    # Backward-compatible 3-value tier summary used by the existing visualizer.
    hub_loads: List[float] = Field(
        default_factory=lambda: [0.0, 0.0, 0.0],
        description="Aggregate utilization for supplier, warehouse, and downstream tiers.",
    )
    drain_rates: List[float] = Field(
        default_factory=lambda: [0.0, 0.0, 0.0],
        description="Aggregate drain rates for supplier, warehouse, and downstream tiers.",
    )

    incoming_load: float = Field(
        default=0.0,
        description="Volume of the active shipment introduced at the current source node.",
    )
    step_count: int = Field(default=0, description="Current step in the episode.")
    max_steps: int = Field(default=80, description="Maximum episode length.")
    overloaded_hubs: int = Field(
        default=0,
        description="Number of nodes currently above capacity.",
    )
    cumulative_score: float = Field(
        default=0.0,
        description="Normalized score in the range [0.0, 1.0] for progress so far.",
    )
    last_reward: float = Field(default=0.0, description="Reward from the previous action.")
    event_label: str = Field(
        default="normal",
        description="Current network condition, e.g. normal, port_closure, or supplier_failure.",
    )
    dynamic_pressure: float = Field(
        default=0.0,
        description="Adaptive stress estimate in [0, 1] driven by congestion, SLA risk, and delivery shortfall.",
    )
    adaptive_disruption_rate: float = Field(
        default=0.0,
        description="Effective disruption probability after adaptive pressure scaling.",
    )
    priority_target_node: int = Field(
        default=10,
        description="Retail node that should receive urgent demand in the current step window.",
    )
    priority_target_name: str = Field(
        default="Retail North",
        description="Human-readable name for the active priority retail target.",
    )
    priority_queue_depth: int = Field(
        default=0,
        description="Number of in-transit urgent shipments still pending completion.",
    )
    priority_service_rate: float = Field(
        default=0.0,
        description="Fraction of urgent shipments delivered on target and within SLA.",
    )
    message: str = Field(default="", description="Human-readable state summary.")

    node_names: List[str] = Field(default_factory=list, description="Names for all network nodes.")
    node_types: List[str] = Field(default_factory=list, description="Node type for each network node.")
    node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
    node_capacities: List[float] = Field(default_factory=list, description="Capacity for each node.")
    node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
    node_drain_rates: List[float] = Field(default_factory=list, description="Processing rate for each node.")
    node_risk_scores: List[float] = Field(default_factory=list, description="Disruption risk for each node.")
    connectivity: Dict[str, List[int]] = Field(
        default_factory=dict,
        description="Directed connectivity graph. Keys are node ids encoded as strings.",
    )

    visible_node_ids: List[int] = Field(
        default_factory=list,
        description="Nodes visible to the agent under the two-hop partial-observation rule.",
    )
    observed_node_loads: List[Optional[float]] = Field(
        default_factory=list,
        description="Node loads with hidden nodes represented as null.",
    )
    visible_connectivity: Dict[str, List[int]] = Field(
        default_factory=dict,
        description="Connectivity restricted to currently visible nodes.",
    )
    in_transit_shipments: List[Dict[str, Any]] = Field(
        default_factory=list,
        description="Shipments currently moving between nodes with delayed arrival.",
    )
    active_disruptions: List[Dict[str, Any]] = Field(
        default_factory=list,
        description="Active stochastic disruptions and their remaining duration.",
    )
    reward_breakdown: Dict[str, float] = Field(
        default_factory=dict,
        description="Inspectible per-step reward components.",
    )
    last_action: Dict[str, Any] = Field(default_factory=dict, description="Last resolved structured action.")
    pending_source_node: int = Field(default=0, description="Source node for the active incoming shipment.")
    retail_delivered: float = Field(default=0.0, description="Total volume that has arrived at retail sinks.")
    sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")


class CrisisLogisticsState(State):
    """Internal environment state exposed for validation and debugging."""

    task_id: str = Field(default="easy", description="Active task identifier.")
    difficulty: str = Field(default="easy", description="Task difficulty label.")
    hub_loads: List[float] = Field(
        default_factory=lambda: [0.0, 0.0, 0.0],
        description="Backward-compatible aggregate tier utilization.",
    )
    incoming_index: int = Field(default=0, description="Index of the next scheduled shipment.")
    bottlenecks: int = Field(default=0, description="Total overload events encountered this episode.")
    score: float = Field(default=0.0, description="Current normalized task score.")

    node_loads: List[float] = Field(default_factory=list, description="Current load at each node.")
    node_utilization: List[float] = Field(default_factory=list, description="Load/capacity for each node.")
    in_transit_count: int = Field(default=0, description="Number of delayed shipments in transit.")
    active_disruptions: List[Dict[str, Any]] = Field(default_factory=list, description="Active disruptions.")
    retail_delivered: float = Field(default=0.0, description="Total retail-arrived volume.")
    sla_success_rate: float = Field(default=0.0, description="Fraction of retail deliveries within SLA.")
    dynamic_pressure: float = Field(default=0.0, description="Current adaptive network pressure in [0, 1].")
    adaptive_disruption_rate: float = Field(default=0.0, description="Effective disruption probability this step.")
    priority_target_node: int = Field(default=10, description="Retail node targeted for urgent demand.")
    priority_service_rate: float = Field(default=0.0, description="Urgent shipment success rate.")