HyeminGu commited on
Commit
ec7be33
·
1 Parent(s): 035a7f1

initial demo

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Anonymous
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,15 +1,60 @@
1
  ---
2
- title: ISOMORPH Demo
3
- emoji: 🦀
4
- colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
- license: mit
12
- short_description: Interactive supply-chain stress-testing simulator
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ISOMORPH Supply Chain Digital Twin
3
+ emoji: 🏭
4
+ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.1
 
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # ISOMORPH Supply Chain Digital Twin
13
+
14
+ **Interactive simulation environment for stress-testing supply chains under demand shocks, disruptions, and cascading transport congestion.**
15
+
16
+ ISOMORPH is a stochastic digital twin of a 13-node multi-echelon US logistics network. Configure parameters, run the simulation, and observe how local operational decisions propagate through the network over time.
17
+
18
+ ## What you can explore
19
+
20
+ - **🗺️ Network Map** — animated shipment propagation across the US network; nodes colored by backlog stress, moving dots colored by SKU. Export as an animated GIF.
21
+ - **📊 Node Detail** — per-node time series of inventory, backlog, inflow, outflow, and demand with disruption event markers.
22
+ - **📈 Bullwhip** — tier-level amplification chart (B = Var(inflow) / Var(outflow)); shows how demand variability grows upstream through the network.
23
+ - **🔥 Edge Util** — heatmap of daily shipping-lane utilization; highlights congestion and disruption events.
24
+ - **⬇️ Download** — full CSV export of all state variables for every node, item, and day.
25
+
26
+ ## Preset scenarios
27
+
28
+ | Preset | What it demonstrates |
29
+ |--------|----------------------|
30
+ | 🟢 Baseline | Mild bullwhip emerging internally from (s, S) ordering and lead-time delays alone |
31
+ | ⚡ Demand Shock | Correlated macro shocks and per-item bursts amplify variability upstream |
32
+ | 🔴 Disruption | A lane is randomly blocked; goods reroute and a catch-up wave propagates on recovery |
33
+ | 📦 Low Capacity | Cascading transport congestion from the last-mile inward; systemic stockouts and extreme bullwhip |
34
+
35
+ Use the preset buttons to instantly load a scenario, then tune individual parameters with the left-panel sliders and click **▶ Run Simulation** to re-run.
36
+
37
+ ## Paper
38
+
39
+ *ISOMORPH: A Supply Chain Digital Twin for Simulation, Dataset Generation, and Forecasting Benchmarks*
40
+ Zhang et al., 2026 — [arXiv:2605.12768](https://arxiv.org/abs/2605.12768)
41
+
42
+ Full simulator and datasets: [github.com/tuhinsahai/ISOMORPH](https://github.com/tuhinsahai/ISOMORPH)
43
+
44
+ ## Acknowledgements
45
+
46
+ This material is based upon work supported by the Defense Advanced Research Projects Agency (DARPA) under Agreement No. HR00112590112. Approved for public release; distribution is unlimited.
47
+
48
+ ## Citation
49
+
50
+ ```bibtex
51
+ @misc{zhang2026isomorphsupplychaindigital,
52
+ title={ISOMORPH: A Supply Chain Digital Twin for Simulation, Dataset Generation, and Forecasting Benchmarks},
53
+ author={Zhizhen Zhang and Hyemin Gu and Benjamin J. Zhang and Daniel Elenius and Michael Tyrrell and Theo J. Bourdais and Houman Owhadi and Markos A. Katsoulakis and Tuhin Sahai},
54
+ year={2026},
55
+ eprint={2605.12768},
56
+ archivePrefix={arXiv},
57
+ primaryClass={stat.ML},
58
+ url={https://arxiv.org/abs/2605.12768},
59
+ }
60
+ ```
analysis/bullwhip_analysis.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Per-node and per-tier bullwhip analysis on the digital-twin sim output.
3
+ B_n = Var(inflow_n) / Var(outflow_n) (Cachon-style amplification ratio)
4
+ """
5
+ import os
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ REPO = Path(__file__).resolve().parents[1]
11
+ DATA = os.environ.get("DATA", str(REPO / "data" / "output_item50"))
12
+ OUT = os.environ.get("OUT", str(REPO / "results" / "bullwhip"))
13
+ os.makedirs(OUT, exist_ok=True)
14
+
15
+ # ---------- 1. Network ----------
16
+ edges = pd.read_csv(os.path.join(DATA, "edge_list.csv"))
17
+ nodes = set(edges["from"]) | set(edges["to"])
18
+ parents, children = {}, {}
19
+ for _, r in edges.iterrows():
20
+ children.setdefault(r["from"], []).append(r["to"])
21
+ parents.setdefault(r["to"], []).append(r["from"])
22
+
23
+ sinks = [n for n in nodes if n not in children]
24
+
25
+ # Tier labels follow the inventory-parameter table of the paper
26
+ # (manuscript Table tab:inventory-params, App. C.4). A longest-path-from-sink
27
+ # BFS would put Memphis at depth 2 (same as Columbus/Richmond, i.e. Tier-4),
28
+ # but the paper labels Memphis as Tier-3 alongside Charlotte and Chicago.
29
+ TIER = {
30
+ # Destination
31
+ "NewYork": 0,
32
+ # Last-mile DCs
33
+ "Baltimore": 1, # Tier-5 (LM)
34
+ "Philadelphia": 1, # Tier-5 (LM)
35
+ # Tier-4
36
+ "Columbus": 2,
37
+ "Richmond": 2,
38
+ # Tier-3 (Memphis included by paper convention; topological depth = 2)
39
+ "Charlotte": 3,
40
+ "Chicago": 3,
41
+ "Memphis": 3,
42
+ # Tier-2
43
+ "Atlanta": 4,
44
+ # Hub
45
+ "Nashville": 5,
46
+ # Sources -- no inflow on the released edges, dropped by compute_B below
47
+ "SanFrancisco": 6,
48
+ "StLouis": 6,
49
+ "Orlando": 6,
50
+ }
51
+ missing = nodes - set(TIER)
52
+ assert not missing, f"TIER mapping missing nodes: {missing}"
53
+ tier = {n: TIER[n] for n in nodes}
54
+
55
+ print("Tiers:")
56
+ for t in sorted(set(tier.values())):
57
+ print(f" T{t}: {sorted(n for n in nodes if tier[n]==t)}")
58
+
59
+ # ---------- 2. Shipments → per-(node, item) daily inflow & outflow ----------
60
+ print("\nLoading shipments ...")
61
+ ship = pd.read_csv(
62
+ os.path.join(DATA, "shipments.csv"),
63
+ usecols=["day", "arrival_day", "from", "to", "item", "units"],
64
+ )
65
+ print(f" shipments rows: {len(ship):,}")
66
+
67
+ # Outflow at node n on day t (item k): units shipped FROM n on day t
68
+ out = (
69
+ ship.groupby(["from", "item", "day"])["units"].sum()
70
+ .rename("units").reset_index()
71
+ .rename(columns={"from": "node", "day": "t"})
72
+ )
73
+ # Inflow at node n on day t (item k): units arriving AT n on arrival_day t
74
+ inn = (
75
+ ship.groupby(["to", "item", "arrival_day"])["units"].sum()
76
+ .rename("units").reset_index()
77
+ .rename(columns={"to": "node", "arrival_day": "t"})
78
+ )
79
+ del ship
80
+
81
+ # ---------- 3. Customer demand at retail sink (NewYork) ----------
82
+ dr = pd.read_csv(os.path.join(DATA, "daily_records.csv"),
83
+ usecols=["day", "item", "demand"])
84
+
85
+ # ---------- 4. Per-(node, item) variances ----------
86
+ T_MAX = int(dr["day"].max()) + 1
87
+ items = sorted(dr["item"].unique())
88
+ print(f" horizon: {T_MAX} days, items: {len(items)}")
89
+
90
+ def to_dense(df, value_col="units"):
91
+ """pivot a (node,item,t,value) frame to {node:{item: ndarray[T_MAX]}}."""
92
+ out = {}
93
+ for (node, item), g in df.groupby(["node", "item"]):
94
+ v = np.zeros(T_MAX, dtype=np.float64)
95
+ idx = g["t"].values.astype(int)
96
+ # Some arrival days may exceed T_MAX (shipments still in transit at end).
97
+ mask = idx < T_MAX
98
+ v[idx[mask]] = g[value_col].values[mask]
99
+ out.setdefault(node, {})[item] = v
100
+ return out
101
+
102
+ print("Pivoting outflow ...")
103
+ outflow = to_dense(out)
104
+ print("Pivoting inflow ...")
105
+ inflow = to_dense(inn)
106
+
107
+ # Sink: outflow = customer demand
108
+ demand_NY = {}
109
+ for item, g in dr.groupby("item"):
110
+ v = np.zeros(T_MAX, dtype=np.float64)
111
+ v[g["day"].values.astype(int)] = g["demand"].values
112
+ demand_NY[item] = v
113
+ sink = sinks[0]
114
+ outflow.setdefault(sink, {})
115
+ for item in items:
116
+ outflow[sink][item] = demand_NY[item] # override: customer-facing outflow
117
+
118
+ # ---------- 5. Compute B per (node, item) at a given aggregation window ----------
119
+ # Burn-in: drop first 365 days to avoid initialization transients
120
+ BURN = 365
121
+
122
+ def aggregate(series, window):
123
+ """Sum a daily ndarray into non-overlapping `window`-day bins (drops trailing partial bin)."""
124
+ if window <= 1:
125
+ return series
126
+ n = (len(series) // window) * window
127
+ return series[:n].reshape(-1, window).sum(axis=1)
128
+
129
+ def compute_B(window, label, suffix):
130
+ """Compute per-(node,item) bullwhip ratios at a given temporal aggregation."""
131
+ rows = []
132
+ for node in nodes:
133
+ for item in items:
134
+ d = outflow.get(node, {}).get(item)
135
+ o = inflow.get(node, {}).get(item)
136
+ if d is None or o is None:
137
+ continue
138
+ d, o = d[BURN:], o[BURN:]
139
+ d, o = aggregate(d, window), aggregate(o, window)
140
+ if len(d) < 2:
141
+ continue
142
+ vd, vo = d.var(ddof=1), o.var(ddof=1)
143
+ if vd == 0 or vo == 0:
144
+ continue
145
+ rows.append(dict(
146
+ node=node, tier=tier[node], item=item,
147
+ var_demand=vd, var_inflow=vo, B=vo / vd,
148
+ mean_demand=d.mean(), mean_inflow=o.mean(),
149
+ ))
150
+ df = pd.DataFrame(rows)
151
+ df.to_csv(os.path.join(OUT, f"bullwhip_per_node_item{suffix}.csv"), index=False)
152
+
153
+ print(f"\n########## Aggregation: {label} (window={window}d, n_bins={(T_MAX-BURN)//window}) ##########")
154
+ print(f"Per-(node,item) rows: {len(df)}")
155
+
156
+ print(f"\n=== Per-node mean B ({label}) ===")
157
+ node_summary = (
158
+ df.groupby(["tier", "node"])
159
+ .agg(B_mean=("B", "mean"),
160
+ B_median=("B", "median"),
161
+ B_p10=("B", lambda s: s.quantile(0.1)),
162
+ B_p90=("B", lambda s: s.quantile(0.9)),
163
+ n_items=("B", "count"))
164
+ .reset_index()
165
+ .sort_values(["tier", "node"])
166
+ )
167
+ print(node_summary.to_string(index=False))
168
+ node_summary.to_csv(os.path.join(OUT, f"bullwhip_per_node{suffix}.csv"), index=False)
169
+
170
+ print(f"\n=== Per-tier summary ({label}) ===")
171
+ tier_summary = (
172
+ df.groupby("tier")
173
+ .agg(B_mean=("B", "mean"),
174
+ B_median=("B", "median"),
175
+ mean_var_demand=("var_demand", "mean"),
176
+ mean_var_inflow=("var_inflow", "mean"),
177
+ n_obs=("B", "count"))
178
+ .reset_index()
179
+ .sort_values("tier")
180
+ )
181
+ print(tier_summary.to_string(index=False))
182
+ tier_summary.to_csv(os.path.join(OUT, f"bullwhip_per_tier{suffix}.csv"), index=False)
183
+
184
+ print(f"\n=== Between-tier variance amplification ({label}) ===")
185
+ btw = tier_summary[["tier", "mean_var_demand"]].copy()
186
+ btw["amplification_to_next_upstream_tier"] = (
187
+ btw["mean_var_demand"].shift(-1) / btw["mean_var_demand"]
188
+ )
189
+ print(btw.to_string(index=False))
190
+ btw.to_csv(os.path.join(OUT, f"bullwhip_between_tier{suffix}.csv"), index=False)
191
+ return df, node_summary, tier_summary
192
+
193
+ # Daily (original): suffix "" keeps backward-compatible filenames
194
+ compute_B(window=1, label="daily", suffix="")
195
+ # Monthly (Cachon-faithful, 30-day bins on simulation time)
196
+ compute_B(window=30, label="monthly", suffix="_monthly")
197
+
198
+ print("\nSaved:")
199
+ for f in ["bullwhip_per_node_item.csv","bullwhip_per_node.csv",
200
+ "bullwhip_per_tier.csv","bullwhip_between_tier.csv",
201
+ "bullwhip_per_node_item_monthly.csv","bullwhip_per_node_monthly.csv",
202
+ "bullwhip_per_tier_monthly.csv","bullwhip_between_tier_monthly.csv"]:
203
+ print(" ", os.path.join(OUT, f))
analysis/make_baseline_overview.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate baseline_overview.pdf for §3 from output_item50.
2
+
3
+ 3 x 2 layout combining catalogue heterogeneity with mechanism:
4
+
5
+ Left column (3 panels) -- raw_item_series.png style: each row is
6
+ one catalogue item showing demand,
7
+ served, and unmet overlaid over the
8
+ full T = 52,560-day horizon. Items:
9
+ I01, I20, I40 sample the catalogue.
10
+
11
+ Right column (3 panels) -- internal network state for the focal
12
+ item I01 in a 5-year zoom window
13
+ centred on the largest macro-shock:
14
+ (b) destination on-hand inventory
15
+ (d) destination backlog
16
+ (f) last-mile edge utilisation
17
+ (PHL->NYC, BAL->NYC) with 0.95
18
+ saturation threshold
19
+
20
+ The yellow band on each left panel marks the right column's window.
21
+ Together the two columns expose both the catalogue's per-item
22
+ heterogeneity and the mechanism that produces fill-rate drop events.
23
+ """
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ import matplotlib.pyplot as plt
29
+
30
+ REPO = Path(__file__).resolve().parents[1]
31
+ DATA = REPO / "data" / "output_item50"
32
+ OUT = REPO / "results" / "figures"
33
+
34
+ # ---------- styling ----------
35
+ plt.rcParams.update({
36
+ "font.family": "serif",
37
+ "font.size": 10,
38
+ "axes.labelsize": 9.5,
39
+ "axes.titlesize": 9.5,
40
+ "axes.titleweight": "bold",
41
+ "xtick.labelsize": 8.5,
42
+ "ytick.labelsize": 8.5,
43
+ "legend.fontsize": 8.0,
44
+ "mathtext.fontset": "cm",
45
+ "axes.spines.top": False,
46
+ "axes.spines.right": False,
47
+ "axes.grid": True,
48
+ "grid.alpha": 0.22,
49
+ "grid.linewidth": 0.5,
50
+ "lines.linewidth": 0.55,
51
+ })
52
+
53
+ C_DEMAND = "#3a7cb8"
54
+ C_SERVED = "#d97706"
55
+ C_UNMET = "#b91c1c"
56
+ C_ONHAND = "#0e7c66"
57
+ C_BACKLOG = "#7c3aed"
58
+ C_PHIL = "#c0392b"
59
+ C_BALT = "#e67e22"
60
+ C_THRESH = "#666666"
61
+ C_WIN = "#f1c40f"
62
+
63
+ ITEMS_LEFT = ["I01", "I02", "I03"]
64
+ FOCAL = "I01"
65
+
66
+ # ---------- load ----------
67
+ print("Loading daily records ...")
68
+ records = pd.read_csv(DATA / "daily_records.csv")
69
+ records["unmet"] = records["demand"] - records["served_from_stock"]
70
+
71
+ demand_pv = records.pivot(index="day", columns="item", values="demand")
72
+ served_pv = records.pivot(index="day", columns="item",
73
+ values="served_from_stock")
74
+ unmet_pv = records.pivot(index="day", columns="item", values="unmet")
75
+ onhand_pv = records.pivot(index="day", columns="item",
76
+ values="dest_on_hand_end_before_ship")
77
+ backlog_pv = records.pivot(index="day", columns="item",
78
+ values="dest_backlog_end_before_ship")
79
+ T = demand_pv.shape[0]
80
+ print(f"Horizon T = {T} days; left items: {ITEMS_LEFT}; "
81
+ f"focal item: {FOCAL}")
82
+
83
+ # edge utilisation
84
+ edge_util = np.load(DATA / "edge_utilisation.npy")
85
+ edge_list = pd.read_csv(DATA / "edge_list.csv")
86
+ phil_idx = int(edge_list.index[(edge_list["from"] == "Philadelphia")
87
+ & (edge_list["to"] == "NewYork")][0])
88
+ balt_idx = int(edge_list.index[(edge_list["from"] == "Baltimore")
89
+ & (edge_list["to"] == "NewYork")][0])
90
+
91
+ # ---------- zoom window ----------
92
+ agg = demand_pv.sum(axis=1).values
93
+ WIN = 5 * 365
94
+ roll = pd.Series(agg).rolling(WIN, center=True).mean()
95
+ peak = int(roll.idxmax())
96
+ start = max(0, peak - WIN // 2)
97
+ end = min(T, start + WIN)
98
+ start = max(0, end - WIN)
99
+ print(f"Zoom window: days {start}..{end} (peak at day {peak})")
100
+
101
+ # focal-item slices
102
+ foc_onhand = onhand_pv[FOCAL].values
103
+ foc_backlog = backlog_pv[FOCAL].values
104
+ foc_onhand_z = foc_onhand[start:end]
105
+ foc_backlog_z = foc_backlog[start:end]
106
+ util_phil_z = edge_util[start:end, phil_idx]
107
+ util_balt_z = edge_util[start:end, balt_idx]
108
+
109
+ # ---------- figure ----------
110
+ fig, axes = plt.subplots(3, 2, figsize=(13, 7.4),
111
+ gridspec_kw={"hspace": 0.55, "wspace": 0.20,
112
+ "width_ratios": [1.7, 1.0]})
113
+
114
+ panel_letters_left = ["a", "c", "e"]
115
+
116
+ # -------- left column: per-item demand/served/unmet --------
117
+ for r, item in enumerate(ITEMS_LEFT):
118
+ d = demand_pv[item].values
119
+ s = served_pv[item].values
120
+ u = unmet_pv[item].values
121
+ fr = s.sum() / max(d.sum(), 1e-9)
122
+ ymax = max(d.max(), s.max()) * 1.10
123
+
124
+ ax = axes[r, 0]
125
+ ax.fill_between(np.arange(T), 0, u, color=C_UNMET, alpha=0.18,
126
+ lw=0)
127
+ ax.plot(np.arange(T), d, color=C_DEMAND, lw=0.45, alpha=0.85,
128
+ label="demand" if r == 0 else None)
129
+ ax.plot(np.arange(T), s, color=C_SERVED, lw=0.45, alpha=0.85,
130
+ label="served" if r == 0 else None)
131
+ ax.plot(np.arange(T), u, color=C_UNMET, lw=0.45, alpha=0.95,
132
+ label="unmet" if r == 0 else None)
133
+ ax.axvspan(start, end, color=C_WIN, alpha=0.22, lw=0,
134
+ label="zoom window" if r == 0 else None)
135
+ ax.set_xlim(0, T - 1)
136
+ ax.set_ylim(0, ymax)
137
+ ax.set_ylabel(f"{item} demand (units)")
138
+ ax.set_title(f"({panel_letters_left[r]}) {item} full horizon "
139
+ f"($T={T}$, fill rate {fr:.3f})", loc="left")
140
+ if r == 0:
141
+ ax.legend(loc="upper right", ncol=4, frameon=False,
142
+ columnspacing=1.0, handlelength=1.4)
143
+ if r == 2:
144
+ ax.set_xlabel("time unit")
145
+
146
+ # -------- right column: focal-item mechanism in zoom --------
147
+ # y-axis scaled to the zoom-window range, not the full horizon, so
148
+ # the dynamics inside the window are readable
149
+ ymax_onh = max(1.0, foc_onhand_z.max() * 1.10)
150
+ ymax_bk = max(1.0, foc_backlog_z.max() * 1.10)
151
+ days_z = np.arange(start, end)
152
+
153
+ # (b) destination on-hand for focal item
154
+ ax = axes[0, 1]
155
+ ax.fill_between(days_z, 0, foc_onhand_z, color=C_ONHAND, alpha=0.18,
156
+ lw=0)
157
+ ax.plot(days_z, foc_onhand_z, color=C_ONHAND, lw=0.85, alpha=0.95)
158
+ ax.set_xlim(start, end - 1)
159
+ ax.set_ylim(0, ymax_onh)
160
+ ax.set_ylabel(f"{FOCAL} on-hand")
161
+ ax.set_title(f"(b) Zoom: {FOCAL} destination on-hand inventory",
162
+ loc="left")
163
+
164
+ # (d) destination backlog for focal item
165
+ ax = axes[1, 1]
166
+ ax.fill_between(days_z, 0, foc_backlog_z, color=C_BACKLOG, alpha=0.18,
167
+ lw=0)
168
+ ax.plot(days_z, foc_backlog_z, color=C_BACKLOG, lw=0.85, alpha=0.95)
169
+ ax.set_xlim(start, end - 1)
170
+ ax.set_ylim(0, ymax_bk)
171
+ ax.set_ylabel(f"{FOCAL} backlog")
172
+ ax.set_title(f"(d) Zoom: {FOCAL} destination backlog", loc="left")
173
+
174
+ # (f) last-mile edge utilisation in zoom
175
+ ax = axes[2, 1]
176
+ ax.plot(days_z, util_phil_z, color=C_PHIL, lw=0.85, alpha=0.95,
177
+ label=r"PHL$\to$NYC")
178
+ ax.plot(days_z, util_balt_z, color=C_BALT, lw=0.85, alpha=0.85,
179
+ label=r"BAL$\to$NYC")
180
+ ax.axhline(0.95, color=C_THRESH, ls="--", lw=0.9, alpha=0.7,
181
+ label=r"$U=0.95$")
182
+ ax.set_xlim(start, end - 1)
183
+ ax.set_ylim(0, 1.05)
184
+ ax.set_ylabel(r"$U_{e,t}$")
185
+ ax.set_xlabel("time unit")
186
+ ax.set_title("(f) Zoom: last-mile edge utilisation", loc="left")
187
+ ax.legend(loc="upper left", ncol=3, frameon=False,
188
+ columnspacing=1.0, handlelength=1.6)
189
+
190
+ plt.savefig(OUT / "baseline_overview.pdf", bbox_inches="tight")
191
+ plt.savefig(OUT / "baseline_overview.png", dpi=160,
192
+ bbox_inches="tight")
193
+ print(f"Saved {OUT / 'baseline_overview.pdf'}")
194
+ print(f"Saved {OUT / 'baseline_overview.png'}")
analysis/make_scenario_family.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Figure 4 (replacement): single-row small multiples, rendered as
3
+ two separate figures so each can be placed independently in the
4
+ paper.
5
+
6
+ For one item (I36, with high macro-shock sensitivity g_i and high
7
+ burst rate r_i), under four demand-side scenarios:
8
+
9
+ scenario_family_demand.pdf
10
+ realised demand y_{i,t} + deterministic intensity
11
+ lambda_{i,t} (Eq.\\ ref{eq:intensity})
12
+
13
+ scenario_family_oh.pdf
14
+ destination on-hand OH^{d*,i}_t at NewYork
15
+ (end-of-step value; the network's response to demand)
16
+
17
+ Supply-side sweeps leave y_{i,t} unchanged by construction and
18
+ are omitted.
19
+ """
20
+
21
+ from pathlib import Path
22
+ import numpy as np
23
+ import pandas as pd
24
+ import matplotlib.pyplot as plt
25
+ import matplotlib as mpl
26
+ from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter
27
+ import matplotlib.patheffects as pe
28
+
29
+
30
+ def _short_count(x: float, _pos: int) -> str:
31
+ """Compact integer formatter: 1.2M, 30K, 5K, 0."""
32
+ if x == 0:
33
+ return "0"
34
+ ax = abs(x)
35
+ if ax >= 1e6:
36
+ return f"{x/1e6:g}M"
37
+ if ax >= 1e3:
38
+ return f"{x/1e3:g}K"
39
+ return f"{int(x)}"
40
+
41
+
42
+ SHORT = FuncFormatter(_short_count)
43
+
44
+ REPO = Path(__file__).resolve().parents[1]
45
+ DATA = REPO / "data" / "output_mixture"
46
+ OUT = REPO / "results" / "figures"
47
+ OUT.mkdir(parents=True, exist_ok=True)
48
+
49
+ ITEM = "I36"
50
+ T_FULL = 52560
51
+
52
+ # (dir name, panel title, accent colour for line, shadow tint)
53
+ SCENARIOS = [
54
+ ("baseline", r"baseline",
55
+ "#3470a8", "#cddbeb"), # cobalt / pale blue
56
+ ("drift_hi", r"drift $(\phi^{\mathrm{AR}}\!=\!0.99)$",
57
+ "#7e5b9a", "#dcd2e5"), # plum / pale mauve
58
+ ("shock_xhi", r"shock $(N{\times}h^{G}\!=\!3{\times}4)$",
59
+ "#c08438", "#f0d8b6"), # caramel / pale cream
60
+ ("burst_xhi", r"burst $(r{\times}h^{P}\!=\!3{\times}4)$",
61
+ "#a64141", "#e7c4c2"), # brick / pale rose
62
+ ]
63
+
64
+ # ---------------------------------------------------------------- style
65
+ mpl.rcParams.update({
66
+ "font.family": "serif",
67
+ "font.serif": ["Times New Roman", "Nimbus Roman", "DejaVu Serif"],
68
+ "mathtext.fontset": "cm",
69
+ "font.size": 9.5,
70
+ "axes.labelsize": 9.5,
71
+ "axes.titlesize": 10.0,
72
+ "axes.titleweight": "normal",
73
+ "xtick.labelsize": 8.0,
74
+ "ytick.labelsize": 8.0,
75
+ "axes.spines.top": False,
76
+ "axes.spines.right": False,
77
+ "axes.linewidth": 0.5,
78
+ "xtick.major.width": 0.5,
79
+ "ytick.major.width": 0.5,
80
+ "xtick.minor.width": 0.3,
81
+ "ytick.minor.width": 0.3,
82
+ "xtick.major.size": 2.4,
83
+ "ytick.major.size": 2.4,
84
+ "xtick.minor.size": 1.3,
85
+ "ytick.minor.size": 1.3,
86
+ "xtick.direction": "out",
87
+ "ytick.direction": "out",
88
+ "axes.grid": False,
89
+ "axes.axisbelow": True,
90
+ "pdf.fonttype": 42,
91
+ "ps.fonttype": 42,
92
+ "text.usetex": False,
93
+ })
94
+
95
+ # ---------------------------------------------------------------- data
96
+ def load_panel(scenario: str, item: str):
97
+ df = pd.read_csv(
98
+ DATA / scenario / "seed2025" / "daily_records.csv",
99
+ usecols=("day", "item", "demand", "dest_on_hand_end_before_ship"),
100
+ )
101
+ s = df[df["item"] == item].sort_values("day")
102
+ y = s["demand"].to_numpy()
103
+ oh = s["dest_on_hand_end_before_ship"].to_numpy()
104
+ assert y.size == T_FULL, f"{scenario}: got {y.size} rows for {item}"
105
+
106
+ cols = (DATA / scenario / "seed2025" / "demand_signals_cols.txt").read_text().strip().split(",")
107
+ j = cols.index(item)
108
+ lam = np.load(DATA / scenario / "seed2025" / "demand_signals.npy")[:, j]
109
+ return y, lam, oh
110
+
111
+
112
+ panels = [(scen, title, accent, shadow, *load_panel(scen, ITEM))
113
+ for scen, title, accent, shadow in SCENARIOS]
114
+
115
+ # ---------------------------------------------------------------- figure
116
+ t_years = np.arange(T_FULL) / 1e4
117
+ decim = max(1, T_FULL // 6000)
118
+
119
+
120
+ # =================================================================
121
+ # Figure A: realised demand y_{i,t} + intensity lambda_{i,t}
122
+ # =================================================================
123
+ figA, axesA = plt.subplots(
124
+ 1, len(SCENARIOS), figsize=(12.0, 2.55),
125
+ sharey=False,
126
+ gridspec_kw={"wspace": 0.22},
127
+ )
128
+
129
+ for ax, (scen, title, accent, shadow, y, lam, _oh) in zip(axesA, panels):
130
+ ax.plot(t_years[::decim], y[::decim],
131
+ color=shadow, lw=0.35, alpha=0.85,
132
+ zorder=1, rasterized=True)
133
+ ax.plot(t_years, lam,
134
+ color=accent, lw=0.95, alpha=1.0,
135
+ solid_capstyle="round", solid_joinstyle="round",
136
+ zorder=3,
137
+ path_effects=[pe.Stroke(linewidth=1.9, foreground="white"),
138
+ pe.Normal()])
139
+ ax.set_title(title, style="italic")
140
+ ax.set_xlim(0, T_FULL / 1e4)
141
+ ax.xaxis.set_major_locator(MultipleLocator(1.0))
142
+ ax.xaxis.set_minor_locator(MultipleLocator(0.5))
143
+ ymax_panel = max(y.max(), lam.max())
144
+ ax.set_ylim(0, ymax_panel * 1.06)
145
+ ax.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=True))
146
+ ax.yaxis.set_minor_locator(MaxNLocator(nbins=8, integer=True))
147
+
148
+ axesA[0].set_ylabel(rf"item {ITEM} demand")
149
+ figA.supxlabel(r"time $t$ ($\times 10^{4}$ time units)",
150
+ y=0.04, fontsize=9.5)
151
+ figA.subplots_adjust(left=0.055, right=0.995, top=0.86, bottom=0.22)
152
+
153
+ outA = OUT / "scenario_family_demand.pdf"
154
+ figA.savefig(outA, bbox_inches="tight")
155
+ figA.savefig(outA.with_suffix(".png"), dpi=200, bbox_inches="tight")
156
+ plt.close(figA)
157
+
158
+
159
+ # =================================================================
160
+ # Figure B: destination (NewYork) on-hand OH^{d*, i}_t
161
+ # =================================================================
162
+ figB, axesB = plt.subplots(
163
+ 1, len(SCENARIOS), figsize=(12.0, 2.55),
164
+ sharey=False,
165
+ gridspec_kw={"wspace": 0.22},
166
+ )
167
+
168
+ for ax, (scen, title, accent, shadow, _y, _lam, oh) in zip(axesB, panels):
169
+ ax.plot(t_years[::decim], oh[::decim],
170
+ color=accent, lw=0.85, alpha=1.0,
171
+ solid_capstyle="round", solid_joinstyle="round",
172
+ zorder=3,
173
+ path_effects=[pe.Stroke(linewidth=1.7, foreground="white"),
174
+ pe.Normal()])
175
+ ax.set_title(title, style="italic")
176
+ ax.set_xlim(0, T_FULL / 1e4)
177
+ ax.xaxis.set_major_locator(MultipleLocator(1.0))
178
+ ax.xaxis.set_minor_locator(MultipleLocator(0.5))
179
+ ymax_panel = max(1.0, float(oh.max()))
180
+ ax.set_ylim(0, ymax_panel * 1.08)
181
+ ax.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=True))
182
+ ax.yaxis.set_minor_locator(MaxNLocator(nbins=8, integer=True))
183
+ ax.yaxis.set_major_formatter(SHORT)
184
+
185
+ axesB[0].set_ylabel(rf"destination on-hand, item {ITEM}")
186
+ figB.supxlabel(r"time $t$ ($\times 10^{4}$ time units)",
187
+ y=0.04, fontsize=9.5)
188
+ figB.subplots_adjust(left=0.055, right=0.995, top=0.86, bottom=0.22)
189
+
190
+ outB = OUT / "scenario_family_oh.pdf"
191
+ figB.savefig(outB, bbox_inches="tight")
192
+ figB.savefig(outB.with_suffix(".png"), dpi=200, bbox_inches="tight")
193
+ plt.close(figB)
194
+
195
+
196
+ # =================================================================
197
+ # Figure C: combined 2x4 (demand row on top, on-hand row on bottom).
198
+ # Same data as Figures A and B, sharing the x-axis within each column.
199
+ # =================================================================
200
+ figC, axesC = plt.subplots(
201
+ 2, len(SCENARIOS), figsize=(12.0, 4.6),
202
+ sharex="col", sharey=False,
203
+ gridspec_kw={"wspace": 0.22, "hspace": 0.18},
204
+ )
205
+
206
+ for col, (scen, title, accent, shadow, y, lam, oh) in enumerate(panels):
207
+ ax_top = axesC[0, col]
208
+ ax_bot = axesC[1, col]
209
+
210
+ ax_top.plot(t_years[::decim], y[::decim],
211
+ color=shadow, lw=0.30, alpha=0.70,
212
+ zorder=1, rasterized=True)
213
+ ax_top.plot(t_years, lam,
214
+ color=accent, lw=0.75, alpha=0.85,
215
+ solid_capstyle="round", solid_joinstyle="round",
216
+ zorder=3,
217
+ path_effects=[pe.Stroke(linewidth=1.5, foreground="white"),
218
+ pe.Normal()])
219
+ ax_top.set_title(title, style="italic")
220
+ ax_top.set_xlim(0, T_FULL / 1e4)
221
+ ax_top.xaxis.set_major_locator(MultipleLocator(1.0))
222
+ ax_top.xaxis.set_minor_locator(MultipleLocator(0.5))
223
+ ymax_top = max(y.max(), lam.max())
224
+ ax_top.set_ylim(0, ymax_top * 1.06)
225
+ ax_top.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=True))
226
+ ax_top.yaxis.set_minor_locator(MaxNLocator(nbins=8, integer=True))
227
+
228
+ ax_bot.plot(t_years[::decim], oh[::decim],
229
+ color=accent, lw=0.75, alpha=0.85,
230
+ solid_capstyle="round", solid_joinstyle="round",
231
+ zorder=3,
232
+ path_effects=[pe.Stroke(linewidth=1.5, foreground="white"),
233
+ pe.Normal()])
234
+ ax_bot.set_xlim(0, T_FULL / 1e4)
235
+ ax_bot.xaxis.set_major_locator(MultipleLocator(1.0))
236
+ ax_bot.xaxis.set_minor_locator(MultipleLocator(0.5))
237
+ ymax_bot = max(1.0, float(oh.max()))
238
+ ax_bot.set_ylim(0, ymax_bot * 1.08)
239
+ ax_bot.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=True))
240
+ ax_bot.yaxis.set_minor_locator(MaxNLocator(nbins=8, integer=True))
241
+ ax_bot.yaxis.set_major_formatter(SHORT)
242
+
243
+ axesC[0, 0].set_ylabel(rf"item {ITEM} demand")
244
+ axesC[1, 0].set_ylabel(rf"destination on-hand, item {ITEM}")
245
+ figC.supxlabel(r"time $t$ ($\times 10^{4}$ time units)",
246
+ y=0.02, fontsize=9.5)
247
+ figC.subplots_adjust(left=0.055, right=0.995, top=0.93, bottom=0.11)
248
+
249
+ outC = OUT / "scenario_family.pdf"
250
+ figC.savefig(outC, bbox_inches="tight")
251
+ figC.savefig(outC.with_suffix(".png"), dpi=200, bbox_inches="tight")
252
+ plt.close(figC)
253
+
254
+ print(f"wrote {outA}, {outB}, and {outC} item={ITEM} "
255
+ f"scenarios={[s for s, *_ in SCENARIOS]}")
app.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISOMORPH Supply Chain Digital Twin — Interactive Demo
3
+ ======================================================
4
+ Run locally:
5
+ python app.py
6
+
7
+ Deploy to Hugging Face Spaces:
8
+ Upload this file together with the simulator/ and demo/ directories.
9
+ Set SDK: gradio in the Space README.
10
+
11
+ Interaction flow:
12
+ Configure parameters → ▶ Run Simulation
13
+ → Network Map (animated shipment propagation)
14
+ → Node Detail (per-node time series)
15
+ → Bullwhip (tier-level amplification chart)
16
+ → Edge Util (heatmap of capacity usage)
17
+ → Download (CSV export)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import io
23
+ import tempfile
24
+ import time
25
+ from typing import Optional
26
+
27
+ import gradio as gr
28
+ import numpy as np
29
+ import pandas as pd
30
+
31
+ from simulator.demo_simulator import run_demo_simulation
32
+ from demo.visualize import (
33
+ make_network_animation_html,
34
+ make_network_animation_gif,
35
+ make_node_timeseries,
36
+ make_bullwhip_chart,
37
+ make_edge_heatmap,
38
+ )
39
+
40
+ # ============================================================================
41
+ # Static reference data
42
+ # ============================================================================
43
+
44
+ # All directed edges in the fixed ISOMORPH topology
45
+ _EDGE_STRINGS = [
46
+ "None (no disruption)",
47
+ "SanFrancisco → Nashville",
48
+ "StLouis → Nashville",
49
+ "Orlando → Nashville",
50
+ "Nashville → Atlanta",
51
+ "Atlanta → Chicago",
52
+ "Atlanta → Charlotte",
53
+ "Atlanta → Memphis",
54
+ "Chicago → Columbus",
55
+ "Charlotte → Richmond",
56
+ "Columbus → Philadelphia",
57
+ "Richmond → Philadelphia",
58
+ "Richmond → Baltimore",
59
+ "Columbus → Baltimore",
60
+ "Memphis → Baltimore",
61
+ "Philadelphia → NewYork",
62
+ "Baltimore → NewYork",
63
+ ]
64
+
65
+ # Node choices ordered: destination first, then by tier
66
+ _NODE_TIER = {
67
+ "NewYork": 0,
68
+ "Philadelphia": 1, "Baltimore": 1,
69
+ "Columbus": 2, "Richmond": 2,
70
+ "Charlotte": 3, "Chicago": 3, "Memphis": 3,
71
+ "Atlanta": 4,
72
+ "Nashville": 5,
73
+ "SanFrancisco": 6, "StLouis": 6, "Orlando": 6,
74
+ }
75
+ _NODE_CHOICES = sorted(_NODE_TIER, key=lambda n: (_NODE_TIER[n], n))
76
+
77
+ _TIER_LABEL = {
78
+ 0: "Destination", 1: "Last-mile",
79
+ 2: "Tier-4", 3: "Tier-3",
80
+ 4: "Tier-2 (Atlanta)", 5: "Hub (Nashville)", 6: "Source",
81
+ }
82
+
83
+ # Narrative preset configurations (T, n_items, seed, pipeline_mult,
84
+ # phi_lo, shock_height_scale, burst_rate_scale, burst_height_scale,
85
+ # containers_scale, ss_scale, leadtime_scale,
86
+ # disruption_edge_str, disruption_prob, disruption_duration,
87
+ # holding_cost, backlog_penalty)
88
+ _PRESETS = {
89
+ "baseline": (
90
+ 200, 3, 42, 7.0, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
91
+ "None (no disruption)", 0.0, 10, 1.0, 5.0,
92
+ ),
93
+ "demand_shock": (
94
+ 200, 3, 42, 7.0, 0.95, 4.0, 2.0, 3.0, 1.0, 1.0, 1.0,
95
+ "None (no disruption)", 0.0, 10, 1.0, 5.0,
96
+ ),
97
+ "disruption": (
98
+ 200, 3, 42, 7.0, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
99
+ "Atlanta → Chicago", 0.08, 14, 1.0, 5.0,
100
+ ),
101
+ "low_capacity": (
102
+ 200, 3, 42, 7.0, 0.95, 1.0, 1.0, 1.0, 0.4, 1.0, 1.0,
103
+ "None (no disruption)", 0.0, 10, 1.0, 5.0,
104
+ ),
105
+ }
106
+
107
+
108
+ # ============================================================================
109
+ # Helper utilities
110
+ # ============================================================================
111
+
112
+ def _parse_edge(edge_str: str):
113
+ """'Atlanta → Chicago' → ('Atlanta', 'Chicago'), 'None...' → None."""
114
+ if edge_str.startswith("None"):
115
+ return None
116
+ parts = edge_str.split(" → ")
117
+ return (parts[0].strip(), parts[1].strip()) if len(parts) == 2 else None
118
+
119
+
120
+ def _build_config(T, n_items, seed, pipeline_mult,
121
+ phi_lo, shock_height_scale,
122
+ burst_rate_scale, burst_height_scale,
123
+ containers_scale, ss_scale, leadtime_scale,
124
+ disruption_edge_str, disruption_prob, disruption_duration,
125
+ holding_cost, backlog_penalty) -> dict:
126
+ return {
127
+ "T": int(T),
128
+ "n_items": int(n_items),
129
+ "seed": int(seed),
130
+ "pipeline_mult": float(pipeline_mult),
131
+ "phi_lo": float(phi_lo),
132
+ "phi_hi": min(float(phi_lo) + 0.02, 0.9999),
133
+ "shock_height_scale": float(shock_height_scale),
134
+ "burst_rate_scale": float(burst_rate_scale),
135
+ "burst_height_scale": float(burst_height_scale),
136
+ "containers_scale": float(containers_scale),
137
+ "ss_scale": float(ss_scale),
138
+ "leadtime_scale": float(leadtime_scale),
139
+ "disruption_edge": _parse_edge(str(disruption_edge_str)),
140
+ "disruption_prob": float(disruption_prob),
141
+ "disruption_duration": int(disruption_duration),
142
+ "holding_cost": float(holding_cost),
143
+ "backlog_penalty": float(backlog_penalty),
144
+ }
145
+
146
+
147
+ def _format_summary(result, elapsed_s: float) -> str:
148
+ """Return a Markdown summary table."""
149
+ lines = [
150
+ f"**Simulation complete** in `{elapsed_s:.2f}s` &nbsp;·&nbsp; "
151
+ f"T = {result.T} days &nbsp;·&nbsp; "
152
+ f"Items = {len(result.item_ids)} &nbsp;·&nbsp; "
153
+ f"Shipments logged = {len(result.shipments):,}",
154
+ "",
155
+ "| Item | Fill Rate | Total Demand | Backlogs |",
156
+ "|------|----------:|-------------:|---------:|",
157
+ ]
158
+ for iid in result.item_ids:
159
+ fr = result.fill_rate.get(iid, 0.0)
160
+ dem = int(result.demand[iid].sum())
161
+ bl = int(result.backlog["NewYork"][iid].sum())
162
+ lines.append(f"| `{iid}` | {fr:.1%} | {dem:,} | {bl:,} |")
163
+
164
+ n_ev = len(result.disruption_log)
165
+ if n_ev:
166
+ ded = result.disruption_log[0]["edge"]
167
+ lines += [
168
+ "",
169
+ f"⚠️ **Disruption** on `{ded[0]} → {ded[1]}`"
170
+ f" · {n_ev} event(s) triggered over {result.T} days",
171
+ ]
172
+ return "\n".join(lines)
173
+
174
+
175
+ def _make_csv_bytes(result) -> bytes:
176
+ """Return a CSV as bytes for download."""
177
+ rows = []
178
+ for nid in result.node_ids:
179
+ for iid in result.item_ids:
180
+ for t in range(result.T):
181
+ rows.append({
182
+ "day": t,
183
+ "node": nid,
184
+ "item": iid,
185
+ "inventory": int(result.inventory[nid][iid][t]),
186
+ "backlog": int(result.backlog[nid][iid][t]),
187
+ "inflow": int(result.inflow[nid][iid][t]),
188
+ "outflow": int(result.outflow[nid][iid][t]),
189
+ "demand": int(result.demand[iid][t])
190
+ if nid == "NewYork" else "",
191
+ })
192
+ buf = io.StringIO()
193
+ pd.DataFrame(rows).to_csv(buf, index=False)
194
+ return buf.getvalue().encode()
195
+
196
+
197
+ # ============================================================================
198
+ # Gradio callbacks
199
+ # ============================================================================
200
+
201
+ def run_sim(T, n_items, seed, pipeline_mult,
202
+ phi_lo, shock_height_scale,
203
+ burst_rate_scale, burst_height_scale,
204
+ containers_scale, ss_scale, leadtime_scale,
205
+ disruption_edge_str, disruption_prob, disruption_duration,
206
+ holding_cost, backlog_penalty):
207
+ """
208
+ Main callback: builds config, runs simulation, generates all figures.
209
+ Returns (sim_state, anim_fig, ts_fig, bw_fig, heat_fig, summary_md,
210
+ item_filter_update).
211
+ """
212
+ try:
213
+ config = _build_config(
214
+ T, n_items, seed, pipeline_mult,
215
+ phi_lo, shock_height_scale,
216
+ burst_rate_scale, burst_height_scale,
217
+ containers_scale, ss_scale, leadtime_scale,
218
+ disruption_edge_str, disruption_prob, disruption_duration,
219
+ holding_cost, backlog_penalty,
220
+ )
221
+ t0 = time.time()
222
+ result = run_demo_simulation(config)
223
+ elapsed = time.time() - t0
224
+
225
+ fig_anim = make_network_animation_html(result)
226
+ fig_ts = make_node_timeseries(result, "NewYork")
227
+ fig_bw = make_bullwhip_chart(result)
228
+ fig_heat = make_edge_heatmap(result)
229
+ summary = _format_summary(result, elapsed)
230
+
231
+ item_choices = ["All items"] + result.item_ids
232
+ return (
233
+ result,
234
+ fig_anim,
235
+ fig_ts,
236
+ fig_bw,
237
+ fig_heat,
238
+ summary,
239
+ gr.update(choices=item_choices, value="All items"),
240
+ )
241
+
242
+ except Exception as exc:
243
+ err = f"**Error during simulation:** `{exc}`"
244
+ return None, None, None, None, None, err, gr.update()
245
+
246
+
247
+ def update_timeseries(result, node_id: str, item_filter: str):
248
+ """Redraws the node time-series when node or item selection changes."""
249
+ if result is None:
250
+ return None
251
+ item_ids = None if item_filter == "All items" else [item_filter]
252
+ return make_node_timeseries(result, node_id, item_ids=item_ids)
253
+
254
+
255
+ def generate_gif(result):
256
+ """Renders the network animation as an animated GIF and returns the file path."""
257
+ if result is None:
258
+ return None
259
+ return make_network_animation_gif(result)
260
+
261
+
262
+ def prepare_download(result):
263
+ """Writes a CSV to a temp file and returns the path for gr.File."""
264
+ if result is None:
265
+ return None
266
+ data = _make_csv_bytes(result)
267
+ tmp = tempfile.NamedTemporaryFile(
268
+ suffix=".csv", delete=False, mode="wb", prefix="isomorph_")
269
+ tmp.write(data)
270
+ tmp.close()
271
+ return tmp.name
272
+
273
+
274
+ # Collect all slider/dropdown inputs in order (must match run_sim signature)
275
+ def _all_inputs(components: dict) -> list:
276
+ return [
277
+ components["T"], components["n_items"], components["seed"],
278
+ components["pipeline_mult"], components["phi_lo"],
279
+ components["shock_height_scale"],
280
+ components["burst_rate_scale"], components["burst_height_scale"],
281
+ components["containers_scale"], components["ss_scale"],
282
+ components["leadtime_scale"],
283
+ components["disruption_edge"], components["disruption_prob"],
284
+ components["disruption_duration"],
285
+ components["holding_cost"], components["backlog_penalty"],
286
+ ]
287
+
288
+
289
+ # ============================================================================
290
+ # Gradio application layout
291
+ # ============================================================================
292
+
293
+ _CSS = """
294
+ #run-btn { font-size: 1.1em; }
295
+ .panel-header { font-weight: 600; color: #1a1a2e; }
296
+ .result-tab > div { padding-top: 4px; }
297
+ footer { display: none !important; }
298
+
299
+ /* Shared accent accordion style — used for highlighted collapsibles */
300
+ #markov-accordion > .label-wrap,
301
+ #learn-more-accordion > .label-wrap,
302
+ #state-space-accordion > .label-wrap {
303
+ background: linear-gradient(90deg, #1a3a5c 0%, #2471a3 100%);
304
+ border-radius: 6px;
305
+ padding: 8px 14px;
306
+ }
307
+ #markov-accordion > .label-wrap span,
308
+ #markov-accordion > .label-wrap .icon,
309
+ #learn-more-accordion > .label-wrap span,
310
+ #learn-more-accordion > .label-wrap .icon,
311
+ #state-space-accordion > .label-wrap span,
312
+ #state-space-accordion > .label-wrap .icon {
313
+ color: #ffffff !important;
314
+ font-weight: 700;
315
+ font-size: 1.05em;
316
+ letter-spacing: 0.01em;
317
+ }
318
+ #markov-accordion,
319
+ #learn-more-accordion,
320
+ #state-space-accordion {
321
+ border: 2px solid #2471a3 !important;
322
+ border-radius: 8px;
323
+ margin-top: 10px;
324
+ }
325
+ """
326
+
327
+ _DESCRIPTION = """
328
+ ***Interactive simulation environment for stress-testing supply chains under demand shocks, disruptions, and cascading transport congestion.***
329
+
330
+ Modern supply chains are vulnerable to delays, congestion, shortages, and cascading disruptions. This demo provides an interpretable digital twin for configuring and simulating the real-time evolution of a **stochastic multi-echelon supply-chain network**, studying how local operational decisions propagate through large logistics networks over time.
331
+
332
+ This demo runs a **fixed 13-node US network**: three suppliers (San Francisco, St. Louis, Orlando) → regional hub (Nashville) → distribution warehouses (Atlanta, Chicago, Charlotte, Memphis, Columbus, Richmond) → last-mile DCs (Philadelphia, Baltimore) → destination (New York).
333
+
334
+ **Users can interactively configure:**
335
+ **Demand characteristics** · **Network capacity** · **Edge disruptions** · **Simulation scope** — then observe how operational effects propagate through the system over time.
336
+
337
+ **The platform is designed to support:**
338
+ - Stress-testing supply chains under demand shocks, congestion, and disruptions
339
+ - Studying bullwhip effect amplification across network tiers
340
+ - Visualizing cascading congestion, bottleneck formation, and inventory depletion
341
+ - Building intuition for how local operational decisions affect global network behavior
342
+
343
+ All within a fixed 13-node network and a product catalogue of up to 5 SKUs.
344
+
345
+ ---
346
+
347
+ **Presets and parameter tuning**
348
+
349
+ Four preset buttons at the top instantly load and run a scenario. There is one baseline and three stress-test scenarios:
350
+
351
+ - 🟢 **Baseline** — stable operating conditions; observe inventory cycles and the mild bullwhip effect emerging internally from (s, S) ordering and lead-time delays alone
352
+ - ⚡ **Demand Shock** — correlated macro shocks and per-item bursts amplify demand variability; backlogs build at NewYork (destination) and variability amplifies upstream through the network
353
+ - 🔴 **Disruption** — the Atlanta→Chicago lane is randomly blocked; goods reroute and inventory depletes downstream of the outage, then a catch-up wave propagates on recovery
354
+ - 📦 **Low Capacity** — all edges at 40% capacity; cascading transport congestion propagates from the last-mile inward, causing systemic stockouts and extreme bullwhip amplification
355
+
356
+ To stress-test further, adjust individual parameters using the sliders in the left panel and click **▶ Run Simulation**. The sliders are organized into groups:
357
+
358
+ - **⚙️ Simulation Settings** — horizon length, number of SKUs, random seed, and pipeline multiplier; these define the scope of any run and are not specific to a scenario
359
+ - **📈 Demand Scenario** — shock amplitude, burst frequency and size, AR persistence; primary levers for the ⚡ Demand Shock scenario
360
+ - **🔴 Disruption** — edge to block, outage probability, duration; primary levers for the 🔴 Disruption scenario
361
+ - **🚚 Supply & Network** — edge capacity, reorder thresholds, source lead times; primary levers for the 📦 Low Capacity scenario
362
+ - **💰 Costs** — holding cost and backlog penalty; display-only annotations that do not affect simulation dynamics, but are recorded in the CSV export for post-hoc cost analysis
363
+
364
+ The full simulator — where topology, catalogue size, demand model, replenishment policy, and routing are all user-configurable — is open-source at **https://github.com/tuhinsahai/ISOMORPH**.
365
+ """
366
+
367
+ _LEARN_MORE = """
368
+ ISOMORPH simulates the flow of multiple product types (each called a **Stock Keeping Unit, SKU**) through a directed network of factories, intermediate warehouses, and a customer-facing destination, advancing every location and link forward in discrete time.
369
+
370
+ In each time step, random customer demand arrives at the destination, is served from available stock or recorded as **backlog**, and triggers replenishment orders that propagate back through the network. Shipments travel along directed edges with fixed transit times, are routed via **Dijkstra shortest-path**, packed greedily into **finite-capacity containers**, and restocked at each warehouse under **(s, S) reorder policies** — reorder up to level S when stock drops below reorder point s.
371
+
372
+ Demand is driven by a **five-component signal**: a persistent AR(1) trend, a long-run drift, rare macro shocks shared across all SKUs (e.g. a holiday surge), independent per-item burst events, and Gaussian noise — producing the correlated, lumpy demand patterns characteristic of real logistics networks.
373
+
374
+ The simulation is **stochastic**: demand draws, source lead times, and disruption timing are all random. Changing the **random seed** gives a different realization of the same scenario, enabling ensemble analysis and forward uncertainty quantification.
375
+ """
376
+
377
+ _STATE_SPACE = """
378
+ **Why ISOMORPH tracks more than just inventory — the "right state space"**
379
+
380
+ Most supply-chain datasets record only how much stock sits at each location. ISOMORPH also records, at every simulated day: the replenishment orders currently outstanding with suppliers, the shipments traveling through the network, and a running smoothed estimate of recent demand — alongside on-hand inventory and backlog at every node. Together these five types of information form the **complete state** of the simulation.
381
+
382
+ This matters because knowing only today's warehouse stock is not enough to predict tomorrow: a shipment already in transit will arrive regardless of what else happens, and a pending supplier order determines when restocking will occur. Without tracking these "invisible" quantities, the simulation would need to remember weeks of past history to make predictions. With all five, tomorrow depends only on today — the mathematical property called **Markovian** (a Markov chain).
383
+
384
+ This "right state" design also enforces **three exact conservation laws** that hold on every simulated day, for every random scenario, with no approximation:
385
+ 1. Each node's inventory changes by exactly what arrives minus what ships.
386
+ 2. Total units inside the network change only when a supplier delivers or a customer is served.
387
+ 3. Backlog only grows when customer demand exceeds on-hand stock.
388
+
389
+ These laws are built into the simulator's structure, not imposed afterward — so they serve as verification tools: any violation would indicate a bug, not a modeling choice.
390
+ """
391
+
392
+
393
+ with gr.Blocks(
394
+ title="ISOMORPH Supply Chain Digital Twin",
395
+ css=_CSS,
396
+ theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
397
+ ) as demo:
398
+
399
+ # ── State ────────────────────────────────────────────────────────────────
400
+ sim_state = gr.State(None)
401
+
402
+ # ── Header ───────────────────────────────────────────────────────────────
403
+ gr.Markdown("# ISOMORPH Supply Chain Digital Twin")
404
+ gr.Markdown(_DESCRIPTION)
405
+
406
+ with gr.Accordion("📖 Learn more about the simulator dynamics", open=False, elem_id="learn-more-accordion"):
407
+ gr.Markdown(_LEARN_MORE)
408
+ with gr.Accordion("🔢 Mathematical structure of the state space — the \"right state space\"", open=False, elem_id="state-space-accordion"):
409
+ gr.Markdown(_STATE_SPACE)
410
+
411
+ # ── Quick-preset buttons ─────────────────────────────────────────────────
412
+ with gr.Row():
413
+ gr.Markdown("**Quick presets:**")
414
+ btn_baseline = gr.Button("🟢 Baseline", size="sm", variant="secondary")
415
+ btn_shock = gr.Button("⚡ Demand Shock", size="sm", variant="secondary")
416
+ btn_disrupt = gr.Button("🔴 Disruption", size="sm", variant="secondary")
417
+ btn_low_cap = gr.Button("📦 Low Capacity", size="sm", variant="secondary")
418
+
419
+ gr.Markdown("---")
420
+
421
+ # ── Main layout: controls | results ──────────────────────────────────────
422
+ with gr.Row(equal_height=False):
423
+
424
+ # ── LEFT: Configuration panel ─────────────────────────────────────
425
+ with gr.Column(scale=1, min_width=280):
426
+
427
+ with gr.Accordion("⚙️ Simulation Settings", open=True):
428
+ c = {} # dict of all input components (for easy callback wiring)
429
+
430
+ c["T"] = gr.Slider(
431
+ 10, 500, value=200, step=5,
432
+ label="Horizon T (days)",
433
+ info="Total simulation length. 200 days ≈ 7 months; enough to see seasonal patterns and bullwhip cycles. Keep ≤ 500 for fast response.",
434
+ )
435
+ c["n_items"] = gr.Slider(
436
+ 1, 5, value=3, step=1,
437
+ label="Number of SKUs (Stock Keeping Units)",
438
+ info="Each SKU is an independent product type with its own AR(1) demand, (s,S) reorder policy, and physical volume. More SKUs increases total network load.",
439
+ )
440
+ c["seed"] = gr.Number(
441
+ value=42, label="Random seed", precision=0,
442
+ info="Controls all stochastic draws (demand, shocks, bursts, lead times). Change this to get a different random realization of the same scenario.",
443
+ )
444
+ c["pipeline_mult"] = gr.Slider(
445
+ 0, 15, value=7.0, step=0.5,
446
+ label="Pipeline multiplier",
447
+ info="How many days of smoothed (EMA) demand are kept in the in-transit replenishment pipeline. Higher = more inventory pre-positioned; 0 = purely reactive (s,S) ordering.",
448
+ )
449
+
450
+ with gr.Accordion("📈 Demand Scenario", open=True):
451
+ c["phi_lo"] = gr.Slider(
452
+ 0.50, 0.999, value=0.95, step=0.01,
453
+ label="Demand memory φ (AR coefficient)",
454
+ info="Auto-regressive coefficient controlling demand persistence. Near 1.0 = slow, long trends (hard to forecast); near 0.5 = rapid mean-reverting fluctuations.",
455
+ )
456
+ c["shock_height_scale"] = gr.Slider(
457
+ 0.0, 6.0, value=1.0, step=0.1,
458
+ label="Macro shock amplitude",
459
+ info="Scales the height of rare, correlated demand spikes that hit all SKUs simultaneously (e.g. a holiday surge). 0 = no shocks; 4+ = severe shocks.",
460
+ )
461
+ c["burst_rate_scale"] = gr.Slider(
462
+ 0.0, 6.0, value=1.0, step=0.1,
463
+ label="Idiosyncratic burst frequency",
464
+ info="Scales how often individual-SKU demand bursts occur (independent across items). Higher = more frequent localized spikes.",
465
+ )
466
+ c["burst_height_scale"] = gr.Slider(
467
+ 0.0, 6.0, value=1.0, step=0.1,
468
+ label="Idiosyncratic burst amplitude",
469
+ info="Scales how large each individual-SKU burst is when it occurs.",
470
+ )
471
+
472
+ with gr.Accordion("🔴 Disruption", open=False):
473
+ c["disruption_edge"] = gr.Dropdown(
474
+ choices=_EDGE_STRINGS,
475
+ value="None (no disruption)",
476
+ label="Edge to disrupt",
477
+ info="Select a directed shipping lane to subject to random outages. When triggered, that edge's capacity drops to zero for the disruption duration.",
478
+ )
479
+ c["disruption_prob"] = gr.Slider(
480
+ 0.0, 0.30, value=0.0, step=0.01,
481
+ label="Disruption probability per day",
482
+ info="Each day, a Bernoulli draw with this probability triggers a new disruption event (if none is currently active). 0.08 ≈ one event every ~13 days.",
483
+ )
484
+ c["disruption_duration"] = gr.Slider(
485
+ 1, 60, value=10, step=1,
486
+ label="Disruption duration (days)",
487
+ info="Number of consecutive days the edge is completely blocked once triggered. Marked as red dashed lines in the Node Detail and Edge Util plots.",
488
+ )
489
+
490
+ with gr.Accordion("🚚 Supply & Network", open=False):
491
+ c["containers_scale"] = gr.Slider(
492
+ 0.1, 3.0, value=1.0, step=0.1,
493
+ label="Edge capacity scale",
494
+ info="Multiplies the number of shipping containers available per edge per day. <1 = congested network (queuing, missed shipments); >1 = excess capacity (lower utilization).",
495
+ )
496
+ c["ss_scale"] = gr.Slider(
497
+ 0.1, 3.0, value=1.0, step=0.1,
498
+ label="Reorder threshold (s, S) scale",
499
+ info="Scales both the reorder point s and the order-up-to level S at every warehouse. <1 = lean / just-in-time; >1 = large safety-stock buffers.",
500
+ )
501
+ c["leadtime_scale"] = gr.Slider(
502
+ 0.5, 10.0, value=1.0, step=0.5,
503
+ label="Source lead-time scale",
504
+ info="Multiplies the replenishment lead time at source nodes (SanFrancisco, StLouis, Orlando). High values simulate distant or slow suppliers and create longer supply gaps.",
505
+ )
506
+
507
+ with gr.Accordion("💰 Costs (display only)", open=False):
508
+ c["holding_cost"] = gr.Slider(
509
+ 0.1, 10.0, value=1.0, step=0.1,
510
+ label="Inventory holding cost per unit per day",
511
+ info="Cost of carrying one unit of stock for one day (warehousing, capital). Display only — does not affect simulation dynamics.",
512
+ )
513
+ c["backlog_penalty"] = gr.Slider(
514
+ 1.0, 20.0, value=5.0, step=0.5,
515
+ label="Backlog penalty per unit per day",
516
+ info="Penalty for each unit of unfulfilled demand per day (lost-sale cost, SLA breach). Display only — does not affect simulation dynamics.",
517
+ )
518
+
519
+ run_btn = gr.Button(
520
+ "▶ Run Simulation", variant="primary",
521
+ size="lg", elem_id="run-btn",
522
+ )
523
+ status_md = gr.Markdown(
524
+ "_Configure parameters above and click ▶ Run Simulation._"
525
+ )
526
+
527
+ # ── RIGHT: Results panel ──────────────────────────────────────────
528
+ with gr.Column(scale=2, min_width=560):
529
+
530
+ with gr.Tabs():
531
+
532
+ # Tab 1 — animated network map
533
+ with gr.Tab("🗺️ Network Map"):
534
+ gr.Markdown(
535
+ "This map shows the physical flow of goods across the US supply chain in real time."
536
+ " Node colors reflect inventory health at every simulation day;"
537
+ " moving dots trace shipments as they travel between facilities.\n\n"
538
+ "**Node color** = backlog stress —"
539
+ " <span style='color:#4CAF50'>■ green</span> (healthy)"
540
+ " → <span style='color:#FFC107'>■ yellow</span> (building backlog)"
541
+ " → <span style='color:#F44336'>■ red</span> (stockout) \n"
542
+ "**Node shape** = facility role —"
543
+ " ★ destination · ■ last-mile DC · ● warehouse · ◆ Tier-2 · ⬡ hub · ▲ supplier \n"
544
+ "**Moving dot color** = which Stock Keeping Unit (SKU) is being shipped \n"
545
+ "**Edge thickness** = proportional to the daily shipping capacity of that lane\n\n\n"
546
+ "**How to interact:**\n\n"
547
+ "- ▶ **Play / scrub:** press **▶ Play** to animate through the simulation, or drag the day slider to jump to any point.\n"
548
+ "- 🏭 **Hover over a node:** see its current total inventory and backlog counts.\n"
549
+ "- 📦 **Hover over a moving dot:** see the shipment's origin, destination, quantity, and arrival day."
550
+ )
551
+ anim_plot = gr.HTML(label="Shipment propagation")
552
+ with gr.Row():
553
+ gif_btn = gr.Button(
554
+ "⬇️ Export as GIF", variant="secondary", size="sm",
555
+ )
556
+ gr.Markdown(
557
+ "_Renders up to 80 frames as a portable animated GIF "
558
+ "(requires `kaleido` and `Pillow`). May take ~30 s._"
559
+ )
560
+ gif_dl = gr.File(label="Animated GIF", file_types=[".gif"])
561
+
562
+ # Tab 2 — node detail time series
563
+ with gr.Tab("📊 Node Detail"):
564
+ gr.Markdown(
565
+ "Shows the time history of inventory and material flows at one node across all simulation days."
566
+ " Use this to trace how a facility responds to demand shocks, disruptions, or capacity constraints.\n\n"
567
+ "**Inventory** = on-hand stock at this node (units)"
568
+ " — state component (1), <b>OH<sup>n,i</sup><sub>t</sub></b>: integer count of item <i>i</i> held at node <i>n</i> on day <i>t</i> \n"
569
+ "**Backlog** = <span style='color:#c0392b'>unfulfilled demand</span> waiting to be served (units; non-zero means stock ran out)"
570
+ " — state component (2), <b>B<sup>n,i</sup><sub>t</sub></b>: demand units owed; non-zero only at the destination \n"
571
+ "**Inflow** = units <span style='color:#2980b9'>arriving</span> per day"
572
+ " — daily receipts <b>R<sup>n,i</sup><sub>t</sub></b>: the realization of component (4) IT<sub>t</sub>;"
573
+ " at non-destination nodes R<sup>n,i</sup><sub>t</sub> = <i>q</i> when a pending order Out<sup>n,i</sup> = (<i>t</i>, <i>q</i>) matures;"
574
+ " at the destination it is the subset of IT<sub>t</sub> whose arrival timestamp equals <i>t</i> \n"
575
+ "**Outflow** = units <span style='color:#27ae60'>dispatched</span> per day"
576
+ " — daily dispatches <b>D<sup>n,i</sup><sub>t</sub></b>: units packed onto outgoing edges by the greedy bin-packing algorithm,"
577
+ " driven by component (3) Out<sub>t</sub> and Dijkstra routing \n"
578
+ "**Demand** = actual customer orders received per day, <b><i>y</i><sub>i,t</sub></b>"
579
+ " — external random input to the chain, not a state component; shown only at the destination (NewYork)."
580
+ " Note: the internal <b>smoothed demand estimate</b> (state component 5),"
581
+ " λ̃<sup>i</sup><sub>t+1</sub> = 0.05·<i>y</i><sub>i,t</sub> + 0.95·λ̃<sup>i</sup><sub>t</sub>,"
582
+ " is the exponential moving average of <i>y</i><sub>i,t</sub> and is what drives replenishment order sizing — it is not plotted here \n"
583
+ "**<span style='color:rgba(255,80,80,0.9)'>Red dashed vertical line</span>**"
584
+ " = day a disruption event was triggered on a connected edge\n\n\n"
585
+ "Select a **Node** and an **Item (SKU)** from the dropdowns below. **How to read the panels:**\n\n"
586
+ "- 📦 **Inventory + Backlog:** when inventory (panel 1) hits zero, backlog (panel 2) spikes — the node has run out of stock and is accumulating unfulfilled demand.\n"
587
+ "- 🔄 **Inflow + Outflow:** when inflow (panel 3) persistently exceeds outflow (panel 4), the node is building up stock; when outflow exceeds inflow, it is drawing down reserves.\n"
588
+ "- ⚡ **Demand (destination only):** compare demand (panel 5) with inflow (panel 3) — a lag between a demand spike and the inflow response reveals the end-to-end replenishment delay through the network."
589
+ )
590
+ with gr.Accordion("📐 About the Markov State — Mathematical Details", open=False, elem_id="markov-accordion"):
591
+ gr.Markdown(
592
+ "The simulation is a Markov chain whose full state at each day is "
593
+ "ξ<sub>t</sub> = (OH<sub>t</sub>, B<sub>t</sub>, Out<sub>t</sub>, IT<sub>t</sub>, λ̃<sub>t</sub>) — five components: "
594
+ "(1) on-hand inventory, (2) backlog, (3) outstanding supplier orders, (4) scheduled in-transit arrivals, and (5) smoothed demand estimate. "
595
+ "Together these C(3N+1) = **120 fixed scalar dimensions** "
596
+ "— where N = 13 nodes and C = 3 SKUs (Stock Keeping Units) for this simulation — "
597
+ "(plus a variable-length in-transit list IT<sub>t</sub>) "
598
+ "contain everything needed to determine the next day's state, with no reliance on history. "
599
+ "The full ISOMORPH release provides datasets at catalogue sizes C = 50 and C = 200, corresponding to "
600
+ "**≥ 2,000 and ≥ 8,000 fixed scalar state dimensions** respectively "
601
+ "(see the <a href='https://arxiv.org/pdf/2605.12768' target='_blank'>ISOMORPH paper</a> for the complete state-space specification).\n\n"
602
+ "This panel directly plots components (1) and (2). "
603
+ "The remaining three components are not plotted directly, but their effects appear as the daily increment terms Inflow and Outflow, "
604
+ "which satisfy the per-node conservation law:\n\n"
605
+ "OH<sup>n,i</sup><sub>t+1</sub> = OH<sup>n,i</sup><sub>t</sub> + R<sup>n,i</sup><sub>t</sub> − D<sup>n,i</sup><sub>t</sub>\n\n"
606
+ "This identity holds exactly on every simulated day — it is not an approximation but a structural invariant of the Markov chain. "
607
+ "Any deviation would indicate a simulation bug, not a modeling choice."
608
+ )
609
+ with gr.Row():
610
+ node_dd = gr.Dropdown(
611
+ choices=_NODE_CHOICES,
612
+ value="NewYork",
613
+ label="Node",
614
+ info="13 nodes by tier — NewYork: destination · Philadelphia, Baltimore: last-mile DCs · Columbus, Richmond: tier-4 · Charlotte, Chicago, Memphis: tier-3 · Atlanta: tier-2 · Nashville: hub · SanFrancisco, StLouis, Orlando: suppliers",
615
+ scale=2,
616
+ )
617
+ item_filter_dd = gr.Dropdown(
618
+ choices=["All items"],
619
+ value="All items",
620
+ label="Item (SKU)",
621
+ info="Stock Keeping Unit — a distinct product type. Show all SKUs overlaid, or isolate one to remove clutter.",
622
+ scale=1,
623
+ )
624
+ ts_plot = gr.Plot(label="Time series")
625
+
626
+ # Tab 3 — bullwhip chart
627
+ with gr.Tab("📈 Bullwhip"):
628
+ gr.Markdown(
629
+ 'Measures how much demand variability amplifies as orders travel upstream — the "bullwhip effect."'
630
+ " Replenishment decisions made with delayed inventory information, lead-time uncertainty, and lumpy (s,S) ordering"
631
+ " can cause upstream order variability to exceed downstream demand variability — but the pattern is not always a"
632
+ " simple monotone staircase. **Amplification can be uneven, selective, or even locally suppressed depending on"
633
+ " the scenario and the network topology** — which nodes connect to which, and how many paths exist between tiers.\n\n"
634
+ "**Bar height** = B = Var(inflow) / Var(outflow), averaged over all nodes in that tier \n"
635
+ "**<span style='color:#555'>&#9135;&#9135; Dashed line at B = 1</span>**"
636
+ " = no-amplification baseline; each tier passes demand through unchanged \n"
637
+ "**<span style='color:#c0392b'>B &gt; 1</span>**"
638
+ " = orders placed at that tier are more variable than the downstream demand that triggered them \n"
639
+ "**<span style='color:#2980b9'>B &lt; 1</span>**"
640
+ " = that tier smooths variability — can occur when bottlenecks suppress flow propagation \n"
641
+ "**Tier axis** = <span style='color:#2980b9'>NewYork (downstream, left)</span>"
642
+ " → <span style='color:#c0392b'>Suppliers (upstream, right)</span> \n"
643
+ "**Bar color** = one distinct color per Stock Keeping Unit (SKU)\n\n\n"
644
+ "**What to expect across scenarios:**\n\n"
645
+ "- 🟢 **Baseline:** the network already exhibits a clear bullwhip effect without any externally imposed shocks."
646
+ " Amplification is not uniform — specific tiers, particularly last-mile replenishment nodes, may become dominant"
647
+ " variability amplifiers as local ordering policies interact with transportation delays and inventory thresholds."
648
+ " Complex upstream variability emerges internally from otherwise stable conditions.\n"
649
+ "- ⚡ **Demand Shock:** macro shocks and per-SKU bursts inject additional variability at the destination."
650
+ " Upstream tiers respond with larger replenishment swings as delayed information and transport lags compound the fluctuations.\n"
651
+ "- 🔴 **Disruption:** the pattern becomes mixed. Some tiers may temporarily show B < 1 because transport bottlenecks"
652
+ " suppress downstream flow propagation. Other tiers exhibit sharp amplification as delayed replenishment creates"
653
+ " catch-up ordering waves once the disruption clears.\n"
654
+ "- 📦 **Low Capacity:** congestion and stockouts generate highly lumpy replenishment behavior."
655
+ " Extreme amplification can appear at selected tiers, where recovery orders become much larger and more variable"
656
+ " than the original downstream demand signal."
657
+ )
658
+ bw_plot = gr.Plot(label="Bullwhip amplification")
659
+
660
+ # Tab 4 — edge utilization heatmap
661
+ with gr.Tab("🔥 Edge Util"):
662
+ gr.Markdown(
663
+ "Shows which shipping lanes are busy, congested, or completely blocked at each simulation day."
664
+ " Persistent congestion on a lane starves downstream nodes and propagates stockouts forward.\n\n"
665
+ "**Row** = one directed shipping lane (upstream node → downstream node) \n"
666
+ "**Column** = simulation day \n"
667
+ "**Cell color** = fraction of daily shipping capacity used —"
668
+ " <span style='color:#5b9bd5'>■ light blue</span> (near-empty)"
669
+ " → <span style='color:#e67e22'>■ orange</span> (~50%)"
670
+ " → <span style='color:#c0392b'>■ red</span> (saturated at 100%) \n"
671
+ "**<span style='color:#2980b9'>Blue dashed vertical line</span>**"
672
+ " = day a disruption event blocked this lane (capacity dropped to zero)\n\n\n"
673
+ "**How to read the heatmap:**\n\n"
674
+ "- 🔴 **Red bands:** persistent red on a lane means it is chronically saturated, forcing upstream nodes to stockpile or reroute goods via longer alternate paths.\n"
675
+ "- 🔍 **Locate the bottleneck:** compare the last-mile lanes (Philadelphia/Baltimore → NewYork) with upstream lanes — the lane that stays red longest is the binding constraint.\n"
676
+ "- 🔵 **Disruption events:** enable a disruption in the left panel and re-run to see blue dashed markers appear on the day each outage was triggered."
677
+ )
678
+ heat_plot = gr.Plot(label="Edge utilization")
679
+
680
+ # Tab 5 — download
681
+ with gr.Tab("⬇️ Download"):
682
+ gr.Markdown(
683
+ "Click **Prepare CSV** to generate a downloadable file "
684
+ "containing inventory, backlog, inflow, and outflow "
685
+ "for every node and item at every time step."
686
+ )
687
+ dl_btn = gr.Button("Prepare CSV", variant="secondary")
688
+ csv_dl = gr.File(label="Download", file_types=[".csv"])
689
+
690
+ # =========================================================================
691
+ # Callback wiring
692
+ # =========================================================================
693
+
694
+ _inputs = _all_inputs(c)
695
+
696
+ # ── Outputs returned by run_sim ───────────────────────────────────────────
697
+ _run_outputs = [
698
+ sim_state, anim_plot, ts_plot, bw_plot, heat_plot,
699
+ status_md, item_filter_dd,
700
+ ]
701
+
702
+ # api_name=False on every handler suppresses Gradio's JSON-schema
703
+ # introspection, which crashes on Python 3.9 + Gradio 4.44 when a
704
+ # gr.State holds a complex dataclass (TypeError: 'bool' not iterable).
705
+ run_btn.click(
706
+ fn=run_sim,
707
+ inputs=_inputs,
708
+ outputs=_run_outputs,
709
+ api_name=False,
710
+ )
711
+
712
+ # ── Quick preset buttons (set knobs then run) ─────────────────────────────
713
+ def _preset(name: str):
714
+ vals = _PRESETS[name]
715
+ result_vals = run_sim(*vals)
716
+ return list(vals) + list(result_vals)
717
+
718
+ _preset_slider_outputs = [
719
+ c["T"], c["n_items"], c["seed"], c["pipeline_mult"],
720
+ c["phi_lo"], c["shock_height_scale"],
721
+ c["burst_rate_scale"], c["burst_height_scale"],
722
+ c["containers_scale"], c["ss_scale"], c["leadtime_scale"],
723
+ c["disruption_edge"], c["disruption_prob"], c["disruption_duration"],
724
+ c["holding_cost"], c["backlog_penalty"],
725
+ ]
726
+
727
+ for btn, name in [
728
+ (btn_baseline, "baseline"),
729
+ (btn_shock, "demand_shock"),
730
+ (btn_disrupt, "disruption"),
731
+ (btn_low_cap, "low_capacity"),
732
+ ]:
733
+ btn.click(
734
+ fn=lambda n=name: _preset(n),
735
+ inputs=[],
736
+ outputs=_preset_slider_outputs + _run_outputs,
737
+ api_name=False,
738
+ )
739
+
740
+ # ── Node / item selection updates time series ─────────────────────────────
741
+ node_dd.change(
742
+ fn=update_timeseries,
743
+ inputs=[sim_state, node_dd, item_filter_dd],
744
+ outputs=[ts_plot],
745
+ api_name=False,
746
+ )
747
+ item_filter_dd.change(
748
+ fn=update_timeseries,
749
+ inputs=[sim_state, node_dd, item_filter_dd],
750
+ outputs=[ts_plot],
751
+ api_name=False,
752
+ )
753
+
754
+ # ── CSV download ──────────────────────────────────────────────────────────
755
+ dl_btn.click(
756
+ fn=prepare_download,
757
+ inputs=[sim_state],
758
+ outputs=[csv_dl],
759
+ api_name=False,
760
+ )
761
+
762
+ # ── GIF export ────────────────────────────────────────────────────────────
763
+ gif_btn.click(
764
+ fn=generate_gif,
765
+ inputs=[sim_state],
766
+ outputs=[gif_dl],
767
+ api_name=False,
768
+ )
769
+
770
+ # ── Auto-run baseline on page load ────────────────────────────────────────
771
+ demo.load(
772
+ fn=lambda: _preset("baseline"),
773
+ inputs=[],
774
+ outputs=_preset_slider_outputs + _run_outputs,
775
+ api_name=False,
776
+ )
777
+
778
+
779
+ # ============================================================================
780
+ # Entry point
781
+ # ============================================================================
782
+
783
+ if __name__ == "__main__":
784
+ demo.launch(
785
+ server_name="0.0.0.0",
786
+ server_port=7860,
787
+ share=False,
788
+ show_error=True,
789
+ )
demo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
demo/__init__.py ADDED
File without changes
demo/visualize.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISOMORPH Demo — Visualization Module
3
+ =====================================
4
+ All functions return plotly.graph_objects.Figure objects for rendering in
5
+ Gradio gr.Plot components.
6
+
7
+ Public API
8
+ ----------
9
+ make_network_animation(result, frame_step=None)
10
+ Animated US map: shipment particles travel along edges, nodes colored
11
+ by backlog stress. Play / Pause controls + scrubber slider.
12
+
13
+ make_node_timeseries(result, node_id, item_ids=None)
14
+ Subplot panel for one node: inventory, backlog, inflow, outflow,
15
+ and (at the destination) realized demand.
16
+
17
+ make_bullwhip_chart(result)
18
+ Bar chart of mean B = Var(inflow)/Var(outflow) per tier, from
19
+ destination (left) to hub/sources (right).
20
+
21
+ make_edge_heatmap(result)
22
+ Heatmap of edge utilization (fraction of daily capacity) over time.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import Dict, List, Optional, Tuple
28
+
29
+ import numpy as np
30
+ import plotly.graph_objects as go
31
+ from plotly.subplots import make_subplots
32
+
33
+ # ── Color constants ───────────────────────────────────────────────────────────
34
+
35
+ # One color per item (up to 5) — matplotlib tab10 subset
36
+ ITEM_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd", "#d62728"]
37
+
38
+ # Node stress: green (0 = no backlog) → yellow → red (1 = all backlog)
39
+ NODE_COLORSCALE = [[0.0, "#4CAF50"], [0.35, "#FFC107"], [1.0, "#F44336"]]
40
+
41
+ # Visual size of node marker by tier
42
+ _TIER_MARKER_SIZE = {0: 22, 1: 16, 2: 13, 3: 13, 4: 15, 5: 18, 6: 11}
43
+
44
+ # Human-readable tier labels (for bullwhip x-axis)
45
+ _TIER_LABEL = {
46
+ 0: "Destination",
47
+ 1: "Last-mile",
48
+ 2: "Tier-4",
49
+ 3: "Tier-3",
50
+ 4: "Tier-2<br>(Atlanta)",
51
+ 5: "Hub<br>(Nashville)",
52
+ 6: "Sources",
53
+ }
54
+
55
+ # Descriptive role name shown in hover tooltip and node-type legend
56
+ _TIER_TYPE_NAME = {
57
+ 0: "Destination (end customer)",
58
+ 1: "Last-mile DC",
59
+ 2: "Tier-4 Warehouse",
60
+ 3: "Tier-3 Warehouse",
61
+ 4: "Tier-2 Warehouse (Atlanta)",
62
+ 5: "Regional Hub (Nashville)",
63
+ 6: "Source / Supplier",
64
+ }
65
+
66
+ # Plotly Scattergeo marker symbol per tier — visually distinguishes role
67
+ _TIER_SYMBOL = {
68
+ 0: "star", # Destination — prominent star
69
+ 1: "square", # Last-mile DCs
70
+ 2: "circle", # Tier-4 warehouses
71
+ 3: "circle", # Tier-3 warehouses
72
+ 4: "diamond", # Tier-2 (Atlanta pivot)
73
+ 5: "hexagram", # Hub (Nashville)
74
+ 6: "triangle-up", # Source suppliers
75
+ }
76
+
77
+
78
+ # ============================================================================
79
+ # Private helpers
80
+ # ============================================================================
81
+
82
+ def _edge_travel_times(result) -> Dict[Tuple[str, str], float]:
83
+ """Derive edge travel times from the first shipment that uses each hop."""
84
+ tt: Dict[Tuple[str, str], float] = {}
85
+ for s in result.shipments:
86
+ path = s["path_nodes"]
87
+ ets = s["edge_times"]
88
+ for h in range(len(ets)):
89
+ hop = (path[h], path[h + 1])
90
+ if hop not in tt:
91
+ tt[hop] = float(ets[h])
92
+ if len(tt) == len(result.edge_ids):
93
+ return tt # early exit once all edges covered
94
+ return tt
95
+
96
+
97
+ def _node_stress(result, t: int) -> List[float]:
98
+ """
99
+ Return a per-node backlog-fraction in [0, 1] at timestep t.
100
+ stress = total_backlog / (total_backlog + total_inventory + 1)
101
+ Ordered to match result.node_ids.
102
+ """
103
+ fracs = []
104
+ for nid in result.node_ids:
105
+ bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
106
+ inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
107
+ fracs.append(bl / max(bl + inv, 1))
108
+ return fracs
109
+
110
+
111
+ def _node_hover_custom(result, t: int,
112
+ type_names: Optional[List[str]] = None) -> List[List]:
113
+ """
114
+ Return per-node customdata rows [inv, bl, type_name] at time t.
115
+ type_names (static) is passed once and repeated in every frame so the
116
+ hovertemplate can reference %{customdata[2]}.
117
+ """
118
+ rows = []
119
+ for i, nid in enumerate(result.node_ids):
120
+ inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
121
+ bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
122
+ tname = type_names[i] if type_names else ""
123
+ rows.append([inv, bl, tname])
124
+ return rows
125
+
126
+
127
+ def _interpolate_position(
128
+ shipment: dict,
129
+ t: int,
130
+ node_coords: Dict[str, Tuple[float, float]],
131
+ ) -> Optional[Tuple[float, float]]:
132
+ """
133
+ Return (lat, lon) for a shipment at integer day t, or None if not active.
134
+ The particle travels along path_nodes using edge_times for timing.
135
+ """
136
+ d = shipment["day"]
137
+ a = shipment["arrival_day"]
138
+ if t < d or t >= a:
139
+ return None
140
+ path = shipment["path_nodes"]
141
+ ets = shipment["edge_times"]
142
+ elapsed = float(t - d)
143
+ cum = 0.0
144
+ for h, tt in enumerate(ets):
145
+ tt = float(tt)
146
+ if elapsed <= cum + tt + 1e-9:
147
+ frac = (elapsed - cum) / max(tt, 1e-6)
148
+ frac = max(0.0, min(1.0, frac))
149
+ lat_a, lon_a = node_coords[path[h]]
150
+ lat_b, lon_b = node_coords[path[h + 1]]
151
+ return lat_a + frac * (lat_b - lat_a), lon_a + frac * (lon_b - lon_a)
152
+ cum += tt
153
+ return None
154
+
155
+
156
+ def _frame_particles(result, t: int, item_filter: Optional[str] = None):
157
+ """
158
+ Return (lats, lons, texts, colors) for all active shipment particles
159
+ at timestep t. colors is a list of CSS color strings keyed by item.
160
+ """
161
+ item_color_map = {iid: ITEM_COLORS[i % len(ITEM_COLORS)]
162
+ for i, iid in enumerate(result.item_ids)}
163
+ lats, lons, texts, colors = [], [], [], []
164
+ for s in result.shipments:
165
+ if item_filter and s["item"] != item_filter:
166
+ continue
167
+ pos = _interpolate_position(s, t, result.node_coords)
168
+ if pos is None:
169
+ continue
170
+ lat, lon = pos
171
+ lats.append(lat)
172
+ lons.append(lon)
173
+ texts.append(
174
+ f"<b>{s['item']}</b><br>{s['from']}→{s['to']}<br>"
175
+ f"qty: {s['units']} day {s['day']}→{s['arrival_day']}"
176
+ )
177
+ colors.append(item_color_map[s["item"]])
178
+ return lats, lons, texts, colors
179
+
180
+
181
+ def _bullwhip_ratios(result) -> Dict[str, Dict[str, float]]:
182
+ """
183
+ Compute B_(n,i) = Var(inflow_(n,i)) / Var(outflow_(n,i)) per node/item.
184
+ At the destination node outflow is replaced by customer demand.
185
+ Source nodes (tier 6) are excluded (no inflow on network edges).
186
+ """
187
+ dest_id = next(
188
+ nid for nid in result.node_ids if result.tier.get(nid) == 0
189
+ )
190
+ ratios: Dict[str, Dict[str, float]] = {}
191
+ for nid in result.node_ids:
192
+ tier = result.tier.get(nid, -1)
193
+ if tier == 6:
194
+ continue
195
+ ratios[nid] = {}
196
+ for iid in result.item_ids:
197
+ inflow_arr = result.inflow[nid][iid].astype(float)
198
+ if nid == dest_id:
199
+ outflow_arr = result.demand[iid].astype(float)
200
+ else:
201
+ outflow_arr = result.outflow[nid][iid].astype(float)
202
+ var_in = float(np.var(inflow_arr, ddof=1)) if result.T > 1 else 0.0
203
+ var_out = float(np.var(outflow_arr, ddof=1)) if result.T > 1 else 0.0
204
+ if var_out > 1.0: # guard against near-zero outflow variance
205
+ ratios[nid][iid] = var_in / var_out
206
+ return ratios
207
+
208
+
209
+ # ============================================================================
210
+ # 1. Animated network map
211
+ # ============================================================================
212
+
213
+ def make_network_animation(
214
+ result,
215
+ frame_step: Optional[int] = None,
216
+ frame_duration_ms: int = 150,
217
+ ) -> go.Figure:
218
+ """
219
+ Animated Scattergeo map of the supply-chain network.
220
+
221
+ Parameters
222
+ ----------
223
+ result : SimResult
224
+ frame_step : int or None
225
+ Days between animation frames. Auto-computed to cap at ~200 frames.
226
+ frame_duration_ms : int
227
+ Milliseconds per frame during playback.
228
+ """
229
+ T = result.T
230
+ max_frames = 200
231
+ if frame_step is None:
232
+ frame_step = max(1, T // max_frames)
233
+ frame_times = list(range(0, T, frame_step))
234
+
235
+ node_ids = result.node_ids
236
+ coords = result.node_coords
237
+ edge_tt = _edge_travel_times(result)
238
+
239
+ # ── Node metadata ──────────────────────────────────────────────────────
240
+ node_lats = [coords[n][0] for n in node_ids]
241
+ node_lons = [coords[n][1] for n in node_ids]
242
+ node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13)
243
+ for n in node_ids]
244
+ node_labels = node_ids
245
+ node_type_names = [_TIER_TYPE_NAME.get(result.tier.get(n, -1), "Warehouse")
246
+ for n in node_ids]
247
+ node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle")
248
+ for n in node_ids]
249
+
250
+ # ── Edge traces (static) ───────────────────────────────────────────────
251
+ edge_traces = []
252
+ max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0
253
+ for eid in result.edge_ids:
254
+ u, v = eid
255
+ if u not in coords or v not in coords:
256
+ continue
257
+ cap = result.edge_cap.get(eid, 1.0)
258
+ lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5
259
+ tt = edge_tt.get(eid, "?")
260
+ edge_traces.append(go.Scattergeo(
261
+ lat=[coords[u][0], coords[v][0], None],
262
+ lon=[coords[u][1], coords[v][1], None],
263
+ mode="lines",
264
+ line=dict(width=lw, color="rgba(140,140,160,0.55)"),
265
+ hoverinfo="text",
266
+ text=f"{u} → {v}<br>travel: {tt} day(s)<br>daily cap: {cap:.0f} units",
267
+ showlegend=False,
268
+ name=f"{u}→{v}",
269
+ ))
270
+ n_edge_traces = len(edge_traces)
271
+
272
+ # ── Initial node trace (t=0) ───────────────────────────────────────────
273
+ stress_0 = _node_stress(result, 0)
274
+ custom_0 = _node_hover_custom(result, 0, node_type_names)
275
+ node_trace = go.Scattergeo(
276
+ lat=node_lats,
277
+ lon=node_lons,
278
+ mode="markers+text",
279
+ text=node_labels,
280
+ textposition="top center",
281
+ textfont=dict(size=9, color="black"),
282
+ marker=dict(
283
+ size=node_sizes,
284
+ symbol=node_symbols,
285
+ color=stress_0,
286
+ colorscale=NODE_COLORSCALE,
287
+ cmin=0.0, cmax=1.0,
288
+ colorbar=dict(
289
+ title="Backlog<br>stress",
290
+ thickness=12,
291
+ len=0.5,
292
+ x=1.01,
293
+ tickvals=[0, 0.5, 1],
294
+ ticktext=["0 (healthy)", "0.5", "1 (stockout)"],
295
+ tickfont=dict(size=9),
296
+ ),
297
+ line=dict(width=1.5, color="white"),
298
+ ),
299
+ customdata=custom_0,
300
+ hovertemplate=(
301
+ "<b>%{text}</b><br>"
302
+ "Role: %{customdata[2]}<br>"
303
+ "Total inventory: %{customdata[0]:,} units<br>"
304
+ "Total backlog: %{customdata[1]:,} units<br>"
305
+ "<extra></extra>"
306
+ ),
307
+ showlegend=False,
308
+ name="nodes",
309
+ )
310
+
311
+ # ── Initial shipment trace (t=0) ───────────────────────────────────────
312
+ lats0, lons0, texts0, colors0 = _frame_particles(result, 0)
313
+ ship_trace = go.Scattergeo(
314
+ lat=lats0,
315
+ lon=lons0,
316
+ mode="markers",
317
+ marker=dict(size=7, color=colors0, opacity=0.85,
318
+ line=dict(width=0.5, color="white")),
319
+ hoverinfo="text",
320
+ text=texts0,
321
+ showlegend=False,
322
+ name="shipments",
323
+ )
324
+
325
+ # ── Legend: item colors (shipment dots) ──────────────────────────────
326
+ item_legend_traces = []
327
+ for i, iid in enumerate(result.item_ids):
328
+ item_legend_traces.append(go.Scattergeo(
329
+ lat=[None], lon=[None],
330
+ mode="markers",
331
+ marker=dict(size=9, color=ITEM_COLORS[i % len(ITEM_COLORS)],
332
+ symbol="circle"),
333
+ name=iid,
334
+ legendgrouptitle_text="Shipment items" if i == 0 else None,
335
+ legendgroup="items",
336
+ showlegend=True,
337
+ ))
338
+
339
+ # ── Legend: node-type symbols ─────────────────────────────────────────
340
+ # One proxy trace per distinct tier present in this result.
341
+ seen_tiers: dict = {}
342
+ for n in node_ids:
343
+ t_val = result.tier.get(n, -1)
344
+ if t_val not in seen_tiers:
345
+ seen_tiers[t_val] = (_TIER_TYPE_NAME.get(t_val, "Unknown"),
346
+ _TIER_SYMBOL.get(t_val, "circle"))
347
+ node_type_legend_traces = []
348
+ for i, (t_val, (tname, sym)) in enumerate(
349
+ sorted(seen_tiers.items(), key=lambda x: x[0])):
350
+ node_type_legend_traces.append(go.Scattergeo(
351
+ lat=[None], lon=[None],
352
+ mode="markers",
353
+ marker=dict(size=10, color="gray", symbol=sym),
354
+ name=tname,
355
+ legendgrouptitle_text="Node types" if i == 0 else None,
356
+ legendgroup="node_types",
357
+ showlegend=True,
358
+ ))
359
+
360
+ # ── Assemble base figure ───────────────────────────────────────────────
361
+ all_traces = (edge_traces + [node_trace, ship_trace]
362
+ + item_legend_traces + node_type_legend_traces)
363
+ fig = go.Figure(data=all_traces)
364
+ # Trace indices of dynamic traces (edge + node + ship; legend traces follow)
365
+ idx_nodes = n_edge_traces
366
+ idx_ships = n_edge_traces + 1
367
+
368
+ # ── Pre-compute frames ─────────────────────────────────────────────────
369
+ frames = []
370
+ for t in frame_times:
371
+ stress_t = _node_stress(result, t)
372
+ custom_t = _node_hover_custom(result, t, node_type_names)
373
+ lats_t, lons_t, texts_t, colors_t = _frame_particles(result, t)
374
+
375
+ frames.append(go.Frame(
376
+ data=[
377
+ # Plain dicts avoid serialising default None lat/lon values
378
+ # that go.Scattergeo() would inject and confuse Plotly.js.
379
+ {"type": "scattergeo",
380
+ "marker": {"color": stress_t},
381
+ "customdata": custom_t},
382
+ {"type": "scattergeo",
383
+ "lat": lats_t, "lon": lons_t,
384
+ "text": texts_t,
385
+ "marker": {"color": colors_t, "size": 7, "opacity": 0.85}},
386
+ ],
387
+ traces=[idx_nodes, idx_ships],
388
+ layout={"title": {"text": f"<b>ISOMORPH Supply Chain</b> — Day {t} / {T - 1}"}},
389
+ name=str(t),
390
+ ))
391
+ fig.frames = frames
392
+
393
+ # ── Slider steps ──────────────────────────────────────────────────────
394
+ slider_steps = [
395
+ dict(
396
+ method="animate",
397
+ args=[[str(t)],
398
+ dict(mode="immediate",
399
+ frame=dict(duration=0, redraw=True),
400
+ transition=dict(duration=0))],
401
+ label=str(t),
402
+ )
403
+ for t in frame_times
404
+ ]
405
+
406
+ # ── Layout ────────────────────────────────────────────────────────────
407
+ fig.update_layout(
408
+ title=dict(
409
+ text=f"<b>ISOMORPH Supply Chain</b> — Day 0 / {T - 1}",
410
+ x=0.5, xanchor="center", font=dict(size=14),
411
+ ),
412
+ geo=dict(
413
+ scope="usa",
414
+ projection_type="albers usa",
415
+ showland=True, landcolor="rgb(243,243,243)",
416
+ showlakes=True, lakecolor="rgb(210,230,255)",
417
+ showrivers=True, rivercolor="rgb(210,230,255)",
418
+ showcoastlines=True, coastlinecolor="rgb(180,180,200)",
419
+ showsubunits=True, subunitcolor="rgb(200,200,215)",
420
+ bgcolor="rgba(255,255,255,0)",
421
+ ),
422
+ legend=dict(
423
+ x=0.01, y=0.01,
424
+ bgcolor="rgba(255,255,255,0.82)",
425
+ bordercolor="lightgray", borderwidth=1,
426
+ font=dict(size=9),
427
+ tracegroupgap=6,
428
+ ),
429
+ updatemenus=[dict(
430
+ type="buttons",
431
+ showactive=False,
432
+ y=1.08, x=0.0, xanchor="left",
433
+ buttons=[
434
+ dict(
435
+ label="▶ Play",
436
+ method="animate",
437
+ args=[None, dict(
438
+ frame=dict(duration=frame_duration_ms, redraw=True),
439
+ fromcurrent=True,
440
+ transition=dict(duration=0),
441
+ )],
442
+ ),
443
+ dict(
444
+ label="⏸ Pause",
445
+ method="animate",
446
+ args=[[None], dict(
447
+ frame=dict(duration=0, redraw=False),
448
+ mode="immediate",
449
+ transition=dict(duration=0),
450
+ )],
451
+ ),
452
+ ],
453
+ )],
454
+ sliders=[dict(
455
+ currentvalue=dict(
456
+ prefix="Day: ",
457
+ font=dict(size=11),
458
+ visible=True,
459
+ xanchor="center",
460
+ ),
461
+ pad=dict(t=50, b=10),
462
+ len=0.9,
463
+ x=0.05,
464
+ steps=slider_steps,
465
+ transition=dict(duration=0),
466
+ )],
467
+ margin=dict(l=0, r=0, t=80, b=60),
468
+ height=540,
469
+ paper_bgcolor="white",
470
+ )
471
+ return fig
472
+
473
+
474
+ # ============================================================================
475
+ # 2. Node detail time-series panel
476
+ # ============================================================================
477
+
478
+ def make_node_timeseries(
479
+ result,
480
+ node_id: str,
481
+ item_ids: Optional[List[str]] = None,
482
+ ) -> go.Figure:
483
+ """
484
+ 4–5 subplot panel for a single node showing inventory, backlog, inflow,
485
+ outflow, and (destination only) realized demand.
486
+
487
+ Parameters
488
+ ----------
489
+ result : SimResult
490
+ node_id : str
491
+ Node to visualise.
492
+ item_ids : list[str] or None
493
+ Subset of items to plot. Defaults to all items in result.
494
+ """
495
+ if item_ids is None:
496
+ item_ids = result.item_ids
497
+
498
+ dest_id = next(n for n in result.node_ids if result.tier.get(n) == 0)
499
+ is_dest = (node_id == dest_id)
500
+ # Use plain Python lists — avoids numpy int32/int64 serialization issues
501
+ # inside Gradio 4.x's gr.Plot JSON path.
502
+ days = list(range(result.T))
503
+ n_panels = 5 if is_dest else 4
504
+
505
+ panel_titles = [
506
+ "On-hand inventory (units stored at this node)",
507
+ "Backlog (unfulfilled demand pending)",
508
+ "Inflow (units arriving per day)",
509
+ "Outflow (units shipped out per day)",
510
+ ]
511
+ if is_dest:
512
+ panel_titles.append("Realized demand (customer orders received per day)")
513
+
514
+ fig = make_subplots(
515
+ rows=n_panels, cols=1,
516
+ shared_xaxes=True,
517
+ subplot_titles=panel_titles,
518
+ vertical_spacing=0.06,
519
+ )
520
+
521
+ for idx, iid in enumerate(item_ids):
522
+ color = ITEM_COLORS[idx % len(ITEM_COLORS)]
523
+
524
+ # Capture loop variables by value via default args to avoid closure issues.
525
+ def _add(row, y_arr, dash="solid", _iid=iid, _color=color):
526
+ fig.add_trace(
527
+ go.Scatter(
528
+ x=days,
529
+ y=y_arr.tolist(), # convert np.int32 → Python ints
530
+ mode="lines",
531
+ name=_iid,
532
+ line=dict(color=_color, width=1.5, dash=dash),
533
+ legendgroup=_iid,
534
+ showlegend=(row == 1),
535
+ hovertemplate=f"Day %{{x}}<br>{_iid}: %{{y:,.0f}}<extra></extra>",
536
+ ),
537
+ row=row, col=1,
538
+ )
539
+
540
+ _add(1, result.inventory[node_id][iid])
541
+ _add(2, result.backlog[node_id][iid])
542
+ _add(3, result.inflow[node_id][iid])
543
+ _add(4, result.outflow[node_id][iid])
544
+ if is_dest:
545
+ _add(5, result.demand[iid], dash="dot")
546
+
547
+ # Disruption event markers — one vline per panel.
548
+ # add_vline(row=, col=) was not available in all Plotly 5.x builds;
549
+ # iterate over subplot y-axis references instead.
550
+ if result.disruption_log:
551
+ yref_list = ["y"] + [f"y{i}" for i in range(2, n_panels + 1)]
552
+ for ev in result.disruption_log:
553
+ for yref in yref_list:
554
+ fig.add_shape(
555
+ type="line",
556
+ x0=ev["day"], x1=ev["day"],
557
+ y0=0, y1=1,
558
+ xref="x", yref=f"{yref} domain",
559
+ line=dict(dash="dash", color="rgba(255,80,80,0.5)", width=1),
560
+ )
561
+
562
+ # Panel y-axis labels
563
+ y_labels = ["Units", "Units", "Units/day", "Units/day"]
564
+ if is_dest:
565
+ y_labels.append("Units/day")
566
+ for row, lbl in enumerate(y_labels, start=1):
567
+ fig.update_yaxes(title_text=lbl, row=row, col=1,
568
+ title_font=dict(size=10), tickfont=dict(size=9))
569
+
570
+ fig.update_xaxes(title_text="Day", row=n_panels, col=1,
571
+ tickfont=dict(size=9))
572
+
573
+ tier_name = _TIER_LABEL.get(result.tier.get(node_id, -1), "")
574
+ type_name = _TIER_TYPE_NAME.get(result.tier.get(node_id, -1), "")
575
+ disruption_note = (
576
+ " · <span style='color:rgba(255,80,80,0.9)'>red dashes = disruption events</span>"
577
+ if result.disruption_log else ""
578
+ )
579
+ fig.update_layout(
580
+ title=dict(
581
+ text=(f"<b>{node_id}</b> — {type_name}"
582
+ f"{disruption_note}"),
583
+ x=0.5, xanchor="center", font=dict(size=13),
584
+ ),
585
+ legend=dict(
586
+ title="Items (SKUs)",
587
+ x=1.01, y=1.0,
588
+ font=dict(size=10),
589
+ bgcolor="rgba(255,255,255,0.8)",
590
+ bordercolor="lightgray", borderwidth=1,
591
+ ),
592
+ height=180 * n_panels + 80,
593
+ paper_bgcolor="white",
594
+ plot_bgcolor="rgba(248,248,252,1)",
595
+ margin=dict(l=60, r=120, t=60, b=50),
596
+ )
597
+ # Light grid
598
+ fig.update_xaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
599
+ fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
600
+ return fig
601
+
602
+
603
+ # ============================================================================
604
+ # 3. Bullwhip amplification chart
605
+ # ============================================================================
606
+
607
+ def make_bullwhip_chart(result) -> go.Figure:
608
+ """
609
+ Bar chart of tier-level bullwhip ratio B = Var(inflow)/Var(outflow).
610
+ Bars are grouped by item; tiers run from destination (left) to hub (right).
611
+ A dashed reference line at B = 1 marks the no-amplification baseline.
612
+ """
613
+ ratios = _bullwhip_ratios(result)
614
+
615
+ # Collect tiers present (excluding sources, tier 6)
616
+ tiers_present = sorted(
617
+ {result.tier[n] for n in ratios if ratios[n]},
618
+ )
619
+
620
+ # Per-item bar traces
621
+ fig = go.Figure()
622
+ for idx, iid in enumerate(result.item_ids):
623
+ tier_means = []
624
+ for t in tiers_present:
625
+ nodes_in_tier = [n for n in ratios
626
+ if result.tier.get(n) == t and iid in ratios[n]]
627
+ if nodes_in_tier:
628
+ tier_means.append(
629
+ float(np.mean([ratios[n][iid] for n in nodes_in_tier]))
630
+ )
631
+ else:
632
+ tier_means.append(None)
633
+
634
+ x_labels = [_TIER_LABEL.get(t, f"Tier {t}") for t in tiers_present]
635
+ fig.add_trace(go.Bar(
636
+ x=x_labels,
637
+ y=tier_means,
638
+ name=iid,
639
+ marker_color=ITEM_COLORS[idx % len(ITEM_COLORS)],
640
+ opacity=0.82,
641
+ text=[f"{v:.2f}" if v is not None else "" for v in tier_means],
642
+ textposition="outside",
643
+ textfont=dict(size=9),
644
+ ))
645
+
646
+ # Reference line at B = 1
647
+ fig.add_hline(
648
+ y=1.0, line_dash="dash",
649
+ line_color="rgba(80,80,80,0.6)", line_width=1.5,
650
+ annotation_text="B = 1 (no amplification)",
651
+ annotation_position="top right",
652
+ annotation_font_size=9,
653
+ )
654
+
655
+ fig.update_layout(
656
+ title=dict(
657
+ text=(
658
+ "<b>Bullwhip Amplification by Tier</b><br>"
659
+ "<sup>B &gt; 1 means demand variability grows as orders travel upstream</sup>"
660
+ ),
661
+ x=0.5, xanchor="center", font=dict(size=13),
662
+ ),
663
+ xaxis=dict(
664
+ title="Network tier (NewYork = downstream → Suppliers = upstream)",
665
+ title_font=dict(size=11),
666
+ tickfont=dict(size=10),
667
+ ),
668
+ yaxis=dict(
669
+ title="B = Var(inflow) / Var(outflow)",
670
+ title_font=dict(size=11),
671
+ tickfont=dict(size=9),
672
+ rangemode="tozero",
673
+ ),
674
+ barmode="group",
675
+ legend=dict(
676
+ title="Items",
677
+ x=1.01, y=1.0,
678
+ font=dict(size=10),
679
+ bgcolor="rgba(255,255,255,0.8)",
680
+ bordercolor="lightgray", borderwidth=1,
681
+ ),
682
+ height=400,
683
+ paper_bgcolor="white",
684
+ plot_bgcolor="rgba(248,248,252,1)",
685
+ margin=dict(l=60, r=120, t=60, b=60),
686
+ )
687
+ fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
688
+ return fig
689
+
690
+
691
+ # ============================================================================
692
+ # 4. Edge utilization heatmap
693
+ # ============================================================================
694
+
695
+ def make_edge_heatmap(result) -> go.Figure:
696
+ """
697
+ Heatmap of edge utilization (fraction of daily capacity) over time.
698
+ Y-axis: edges sorted downstream→upstream.
699
+ X-axis: simulation day.
700
+ Color: 0 (white/green) = empty → 1 (red) = at capacity.
701
+ Disruption events are shown as vertical dashed lines.
702
+ """
703
+ # Sort edges: downstream edges first (by tier of source node)
704
+ def _edge_tier(eid):
705
+ return result.tier.get(eid[0], 99)
706
+
707
+ sorted_edges = sorted(result.edge_ids, key=_edge_tier)
708
+
709
+ edge_labels = [f"{u} → {v}" for u, v in sorted_edges]
710
+ z = np.array([result.edge_util[eid] for eid in sorted_edges],
711
+ dtype=np.float64) # shape (n_edges, T)
712
+
713
+ days = list(range(result.T))
714
+
715
+ fig = go.Figure(go.Heatmap(
716
+ z=z.tolist(),
717
+ x=days,
718
+ y=edge_labels,
719
+ colorscale=[
720
+ [0.0, "rgb(240,248,255)"], # near-empty: very light blue
721
+ [0.5, "rgb(255,200,100)"], # half full: yellow-orange
722
+ [1.0, "rgb(220,50,50)"], # saturated: red
723
+ ],
724
+ zmin=0.0, zmax=1.0,
725
+ colorbar=dict(
726
+ title="Utilization<br>fraction",
727
+ thickness=13,
728
+ len=0.7,
729
+ tickvals=[0, 0.5, 1],
730
+ ticktext=["0%", "50%", "100%"],
731
+ tickfont=dict(size=9),
732
+ ),
733
+ hovertemplate=(
734
+ "Edge: %{y}<br>Day: %{x}<br>Utilization: %{z:.1%}<extra></extra>"
735
+ ),
736
+ ))
737
+
738
+ # Disruption event markers
739
+ for ev in result.disruption_log:
740
+ fig.add_vline(
741
+ x=ev["day"], line_dash="dash",
742
+ line_color="rgba(80,80,255,0.6)", line_width=1.5,
743
+ )
744
+ fig.add_annotation(
745
+ x=ev["day"], y=1.01, xref="x", yref="paper",
746
+ text="disruption", showarrow=False,
747
+ font=dict(size=8, color="blue"), textangle=-90,
748
+ )
749
+
750
+ fig.update_layout(
751
+ title=dict(
752
+ text=(
753
+ "<b>Edge Utilization over Time</b><br>"
754
+ "<sup>Fraction of daily shipping capacity used per edge "
755
+ "(0 = empty, 1 = fully saturated). "
756
+ "Blue dashes mark disruption events.</sup>"
757
+ ),
758
+ x=0.5, xanchor="center", font=dict(size=13),
759
+ ),
760
+ xaxis=dict(
761
+ title="Day",
762
+ title_font=dict(size=11),
763
+ tickfont=dict(size=9),
764
+ ),
765
+ yaxis=dict(
766
+ title="Edge",
767
+ title_font=dict(size=11),
768
+ tickfont=dict(size=9),
769
+ autorange="reversed", # downstream edges at top
770
+ ),
771
+ height=max(320, 30 * len(sorted_edges) + 120),
772
+ paper_bgcolor="white",
773
+ margin=dict(l=160, r=80, t=60, b=60),
774
+ )
775
+ return fig
776
+
777
+
778
+ # ============================================================================
779
+ # 5. HTML wrapper for animated network map (Gradio gr.HTML compatible)
780
+ # ============================================================================
781
+
782
+ def make_network_animation_html(
783
+ result,
784
+ frame_step: Optional[int] = None,
785
+ frame_duration_ms: int = 150,
786
+ ) -> str:
787
+ """
788
+ Return an <iframe srcdoc=...> string for the animated network map.
789
+
790
+ Gradio's gr.Plot uses Plotly.react() internally, which strips the frames
791
+ array and breaks Play/Pause. An <iframe> with a full standalone HTML page
792
+ (full_html=True) calls Plotly.newPlot() directly so frames are preserved
793
+ and the animation works. The iframe also avoids the height:100% collapse
794
+ that occurs when embedding partial HTML in gr.HTML.
795
+ """
796
+ import html as _html
797
+ import plotly.io as pio
798
+
799
+ fig = make_network_animation(result, frame_step, frame_duration_ms)
800
+ full_html = pio.to_html(
801
+ fig,
802
+ include_plotlyjs="cdn",
803
+ full_html=True,
804
+ config={"responsive": True},
805
+ )
806
+ escaped = _html.escape(full_html, quote=True)
807
+ return (
808
+ f'<iframe srcdoc="{escaped}" '
809
+ f'width="100%" height="600" frameborder="0" scrolling="no" '
810
+ f'style="border:none;display:block;"></iframe>'
811
+ )
812
+
813
+
814
+ # ============================================================================
815
+ # 6. Animated GIF export
816
+ # ============================================================================
817
+
818
+ def make_network_animation_gif(
819
+ result,
820
+ frame_step: Optional[int] = None,
821
+ frame_duration_ms: int = 150,
822
+ max_frames: int = 80,
823
+ ) -> str:
824
+ """
825
+ Render the network animation as an animated GIF and return the temp file path.
826
+
827
+ Requires:
828
+ pip install kaleido Pillow
829
+
830
+ Parameters
831
+ ----------
832
+ result : SimResult
833
+ frame_step : int or None
834
+ Days between frames. Auto-computed to cap at max_frames.
835
+ frame_duration_ms : int
836
+ Milliseconds per frame during playback.
837
+ max_frames : int
838
+ Maximum number of frames to render (keeps file size reasonable).
839
+ """
840
+ import io as _io
841
+ import tempfile
842
+ import plotly.io as pio
843
+
844
+ try:
845
+ from PIL import Image
846
+ except ImportError:
847
+ raise ImportError("Pillow is required for GIF export: pip install Pillow")
848
+
849
+ T = result.T
850
+ if frame_step is None:
851
+ frame_step = max(1, T // max_frames)
852
+ frame_times = list(range(0, T, frame_step))
853
+
854
+ coords = result.node_coords
855
+ node_ids = result.node_ids
856
+ node_lats = [coords[n][0] for n in node_ids]
857
+ node_lons = [coords[n][1] for n in node_ids]
858
+ node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) for n in node_ids]
859
+ node_labels = node_ids
860
+ node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle") for n in node_ids]
861
+ max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0
862
+
863
+ # Static edge traces (reused across frames)
864
+ edge_traces = []
865
+ for eid in result.edge_ids:
866
+ u, v = eid
867
+ if u not in coords or v not in coords:
868
+ continue
869
+ cap = result.edge_cap.get(eid, 1.0)
870
+ lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5
871
+ edge_traces.append(go.Scattergeo(
872
+ lat=[coords[u][0], coords[v][0], None],
873
+ lon=[coords[u][1], coords[v][1], None],
874
+ mode="lines",
875
+ line=dict(width=lw, color="rgba(140,140,160,0.55)"),
876
+ hoverinfo="skip",
877
+ showlegend=False,
878
+ ))
879
+
880
+ pil_frames = []
881
+ for t in frame_times:
882
+ stress_t = _node_stress(result, t)
883
+ lats_t, lons_t, _, colors_t = _frame_particles(result, t)
884
+
885
+ traces = list(edge_traces) # shallow copy of static traces
886
+
887
+ traces.append(go.Scattergeo(
888
+ lat=node_lats, lon=node_lons,
889
+ mode="markers+text",
890
+ text=node_labels,
891
+ textposition="top center",
892
+ textfont=dict(size=9, color="black"),
893
+ marker=dict(
894
+ size=node_sizes, symbol=node_symbols,
895
+ color=stress_t, colorscale=NODE_COLORSCALE,
896
+ cmin=0.0, cmax=1.0,
897
+ line=dict(width=1.5, color="white"),
898
+ ),
899
+ hoverinfo="skip",
900
+ showlegend=False,
901
+ ))
902
+
903
+ if lats_t:
904
+ traces.append(go.Scattergeo(
905
+ lat=lats_t, lon=lons_t,
906
+ mode="markers",
907
+ marker=dict(size=7, color=colors_t, opacity=0.85,
908
+ line=dict(width=0.5, color="white")),
909
+ hoverinfo="skip",
910
+ showlegend=False,
911
+ ))
912
+
913
+ fig = go.Figure(data=traces)
914
+ fig.update_layout(
915
+ title=dict(
916
+ text=f"<b>ISOMORPH</b> — Day {t} / {T - 1}",
917
+ x=0.5, xanchor="center", font=dict(size=12),
918
+ ),
919
+ geo=dict(
920
+ scope="usa", projection_type="albers usa",
921
+ showland=True, landcolor="rgb(243,243,243)",
922
+ showlakes=True, lakecolor="rgb(210,230,255)",
923
+ showcoastlines=True, coastlinecolor="rgb(180,180,200)",
924
+ showsubunits=True, subunitcolor="rgb(200,200,215)",
925
+ bgcolor="white",
926
+ ),
927
+ margin=dict(l=0, r=0, t=40, b=5),
928
+ height=380, width=680,
929
+ paper_bgcolor="white",
930
+ showlegend=False,
931
+ )
932
+
933
+ img_bytes = pio.to_image(fig, format="png", width=680, height=380, scale=1)
934
+ pil_frames.append(Image.open(_io.BytesIO(img_bytes)).convert("RGB"))
935
+
936
+ tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False, prefix="isomorph_")
937
+ pil_frames[0].save(
938
+ tmp.name,
939
+ save_all=True,
940
+ append_images=pil_frames[1:],
941
+ loop=0,
942
+ duration=frame_duration_ms,
943
+ optimize=False,
944
+ )
945
+ tmp.close()
946
+ return tmp.name
eval/chronos_run.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI driver: run a foundation model zero-shot on one ISOMORPH release.
3
+
4
+ Outputs:
5
+ results/{model_short}_{dataset}.csv long-format per-channel metrics
6
+ results/{model_short}_{dataset}_summary.csv cross-channel mean/median
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from data_utils import load_dataset, iter_test_windows
19
+ from metrics import metrics_at_horizons, to_long_dataframe, HORIZONS
20
+ from chronos_runner import (
21
+ load_chronos, predict_rolling_origin, collect_y_true,
22
+ )
23
+
24
+
25
+ def short_name(model_id: str) -> str:
26
+ return model_id.split("/")[-1].replace("-", "_")
27
+
28
+
29
+ def run(out_dir: Path, model_id: str, results_dir: Path,
30
+ L: int, H: int, stride: int,
31
+ num_samples: int, channel_batch: int,
32
+ max_windows: int | None = None,
33
+ label: str | None = None) -> None:
34
+ print(f"=== {out_dir.name} with {model_id} ===", file=sys.stderr)
35
+ split = load_dataset(out_dir)
36
+ if label is not None:
37
+ split.label = label
38
+ print(f" T={split.T} n_items={split.n_items} "
39
+ f"train_end={split.train_end} val_end={split.val_end} "
40
+ f"test=[{split.test_start}, {split.T})", file=sys.stderr)
41
+
42
+ starts = list(iter_test_windows(split, L=L, H=H, stride=stride))
43
+ if max_windows is not None:
44
+ starts = starts[:max_windows]
45
+ print(f" rolling-origin windows: {len(starts)} "
46
+ f"(L={L}, H={H}, stride={stride})", file=sys.stderr)
47
+
48
+ pipe = load_chronos(model_id)
49
+ t0 = time.time()
50
+ y_pred = predict_rolling_origin(
51
+ pipe, split.D, starts, L=L, H=H,
52
+ num_samples=num_samples, channel_batch=channel_batch,
53
+ )
54
+ y_true = collect_y_true(split.D, starts, H)
55
+ elapsed = time.time() - t0
56
+ print(f" inference done in {elapsed/60:.1f} min", file=sys.stderr)
57
+
58
+ metric_dict = metrics_at_horizons(y_true, y_pred, split.mase_denom,
59
+ horizons=HORIZONS)
60
+ long_df = to_long_dataframe(metric_dict, split.item_ids,
61
+ model=model_id, dataset=split.label)
62
+
63
+ sn = short_name(model_id)
64
+ results_dir.mkdir(parents=True, exist_ok=True)
65
+ out_long = results_dir / f"{sn}_{split.label}.csv"
66
+ long_df.to_csv(out_long, index=False)
67
+ print(f" -> {out_long}", file=sys.stderr)
68
+
69
+ # Persist raw tensors for post-hoc slicing (e.g. stationary-vs-shock).
70
+ out_npz = results_dir / f"{sn}_{split.label}_tensors.npz"
71
+ np.savez_compressed(
72
+ out_npz, y_pred=y_pred, y_true=y_true,
73
+ window_starts=np.asarray(starts, dtype=np.int64),
74
+ item_ids=np.asarray(split.item_ids),
75
+ L=L, H=H, stride=stride, model=model_id, dataset=split.label,
76
+ )
77
+ print(f" -> {out_npz}", file=sys.stderr)
78
+
79
+ # Cross-channel summary at each (metric, h).
80
+ summary = (long_df
81
+ .groupby(["model", "dataset", "metric", "h"])["value"]
82
+ .agg(mean="mean", median="median",
83
+ q25=lambda x: x.quantile(0.25),
84
+ q75=lambda x: x.quantile(0.75),
85
+ n="count")
86
+ .reset_index())
87
+ out_sum = results_dir / f"{sn}_{split.label}_summary.csv"
88
+ summary.to_csv(out_sum, index=False)
89
+ print(f" -> {out_sum}", file=sys.stderr)
90
+
91
+ # Print a compact body-table preview.
92
+ print("\n Headline (median across channels):", file=sys.stderr)
93
+ print(summary.pivot_table(index="metric", columns="h",
94
+ values="median").to_string(),
95
+ file=sys.stderr)
96
+
97
+
98
+ def main():
99
+ repo = Path(__file__).resolve().parents[1]
100
+ ap = argparse.ArgumentParser()
101
+ ap.add_argument("--root", default=str(repo / "data"))
102
+ ap.add_argument("--dataset", default="output_item50")
103
+ ap.add_argument("--scenario_path", default=None,
104
+ help="Path to a scenario directory "
105
+ "(e.g. data/output_mixture/baseline/seed2025). "
106
+ "Overrides --root/--dataset when set.")
107
+ ap.add_argument("--label", default=None,
108
+ help="Output filename label. Defaults to the directory "
109
+ "name. For scenario paths, set this to the scenario "
110
+ "name so results from different scenarios don't collide.")
111
+ ap.add_argument("--model_id", default="amazon/chronos-t5-base")
112
+ ap.add_argument("--out", default=str(
113
+ repo / "results" / "eval" / "baseline_and_scenarios"))
114
+ ap.add_argument("--L", type=int, default=512)
115
+ ap.add_argument("--H", type=int, default=30)
116
+ ap.add_argument("--stride", type=int, default=30)
117
+ ap.add_argument("--num_samples", type=int, default=20)
118
+ ap.add_argument("--channel_batch", type=int, default=16)
119
+ ap.add_argument("--max_windows", type=int, default=None,
120
+ help="cap for smoke testing")
121
+ args = ap.parse_args()
122
+ if args.scenario_path is not None:
123
+ out_dir = Path(args.scenario_path)
124
+ else:
125
+ out_dir = Path(args.root) / args.dataset
126
+ run(out_dir, args.model_id, Path(args.out),
127
+ L=args.L, H=args.H, stride=args.stride,
128
+ num_samples=args.num_samples,
129
+ channel_batch=args.channel_batch,
130
+ max_windows=args.max_windows,
131
+ label=args.label)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
eval/chronos_runner.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chronos zero-shot rolling-origin inference.
3
+
4
+ Loads a Chronos pipeline from HuggingFace (uses HF_HOME cache so no
5
+ network is needed once the checkpoint is local), runs L=512 → H=30
6
+ inference per channel batched across channels, and reduces the 20
7
+ sample paths to the per-day median for point-forecast metrics.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from pathlib import Path
13
+ import time
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from chronos import ChronosPipeline
19
+
20
+
21
+ def load_chronos(model_id: str, device: str | None = None,
22
+ dtype: torch.dtype | None = None) -> ChronosPipeline:
23
+ """Load Chronos pipeline; auto-detect device/dtype if not given."""
24
+ if device is None:
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ if dtype is None:
27
+ dtype = (torch.bfloat16 if device == "cuda"
28
+ else torch.float32)
29
+ print(f" loading {model_id} on {device} ({dtype})", file=sys.stderr)
30
+ return ChronosPipeline.from_pretrained(
31
+ model_id, device_map=device, torch_dtype=dtype,
32
+ )
33
+
34
+
35
+ def predict_rolling_origin(
36
+ pipe: ChronosPipeline,
37
+ D: np.ndarray, # (T, n_items)
38
+ window_starts: list[int], # rolling-origin t values; forecast [t, t+H)
39
+ L: int = 512, H: int = 30,
40
+ num_samples: int = 20,
41
+ channel_batch: int = 16,
42
+ ) -> np.ndarray:
43
+ """Returns y_pred of shape (n_windows, H, n_items), point=median."""
44
+ n_windows = len(window_starts)
45
+ n_items = D.shape[1]
46
+ y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32)
47
+
48
+ t0 = time.time()
49
+ for wi, t in enumerate(window_starts):
50
+ ctx = D[t - L:t, :] # (L, n_items)
51
+ # Predict all channels for this window in batches.
52
+ for j0 in range(0, n_items, channel_batch):
53
+ j1 = min(j0 + channel_batch, n_items)
54
+ # ChronosPipeline.predict expects a list of 1-D tensors.
55
+ ctxs = [torch.tensor(ctx[:, j], dtype=torch.float32)
56
+ for j in range(j0, j1)]
57
+ samples = pipe.predict(
58
+ ctxs, prediction_length=H,
59
+ num_samples=num_samples, limit_prediction_length=False,
60
+ )
61
+ # samples: (batch, num_samples, H) -- median over samples
62
+ samples = samples.cpu().to(torch.float32).numpy()
63
+ med = np.median(samples, axis=1) # (batch, H)
64
+ y_pred[wi, :, j0:j1] = med.T # (H, batch)
65
+ if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows:
66
+ elapsed = time.time() - t0
67
+ rate = (wi + 1) / max(elapsed, 1e-9)
68
+ eta = (n_windows - wi - 1) / max(rate, 1e-9)
69
+ print(f" window {wi+1:4d}/{n_windows} "
70
+ f"elapsed={elapsed:6.1f}s "
71
+ f"rate={rate:5.2f} win/s "
72
+ f"eta={eta/60:5.1f}min", file=sys.stderr)
73
+ return y_pred
74
+
75
+
76
+ def collect_y_true(D: np.ndarray, window_starts: list[int],
77
+ H: int) -> np.ndarray:
78
+ n_windows = len(window_starts)
79
+ n_items = D.shape[1]
80
+ y_true = np.zeros((n_windows, H, n_items), dtype=np.float32)
81
+ for wi, t in enumerate(window_starts):
82
+ y_true[wi] = D[t:t + H]
83
+ return y_true
eval/data_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset I/O for §4 zero-shot foundation-model evaluation.
3
+
4
+ Loads per-channel daily demand from the released CSVs, computes the
5
+ chronological 70/15/15 split, and the MASE denominator (lag-7
6
+ seasonal-naive in-sample MAE per channel, computed on the train slice;
7
+ matches the GIFT-Eval-style definition used in paper §F.1).
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+
18
+ @dataclass
19
+ class DemandSplit:
20
+ """A single dataset's demand series with its chronological splits."""
21
+ label: str
22
+ item_ids: list[str]
23
+ D: np.ndarray # shape (T, n_items), float32
24
+ train_end: int # exclusive
25
+ val_end: int # exclusive — test = [val_end, T)
26
+ mase_denom: np.ndarray # shape (n_items,), per-channel MAE of
27
+ # lag-7 seasonal-naive on the train slice
28
+
29
+ @property
30
+ def T(self) -> int:
31
+ return self.D.shape[0]
32
+
33
+ @property
34
+ def n_items(self) -> int:
35
+ return self.D.shape[1]
36
+
37
+ @property
38
+ def test_start(self) -> int:
39
+ return self.val_end
40
+
41
+
42
+ def load_dataset(out_dir: Path,
43
+ split_train: float = 0.70,
44
+ split_val: float = 0.15) -> DemandSplit:
45
+ """Load demand from daily_records.csv into (T, n_items) array."""
46
+ cols_path = out_dir / "demand_signals_cols.txt"
47
+ item_ids = cols_path.read_text().strip().split(",")
48
+ n_items = len(item_ids)
49
+ item_to_col = {iid: j for j, iid in enumerate(item_ids)}
50
+
51
+ dr = pd.read_csv(
52
+ out_dir / "daily_records.csv",
53
+ usecols=["day", "item", "demand"],
54
+ dtype={"day": np.int32, "demand": np.int64},
55
+ )
56
+ T = int(dr["day"].max() + 1)
57
+ D = np.zeros((T, n_items), dtype=np.float32)
58
+ days = dr["day"].to_numpy()
59
+ cols = dr["item"].map(item_to_col).to_numpy()
60
+ if pd.isna(cols).any():
61
+ raise ValueError("unknown items in daily_records.csv")
62
+ cols = cols.astype(np.int64)
63
+ D[days, cols] = dr["demand"].to_numpy(dtype=np.float32)
64
+
65
+ train_end = int(round(T * split_train))
66
+ val_end = int(round(T * (split_train + split_val)))
67
+
68
+ # MASE denominator: per-channel mean absolute lag-7 first difference
69
+ # of the TRAIN slice (seasonal-naive at weekly lag, the GluonTS default
70
+ # for daily frequency and the convention reported in paper §F.1).
71
+ train = D[:train_end]
72
+ SEASONAL_LAG = 7
73
+ diff = np.abs(train[SEASONAL_LAG:] - train[:-SEASONAL_LAG])
74
+ mase_denom = diff.mean(axis=0).astype(np.float32)
75
+ # Guard against zero (constant channel); fall back to 1.0 to avoid div-0
76
+ mase_denom = np.where(mase_denom > 0, mase_denom, 1.0)
77
+
78
+ return DemandSplit(
79
+ label=out_dir.name,
80
+ item_ids=item_ids,
81
+ D=D,
82
+ train_end=train_end,
83
+ val_end=val_end,
84
+ mase_denom=mase_denom,
85
+ )
86
+
87
+
88
+ def iter_test_windows(split: DemandSplit, L: int = 512, H: int = 30,
89
+ stride: int = 30):
90
+ """Yield rolling-origin windows whose forecast horizon lies in test.
91
+
92
+ Each window: context indices [t - L, t), forecast indices [t, t + H).
93
+ The first t is split.test_start; the last t satisfies t + H <= T.
94
+ Context is allowed to span train/val/test boundaries (rolling-origin).
95
+ """
96
+ T = split.T
97
+ t = split.test_start
98
+ while t + H <= T:
99
+ if t - L < 0:
100
+ t += stride
101
+ continue
102
+ yield t
103
+ t += stride
eval/gift_style_mase.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute GIFT-Eval-style aggregate MASE on the two ISOMORPH baseline
2
+ releases for the four foundation models, so the values are directly
3
+ comparable in scale to the GIFT-Eval leaderboard.
4
+
5
+ Definition (matches GluonTS standard MASE + GIFT-Eval aggregation):
6
+ MASE_per_channel(model) = MAE_model_per_channel / D_seasonal
7
+ MASE_per_channel(SN) = MAE_SN_per_channel / D_seasonal
8
+ RelMASE_per_channel(M) = MASE_model / MASE_SN (= MAE_model / MAE_SN)
9
+ Aggregate = geometric_mean over channels
10
+
11
+ D_seasonal: per-channel mean abs lag-m first difference on the train
12
+ slice, with m=7 (weekly), the GluonTS default for daily
13
+ frequency.
14
+
15
+ SN baseline: at test window starting day t with horizon H=30,
16
+ prediction y_pred[t+h] = y[t+h-7] for h in [0, H).
17
+
18
+ The MAE values for the four models per channel per horizon are read
19
+ directly from the per-channel CSVs already produced by the foundation
20
+ evaluation runs (no re-inference). Only Seasonal Naive is computed
21
+ fresh.
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from pathlib import Path
26
+ import numpy as np
27
+ import pandas as pd
28
+
29
+ REPO = Path(__file__).resolve().parents[1]
30
+ ROOT = REPO / "data"
31
+ RESULT_DIR = REPO / "results" / "eval" / "baseline_and_scenarios"
32
+
33
+ L = 512
34
+ H = 30
35
+ STRIDE = 30
36
+ SEASONAL_M = 7
37
+
38
+ DATASETS = ["output_item50", "output_item200"]
39
+ HORIZONS = [1, 7, 14, 30]
40
+
41
+ MODELS = {
42
+ "Chronos": "chronos_t5_base_{ds}.csv",
43
+ "Moirai": "moirai_1_1_R_base_{ds}.csv",
44
+ "TimesFM": "timesfm_2_0_500m_pytorch_{ds}.csv",
45
+ "Lag-Llama": "lag_llama_{ds}.csv",
46
+ }
47
+
48
+
49
+ def load_demand(ds_dir: Path) -> tuple[np.ndarray, list[str]]:
50
+ item_ids = (ds_dir / "demand_signals_cols.txt").read_text().strip().split(",")
51
+ item_to_col = {iid: j for j, iid in enumerate(item_ids)}
52
+ dr = pd.read_csv(ds_dir / "daily_records.csv",
53
+ usecols=["day", "item", "demand"],
54
+ dtype={"day": np.int32, "demand": np.int64})
55
+ T = int(dr["day"].max() + 1)
56
+ D = np.zeros((T, len(item_ids)), dtype=np.float32)
57
+ D[dr["day"].to_numpy(),
58
+ dr["item"].map(item_to_col).to_numpy(dtype=np.int64)] = \
59
+ dr["demand"].to_numpy(dtype=np.float32)
60
+ return D, item_ids
61
+
62
+
63
+ def seasonal_denom(D_train: np.ndarray, m: int) -> np.ndarray:
64
+ """Mean abs lag-m first difference per channel on the train slice."""
65
+ diff = np.abs(D_train[m:] - D_train[:-m])
66
+ den = diff.mean(axis=0).astype(np.float32)
67
+ return np.where(den > 0, den, 1.0)
68
+
69
+
70
+ def seasonal_naive_mae(D: np.ndarray, train_end: int, val_end: int,
71
+ L: int, H: int, stride: int, m: int) -> np.ndarray:
72
+ """Per-channel-per-horizon MAE of SN(m) under the rolling-origin protocol.
73
+
74
+ Returns shape (len(HORIZONS), C). Cumulative-mean over the first h
75
+ forecast days then averaged over windows.
76
+ """
77
+ T, C = D.shape
78
+ test_start = val_end
79
+ starts = list(range(test_start, T - H + 1, stride))
80
+ n_W = len(starts)
81
+ abs_err = np.zeros((n_W, H, C), dtype=np.float32)
82
+ for w, t in enumerate(starts):
83
+ for h in range(H):
84
+ true = D[t + h]
85
+ pred = D[t + h - m]
86
+ abs_err[w, h] = np.abs(true - pred)
87
+ mae_per_h = np.zeros((len(HORIZONS), C), dtype=np.float32)
88
+ for i, h in enumerate(HORIZONS):
89
+ mae_per_h[i] = abs_err[:, :h, :].mean(axis=(0, 1))
90
+ return mae_per_h, n_W
91
+
92
+
93
+ def model_mae(csv_path: Path) -> dict[int, np.ndarray]:
94
+ """Load per-channel MAE at each horizon from a model CSV."""
95
+ df = pd.read_csv(csv_path)
96
+ out = {}
97
+ for h in HORIZONS:
98
+ sub = df[(df["metric"] == "MAE") & (df["h"] == h)]
99
+ sub = sub.sort_values("channel")
100
+ out[h] = sub["value"].to_numpy(dtype=np.float32)
101
+ return out
102
+
103
+
104
+ def gift_aggregate(rel_per_channel: np.ndarray) -> float:
105
+ """Geometric mean across channels."""
106
+ rel_per_channel = np.maximum(rel_per_channel, 1e-12)
107
+ return float(np.exp(np.log(rel_per_channel).mean()))
108
+
109
+
110
+ def main():
111
+ rows = []
112
+ for ds in DATASETS:
113
+ ds_dir = ROOT / ds
114
+ print(f"\n=== {ds} ===")
115
+ D, item_ids = load_demand(ds_dir)
116
+ T, C = D.shape
117
+ train_end = int(round(T * 0.70))
118
+ val_end = int(round(T * 0.85))
119
+ print(f" T={T}, C={C}, train_end={train_end}, val_end={val_end}")
120
+
121
+ d_seasonal = seasonal_denom(D[:train_end], SEASONAL_M)
122
+ print(f" seasonal-{SEASONAL_M} denom: "
123
+ f"min={d_seasonal.min():.3f} "
124
+ f"mean={d_seasonal.mean():.3f} max={d_seasonal.max():.3f}")
125
+
126
+ # Seasonal Naive baseline on rolling windows
127
+ sn_mae, n_W = seasonal_naive_mae(D, train_end, val_end,
128
+ L=L, H=H, stride=STRIDE,
129
+ m=SEASONAL_M)
130
+ print(f" windows: {n_W}; SN MAE @ h=30 (mean over channels): "
131
+ f"{sn_mae[3].mean():.3f}")
132
+ sn_mase = sn_mae / d_seasonal[None, :] # (4, C)
133
+
134
+ for model_name, csv_template in MODELS.items():
135
+ csv_name = csv_template.format(ds=ds)
136
+ csv_path = RESULT_DIR / csv_name
137
+ if not csv_path.exists():
138
+ print(f" [{model_name}] MISSING {csv_name}")
139
+ continue
140
+ mae_dict = model_mae(csv_path)
141
+ for i, h in enumerate(HORIZONS):
142
+ model_mae_arr = mae_dict[h]
143
+ model_mase = model_mae_arr / d_seasonal
144
+ rel = model_mase / sn_mase[i]
145
+ # Cap absurd outliers to avoid inf in geom mean
146
+ rel = np.clip(rel, 1e-3, 1e3)
147
+ gift_mase = gift_aggregate(rel)
148
+ rows.append({
149
+ "dataset": ds,
150
+ "model": model_name,
151
+ "h": h,
152
+ "MAE_mean": float(model_mae_arr.mean()),
153
+ "MASE_seasonal_mean": float(model_mase.mean()),
154
+ "RelMASE_geom_over_channels": gift_mase,
155
+ })
156
+
157
+ out = pd.DataFrame(rows)
158
+ print("\n\n=== Summary (GIFT-style RelMASE = geom mean over channels of "
159
+ "[MAE_model / MAE_SeasonalNaive(m=7)]) ===")
160
+ pivot = out.pivot_table(index=["dataset", "model"], columns="h",
161
+ values="RelMASE_geom_over_channels")
162
+ print(pivot.round(3).to_string())
163
+
164
+ # Headline aggregate (across horizons): geom mean over h
165
+ print("\n=== Per-(dataset, model) aggregate over horizons "
166
+ "(geom mean over h ∈ {1,7,14,30}) ===")
167
+ agg = (pivot.apply(lambda r: np.exp(np.log(r).mean()), axis=1)
168
+ .round(3))
169
+ print(agg.to_string())
170
+
171
+ out.to_csv(RESULT_DIR / "gift_style_mase.csv", index=False)
172
+ print(f"\nSaved: {RESULT_DIR / 'gift_style_mase.csv'}")
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
eval/lagllama_run.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI driver: run Lag-Llama zero-shot on one ISOMORPH release.
3
+
4
+ Outputs (mirror Chronos / Moirai / TimesFM runners):
5
+ results/lag_llama_{dataset}.csv per-channel long-format metrics
6
+ results/lag_llama_{dataset}_summary.csv cross-channel mean/median
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from data_utils import load_dataset, iter_test_windows
19
+ from metrics import metrics_at_horizons, to_long_dataframe, HORIZONS
20
+ from lagllama_runner import (
21
+ load_lagllama, predict_rolling_origin, collect_y_true,
22
+ )
23
+
24
+
25
+ MODEL_LABEL = "time-series-foundation-models/Lag-Llama"
26
+ SHORT = "lag_llama"
27
+
28
+
29
+ def run(out_dir: Path, results_dir: Path,
30
+ L: int, H: int, stride: int,
31
+ num_samples: int, batch_size: int,
32
+ max_windows: int | None = None,
33
+ label: str | None = None) -> None:
34
+ print(f"=== {out_dir.name} with {MODEL_LABEL} ===", file=sys.stderr)
35
+ split = load_dataset(out_dir)
36
+ if label is not None:
37
+ split.label = label
38
+ print(f" T={split.T} n_items={split.n_items} "
39
+ f"train_end={split.train_end} val_end={split.val_end} "
40
+ f"test=[{split.test_start}, {split.T})", file=sys.stderr)
41
+
42
+ starts = list(iter_test_windows(split, L=L, H=H, stride=stride))
43
+ if max_windows is not None:
44
+ starts = starts[:max_windows]
45
+ print(f" rolling-origin windows: {len(starts)} "
46
+ f"(L={L}, H={H}, stride={stride})", file=sys.stderr)
47
+
48
+ predictor = load_lagllama(
49
+ prediction_length=H, context_length=L,
50
+ num_samples=num_samples, batch_size=batch_size,
51
+ )
52
+
53
+ t0 = time.time()
54
+ y_pred = predict_rolling_origin(predictor, split.D, starts, L=L, H=H)
55
+ y_true = collect_y_true(split.D, starts, H)
56
+ elapsed = time.time() - t0
57
+ print(f" inference done in {elapsed/60:.1f} min", file=sys.stderr)
58
+
59
+ metric_dict = metrics_at_horizons(y_true, y_pred, split.mase_denom,
60
+ horizons=HORIZONS)
61
+ long_df = to_long_dataframe(metric_dict, split.item_ids,
62
+ model=MODEL_LABEL, dataset=split.label)
63
+
64
+ results_dir.mkdir(parents=True, exist_ok=True)
65
+ out_long = results_dir / f"{SHORT}_{split.label}.csv"
66
+ long_df.to_csv(out_long, index=False)
67
+ print(f" -> {out_long}", file=sys.stderr)
68
+
69
+ # Persist raw tensors for post-hoc slicing (e.g. stationary-vs-shock).
70
+ out_npz = results_dir / f"{SHORT}_{split.label}_tensors.npz"
71
+ np.savez_compressed(
72
+ out_npz, y_pred=y_pred, y_true=y_true,
73
+ window_starts=np.asarray(starts, dtype=np.int64),
74
+ item_ids=np.asarray(split.item_ids),
75
+ L=L, H=H, stride=stride, model=MODEL_LABEL, dataset=split.label,
76
+ )
77
+ print(f" -> {out_npz}", file=sys.stderr)
78
+
79
+ summary = (long_df
80
+ .groupby(["model", "dataset", "metric", "h"])["value"]
81
+ .agg(mean="mean", median="median",
82
+ q25=lambda x: x.quantile(0.25),
83
+ q75=lambda x: x.quantile(0.75),
84
+ n="count")
85
+ .reset_index())
86
+ out_sum = results_dir / f"{SHORT}_{split.label}_summary.csv"
87
+ summary.to_csv(out_sum, index=False)
88
+ print(f" -> {out_sum}", file=sys.stderr)
89
+
90
+ print("\n Headline (median across channels):", file=sys.stderr)
91
+ print(summary.pivot_table(index="metric", columns="h",
92
+ values="median").to_string(),
93
+ file=sys.stderr)
94
+
95
+
96
+ def main():
97
+ repo = Path(__file__).resolve().parents[1]
98
+ ap = argparse.ArgumentParser()
99
+ ap.add_argument("--root", default=str(repo / "data"))
100
+ ap.add_argument("--dataset", default="output_item50")
101
+ ap.add_argument("--scenario_path", default=None,
102
+ help="Path to a scenario directory; "
103
+ "overrides --root/--dataset when set.")
104
+ ap.add_argument("--label", default=None,
105
+ help="Output filename label; defaults to out_dir.name.")
106
+ ap.add_argument("--out", default=str(
107
+ repo / "results" / "eval" / "baseline_and_scenarios"))
108
+ ap.add_argument("--L", type=int, default=512,
109
+ help="context length; with L>32 RoPE scaling auto-on")
110
+ ap.add_argument("--H", type=int, default=30)
111
+ ap.add_argument("--stride", type=int, default=30)
112
+ ap.add_argument("--num_samples", type=int, default=100,
113
+ help="probabilistic samples for the median point forecast")
114
+ ap.add_argument("--batch_size", type=int, default=32)
115
+ ap.add_argument("--max_windows", type=int, default=None,
116
+ help="cap for smoke testing")
117
+ args = ap.parse_args()
118
+ if args.scenario_path is not None:
119
+ out_dir = Path(args.scenario_path)
120
+ else:
121
+ out_dir = Path(args.root) / args.dataset
122
+ run(out_dir, Path(args.out),
123
+ L=args.L, H=args.H, stride=args.stride,
124
+ num_samples=args.num_samples,
125
+ batch_size=args.batch_size,
126
+ max_windows=args.max_windows,
127
+ label=args.label)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
eval/lagllama_runner.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lag-Llama zero-shot rolling-origin inference.
3
+
4
+ Wraps time-series-foundation-models/Lag-Llama via the official GluonTS-style
5
+ LagLlamaEstimator. Univariate (like Chronos / TimesFM): per window we hand
6
+ the model C 1-D context arrays and reduce 100 sample paths to the per-day
7
+ median for the point forecast.
8
+
9
+ CRITICAL: Lag-Llama was trained on context length 32. Any L > 32 requires
10
+ RoPE scaling (linear factor = (L + H) / 32). Without this the position
11
+ embeddings extrapolate and the forecast degrades sharply.
12
+
13
+ Designed to mirror chronos_runner.py's I/O contract so metrics.py stays
14
+ unchanged.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import sys
19
+ import time
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ import torch
25
+ from gluonts.dataset.common import ListDataset
26
+
27
+ from huggingface_hub import hf_hub_download
28
+ from lag_llama.gluon.estimator import LagLlamaEstimator
29
+
30
+
31
+ _LAGLLAMA_TRAINING_CTX = 32 # what the public checkpoint was trained on
32
+
33
+
34
+ def load_lagllama(prediction_length: int, context_length: int,
35
+ num_samples: int = 100, batch_size: int = 32,
36
+ ckpt_path: str | None = None,
37
+ device: str | None = None):
38
+ """Load Lag-Llama as a GluonTS PyTorchPredictor."""
39
+ if device is None:
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ if ckpt_path is None:
42
+ ckpt_path = hf_hub_download(
43
+ repo_id="time-series-foundation-models/Lag-Llama",
44
+ filename="lag-llama.ckpt",
45
+ )
46
+ print(f" loading Lag-Llama from {ckpt_path} on {device} "
47
+ f"(L={context_length}, H={prediction_length}, "
48
+ f"num_samples={num_samples}, batch_size={batch_size})",
49
+ file=sys.stderr)
50
+
51
+ # Pull the model architecture knobs out of the checkpoint so we
52
+ # construct a matching estimator regardless of release.
53
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
54
+ args = ckpt["hyper_parameters"]["model_kwargs"]
55
+
56
+ # RoPE scaling so positional encodings extrapolate cleanly to L > 32.
57
+ rope_scaling = None
58
+ needed_extent = context_length + prediction_length
59
+ if needed_extent > _LAGLLAMA_TRAINING_CTX:
60
+ rope_scaling = {
61
+ "type": "linear",
62
+ "factor": float(needed_extent) / float(_LAGLLAMA_TRAINING_CTX),
63
+ }
64
+ print(f" enabling RoPE scaling: factor="
65
+ f"{rope_scaling['factor']:.3f} "
66
+ f"(L+H={needed_extent} > training ctx {_LAGLLAMA_TRAINING_CTX})",
67
+ file=sys.stderr)
68
+
69
+ estimator = LagLlamaEstimator(
70
+ ckpt_path=ckpt_path,
71
+ prediction_length=prediction_length,
72
+ context_length=context_length,
73
+ input_size=args["input_size"],
74
+ n_layer=args["n_layer"],
75
+ n_embd_per_head=args["n_embd_per_head"],
76
+ n_head=args["n_head"],
77
+ scaling=args["scaling"],
78
+ time_feat=args["time_feat"],
79
+ rope_scaling=rope_scaling,
80
+ batch_size=batch_size,
81
+ num_parallel_samples=num_samples,
82
+ )
83
+ lightning_module = estimator.create_lightning_module()
84
+ transformation = estimator.create_transformation()
85
+ predictor = estimator.create_predictor(transformation, lightning_module)
86
+ return predictor
87
+
88
+
89
+ def predict_rolling_origin(predictor, D: np.ndarray,
90
+ window_starts: list[int],
91
+ L: int, H: int) -> np.ndarray:
92
+ """Returns y_pred of shape (n_windows, H, n_items); point=median."""
93
+ n_windows = len(window_starts)
94
+ n_items = D.shape[1]
95
+ y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32)
96
+ anchor = pd.Period("2000-01-01", freq="D")
97
+
98
+ t0 = time.time()
99
+ for wi, t in enumerate(window_starts):
100
+ ctx = D[t - L:t, :] # (L, n_items)
101
+ items = [
102
+ {"target": ctx[:, j].astype(np.float32), "start": anchor}
103
+ for j in range(n_items)
104
+ ]
105
+ ds = ListDataset(items, freq="D")
106
+ forecasts = list(predictor.predict(ds))
107
+ for j, fc in enumerate(forecasts):
108
+ # fc.samples: (num_parallel_samples, H)
109
+ y_pred[wi, :, j] = np.median(
110
+ np.asarray(fc.samples, dtype=np.float32), axis=0
111
+ )
112
+ if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows:
113
+ elapsed = time.time() - t0
114
+ rate = (wi + 1) / max(elapsed, 1e-9)
115
+ eta = (n_windows - wi - 1) / max(rate, 1e-9)
116
+ print(f" window {wi+1:4d}/{n_windows} "
117
+ f"elapsed={elapsed:6.1f}s "
118
+ f"rate={rate:5.2f} win/s "
119
+ f"eta={eta/60:5.1f}min", file=sys.stderr)
120
+ return y_pred
121
+
122
+
123
+ def collect_y_true(D: np.ndarray, window_starts: list[int],
124
+ H: int) -> np.ndarray:
125
+ n_windows = len(window_starts)
126
+ n_items = D.shape[1]
127
+ y_true = np.zeros((n_windows, H, n_items), dtype=np.float32)
128
+ for wi, t in enumerate(window_starts):
129
+ y_true[wi] = D[t:t + H]
130
+ return y_true
eval/metrics.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Forecast metrics computed at horizons h ∈ {1, 7, 14, 30}.
3
+
4
+ Inputs are y_true and y_pred of shape (n_windows, H, n_items),
5
+ plus mase_denom of shape (n_items,). We accumulate per-item per-h
6
+ sums and counts and then aggregate at the end.
7
+
8
+ NOTE on the headline numbers in the paper.
9
+ The per-channel ``MASE = MAE / mase_denom`` column written by this
10
+ module is a convenience output; it is *not* what populates paper
11
+ Table 1 / 6 / 7. Those tables report the GIFT-Eval-style aggregate
12
+ (geometric mean over channels of ``MAE_model / MAE_SeasonalNaive``),
13
+ which is recomputed post-hoc from the per-channel MAE column by
14
+ ``gift_style_mase.py``. Downstream consumers of these CSVs (the
15
+ analysis scripts in this repo) likewise only read the MAE rows.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import numpy as np
20
+
21
+
22
+ HORIZONS = [1, 7, 14, 30] # 1-indexed: h=1 means first forecast day
23
+
24
+
25
+ def _smape_term(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
26
+ """Standard sMAPE: 200 * |y - y_hat| / (|y| + |y_hat|).
27
+
28
+ Convention: when both numerator-relevant values are zero, the term
29
+ is zero (perfect forecast on a zero ground-truth).
30
+ """
31
+ num = 2.0 * np.abs(y_true - y_pred)
32
+ den = np.abs(y_true) + np.abs(y_pred)
33
+ out = np.where(den > 0, num / den, 0.0)
34
+ return 100.0 * out
35
+
36
+
37
+ def metrics_at_horizons(
38
+ y_true: np.ndarray, y_pred: np.ndarray, mase_denom: np.ndarray,
39
+ horizons: list[int] = HORIZONS,
40
+ ) -> dict:
41
+ """Returns a dict keyed by (metric, h) -> per-channel array.
42
+
43
+ y_true, y_pred: shape (n_windows, H, n_items), float
44
+ mase_denom: shape (n_items,)
45
+
46
+ For h=k, we average the per-day error over [0, k) and over windows;
47
+ that is, MAE@k = mean over (window, day in [0,k), channel).
48
+ Reporting per-channel: average over (window, day in [0,k)) → array of
49
+ shape (n_items,).
50
+ """
51
+ n_windows, H, n_items = y_true.shape
52
+ out = {}
53
+ abs_err = np.abs(y_true - y_pred) # (W, H, C)
54
+ sq_err = (y_true - y_pred) ** 2 # (W, H, C)
55
+ smape = _smape_term(y_true, y_pred) # (W, H, C)
56
+
57
+ for h in horizons:
58
+ if h > H:
59
+ continue
60
+ # Cumulative-mean over the first h forecast days, then average
61
+ # over windows. Result is per-channel: shape (n_items,).
62
+ mae = abs_err[:, :h, :].mean(axis=(0, 1))
63
+ rmse = np.sqrt(sq_err[:, :h, :].mean(axis=(0, 1)))
64
+ smap = smape[:, :h, :].mean(axis=(0, 1))
65
+ mase = mae / mase_denom
66
+
67
+ out[("MAE", h)] = mae
68
+ out[("RMSE", h)] = rmse
69
+ out[("SMAPE", h)] = smap
70
+ out[("MASE", h)] = mase
71
+ return out
72
+
73
+
74
+ def to_long_dataframe(metric_dict: dict, item_ids: list[str],
75
+ model: str, dataset: str):
76
+ """Flatten the (metric, h) -> (n_items,) dict into long format."""
77
+ import pandas as pd
78
+ rows = []
79
+ for (metric, h), arr in metric_dict.items():
80
+ for i, iid in enumerate(item_ids):
81
+ rows.append({
82
+ "model": model, "dataset": dataset,
83
+ "channel": iid, "metric": metric, "h": h,
84
+ "value": float(arr[i]),
85
+ })
86
+ return pd.DataFrame(rows)
eval/moirai_run.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI driver: run Moirai zero-shot on one ISOMORPH release.
3
+
4
+ Outputs (mirror Chronos runner naming):
5
+ results/{model_short}_{dataset}.csv per-channel long-format metrics
6
+ results/{model_short}_{dataset}_summary.csv cross-channel mean/median
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from data_utils import load_dataset
19
+ from metrics import metrics_at_horizons, to_long_dataframe, HORIZONS
20
+ from moirai_runner import (
21
+ load_moirai, build_test_data,
22
+ predict_rolling_origin, collect_y_true,
23
+ )
24
+
25
+
26
+ def short_name(model_id: str) -> str:
27
+ return model_id.split("/")[-1].replace("-", "_").replace(".", "_")
28
+
29
+
30
+ def run(out_dir: Path, model_id: str, results_dir: Path,
31
+ L: int, H: int, stride: int,
32
+ num_samples: int, batch_size: int, patch_size,
33
+ max_windows: int | None = None,
34
+ label: str | None = None) -> None:
35
+ print(f"=== {out_dir.name} with {model_id} ===", file=sys.stderr)
36
+ split = load_dataset(out_dir)
37
+ if label is not None:
38
+ split.label = label
39
+ print(f" T={split.T} n_items={split.n_items} "
40
+ f"train_end={split.train_end} val_end={split.val_end} "
41
+ f"test=[{split.test_start}, {split.T})", file=sys.stderr)
42
+
43
+ test_data, n_windows = build_test_data(
44
+ split.D, split.item_ids, split.val_end,
45
+ H=H, stride=stride, max_windows=max_windows,
46
+ )
47
+ print(f" rolling-origin windows: {n_windows} "
48
+ f"(L={L}, H={H}, stride={stride})", file=sys.stderr)
49
+
50
+ predictor = load_moirai(
51
+ model_id,
52
+ prediction_length=H, context_length=L,
53
+ target_dim=split.n_items,
54
+ num_samples=num_samples, patch_size=patch_size,
55
+ batch_size=batch_size,
56
+ )
57
+
58
+ t0 = time.time()
59
+ y_pred = predict_rolling_origin(
60
+ predictor, test_data, n_windows, H, split.n_items,
61
+ )
62
+ y_true = collect_y_true(test_data, n_windows, H, split.n_items)
63
+ elapsed = time.time() - t0
64
+ print(f" inference done in {elapsed/60:.1f} min", file=sys.stderr)
65
+
66
+ metric_dict = metrics_at_horizons(y_true, y_pred, split.mase_denom,
67
+ horizons=HORIZONS)
68
+ long_df = to_long_dataframe(metric_dict, split.item_ids,
69
+ model=model_id, dataset=split.label)
70
+
71
+ sn = short_name(model_id)
72
+ results_dir.mkdir(parents=True, exist_ok=True)
73
+ out_long = results_dir / f"{sn}_{split.label}.csv"
74
+ long_df.to_csv(out_long, index=False)
75
+ print(f" -> {out_long}", file=sys.stderr)
76
+
77
+ # Persist raw tensors for post-hoc slicing (e.g. stationary-vs-shock).
78
+ # Moirai's GluonTS test_data hides explicit window starts; reconstruct
79
+ # them from the (val_end, stride, n_windows) triple — the first
80
+ # forecast horizon begins at val_end and successive ones shift by stride.
81
+ starts = [split.val_end + i * stride for i in range(n_windows)]
82
+ out_npz = results_dir / f"{sn}_{split.label}_tensors.npz"
83
+ np.savez_compressed(
84
+ out_npz, y_pred=y_pred, y_true=y_true,
85
+ window_starts=np.asarray(starts, dtype=np.int64),
86
+ item_ids=np.asarray(split.item_ids),
87
+ L=L, H=H, stride=stride, model=model_id, dataset=split.label,
88
+ )
89
+ print(f" -> {out_npz}", file=sys.stderr)
90
+
91
+ summary = (long_df
92
+ .groupby(["model", "dataset", "metric", "h"])["value"]
93
+ .agg(mean="mean", median="median",
94
+ q25=lambda x: x.quantile(0.25),
95
+ q75=lambda x: x.quantile(0.75),
96
+ n="count")
97
+ .reset_index())
98
+ out_sum = results_dir / f"{sn}_{split.label}_summary.csv"
99
+ summary.to_csv(out_sum, index=False)
100
+ print(f" -> {out_sum}", file=sys.stderr)
101
+
102
+ print("\n Headline (median across channels):", file=sys.stderr)
103
+ print(summary.pivot_table(index="metric", columns="h",
104
+ values="median").to_string(),
105
+ file=sys.stderr)
106
+
107
+
108
+ def parse_patch_size(s: str):
109
+ if s == "auto":
110
+ return "auto"
111
+ return int(s)
112
+
113
+
114
+ def main():
115
+ repo = Path(__file__).resolve().parents[1]
116
+ ap = argparse.ArgumentParser()
117
+ ap.add_argument("--root", default=str(repo / "data"))
118
+ ap.add_argument("--dataset", default="output_item50")
119
+ ap.add_argument("--scenario_path", default=None,
120
+ help="Path to a scenario directory; "
121
+ "overrides --root/--dataset when set.")
122
+ ap.add_argument("--label", default=None,
123
+ help="Output filename label; defaults to out_dir.name.")
124
+ ap.add_argument("--model_id", default="Salesforce/moirai-1.1-R-base")
125
+ ap.add_argument("--out", default=str(
126
+ repo / "results" / "eval" / "baseline_and_scenarios"))
127
+ ap.add_argument("--L", type=int, default=1024,
128
+ help="context length (Moirai recommends 512-2048)")
129
+ ap.add_argument("--H", type=int, default=30)
130
+ ap.add_argument("--stride", type=int, default=30)
131
+ ap.add_argument("--num_samples", type=int, default=100,
132
+ help="probabilistic samples; 100 needed for stable median")
133
+ ap.add_argument("--batch_size", type=int, default=8,
134
+ help="windows per inference batch")
135
+ ap.add_argument("--patch_size", type=parse_patch_size, default=32,
136
+ help="'auto' or one of {8,16,32,64,128}")
137
+ ap.add_argument("--max_windows", type=int, default=None,
138
+ help="cap for smoke testing")
139
+ args = ap.parse_args()
140
+ if args.scenario_path is not None:
141
+ out_dir = Path(args.scenario_path)
142
+ else:
143
+ out_dir = Path(args.root) / args.dataset
144
+ run(out_dir,
145
+ args.model_id, Path(args.out),
146
+ L=args.L, H=args.H, stride=args.stride,
147
+ num_samples=args.num_samples,
148
+ batch_size=args.batch_size, patch_size=args.patch_size,
149
+ max_windows=args.max_windows,
150
+ label=args.label)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
eval/moirai_runner.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Moirai zero-shot rolling-origin inference.
3
+
4
+ Wraps Salesforce moirai (uni2ts.model.moirai) to forecast our supply-chain
5
+ demand release. Multivariate-native: a single forward predicts all C items.
6
+
7
+ Designed to mirror chronos_runner.py's I/O contract so metrics.py can stay
8
+ unchanged: produces y_pred of shape (n_windows, H, n_items) with point
9
+ forecast = median over num_samples draws.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import sys
14
+ import time
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ from gluonts.dataset.multivariate_grouper import MultivariateGrouper
20
+ from gluonts.dataset.pandas import PandasDataset
21
+ from gluonts.dataset.split import split as gluonts_split
22
+
23
+ from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
24
+
25
+
26
+ def load_moirai(model_id: str, prediction_length: int, context_length: int,
27
+ target_dim: int, num_samples: int = 100,
28
+ patch_size: int | str = 32, batch_size: int = 32,
29
+ device: str | None = None):
30
+ """Build a Moirai predictor from a pretrained checkpoint."""
31
+ if device is None:
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print(f" loading {model_id} on {device} "
34
+ f"(L={context_length}, H={prediction_length}, "
35
+ f"target_dim={target_dim}, patch_size={patch_size}, "
36
+ f"num_samples={num_samples})", file=sys.stderr)
37
+ module = MoiraiModule.from_pretrained(model_id)
38
+ model = MoiraiForecast(
39
+ module=module,
40
+ prediction_length=prediction_length,
41
+ context_length=context_length,
42
+ patch_size=patch_size,
43
+ num_samples=num_samples,
44
+ target_dim=target_dim,
45
+ feat_dynamic_real_dim=0,
46
+ past_feat_dynamic_real_dim=0,
47
+ )
48
+ predictor = model.create_predictor(batch_size=batch_size)
49
+ return predictor
50
+
51
+
52
+ def build_test_data(D: np.ndarray, item_ids: list[str],
53
+ val_end: int, H: int, stride: int,
54
+ max_windows: int | None = None):
55
+ """Build a GluonTS multivariate test dataset for rolling-origin eval.
56
+
57
+ The test region is [val_end, T). Window i has its forecast horizon at
58
+ [val_end + i*stride, val_end + i*stride + H), with full-history input
59
+ (the model itself trims to the last `context_length` steps).
60
+ """
61
+ T, n_items = D.shape
62
+ test_len = T - val_end
63
+
64
+ # Wide DataFrame: each column is one item; daily frequency.
65
+ df = pd.DataFrame(
66
+ D,
67
+ columns=item_ids,
68
+ index=pd.date_range("2000-01-01", periods=T, freq="D"),
69
+ )
70
+ ds = PandasDataset(dict(df))
71
+ grouper = MultivariateGrouper(len(ds))
72
+ multivar_ds = grouper(ds)
73
+
74
+ train, test_template = gluonts_split(multivar_ds, offset=-test_len)
75
+
76
+ n_windows = (test_len - H) // stride + 1
77
+ if max_windows is not None:
78
+ n_windows = min(n_windows, max_windows)
79
+
80
+ test_data = test_template.generate_instances(
81
+ prediction_length=H,
82
+ windows=n_windows,
83
+ distance=stride,
84
+ )
85
+ return test_data, n_windows
86
+
87
+
88
+ def predict_rolling_origin(predictor, test_data, n_windows: int,
89
+ H: int, n_items: int) -> np.ndarray:
90
+ """Returns y_pred of shape (n_windows, H, n_items); point=median samples."""
91
+ y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32)
92
+ t0 = time.time()
93
+ forecasts = predictor.predict(test_data.input)
94
+ for wi, fc in enumerate(forecasts):
95
+ # fc.samples shape: (num_samples, H, target_dim)
96
+ samples = np.asarray(fc.samples, dtype=np.float32)
97
+ med = np.median(samples, axis=0) # (H, target_dim)
98
+ y_pred[wi] = med
99
+ if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows:
100
+ elapsed = time.time() - t0
101
+ rate = (wi + 1) / max(elapsed, 1e-9)
102
+ eta = (n_windows - wi - 1) / max(rate, 1e-9)
103
+ print(f" window {wi+1:4d}/{n_windows} "
104
+ f"elapsed={elapsed:6.1f}s "
105
+ f"rate={rate:5.2f} win/s "
106
+ f"eta={eta/60:5.1f}min", file=sys.stderr)
107
+ return y_pred
108
+
109
+
110
+ def collect_y_true(test_data, n_windows: int, H: int,
111
+ n_items: int) -> np.ndarray:
112
+ """Extract ground truth labels: shape (n_windows, H, n_items)."""
113
+ y_true = np.zeros((n_windows, H, n_items), dtype=np.float32)
114
+ for wi, lbl in enumerate(test_data.label):
115
+ # lbl["target"] shape: (target_dim, H)
116
+ y_true[wi] = np.asarray(lbl["target"], dtype=np.float32).T
117
+ return y_true
eval/timesfm_run.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI driver: run TimesFM v2 zero-shot on one ISOMORPH release.
3
+
4
+ Outputs (mirror Chronos / Moirai runners):
5
+ results/{model_short}_{dataset}.csv per-channel long-format metrics
6
+ results/{model_short}_{dataset}_summary.csv cross-channel mean/median
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from data_utils import load_dataset, iter_test_windows
19
+ from metrics import metrics_at_horizons, to_long_dataframe, HORIZONS
20
+ from timesfm_runner import (
21
+ load_timesfm, predict_rolling_origin, collect_y_true,
22
+ )
23
+
24
+
25
+ def short_name(model_id: str) -> str:
26
+ return model_id.split("/")[-1].replace("-", "_").replace(".", "_")
27
+
28
+
29
+ def run(out_dir: Path, model_id: str, results_dir: Path,
30
+ L: int, H: int, stride: int,
31
+ per_core_batch_size: int, horizon_len: int,
32
+ max_windows: int | None = None,
33
+ label: str | None = None) -> None:
34
+ print(f"=== {out_dir.name} with {model_id} ===", file=sys.stderr)
35
+ split = load_dataset(out_dir)
36
+ if label is not None:
37
+ split.label = label
38
+ print(f" T={split.T} n_items={split.n_items} "
39
+ f"train_end={split.train_end} val_end={split.val_end} "
40
+ f"test=[{split.test_start}, {split.T})", file=sys.stderr)
41
+
42
+ starts = list(iter_test_windows(split, L=L, H=H, stride=stride))
43
+ if max_windows is not None:
44
+ starts = starts[:max_windows]
45
+ print(f" rolling-origin windows: {len(starts)} "
46
+ f"(L={L}, H={H}, stride={stride})", file=sys.stderr)
47
+
48
+ tfm = load_timesfm(
49
+ model_id=model_id,
50
+ context_len=L, horizon_len=horizon_len,
51
+ per_core_batch_size=per_core_batch_size,
52
+ )
53
+
54
+ t0 = time.time()
55
+ y_pred = predict_rolling_origin(tfm, split.D, starts, L=L, H=H)
56
+ y_true = collect_y_true(split.D, starts, H)
57
+ elapsed = time.time() - t0
58
+ print(f" inference done in {elapsed/60:.1f} min", file=sys.stderr)
59
+
60
+ metric_dict = metrics_at_horizons(y_true, y_pred, split.mase_denom,
61
+ horizons=HORIZONS)
62
+ long_df = to_long_dataframe(metric_dict, split.item_ids,
63
+ model=model_id, dataset=split.label)
64
+
65
+ sn = short_name(model_id)
66
+ results_dir.mkdir(parents=True, exist_ok=True)
67
+ out_long = results_dir / f"{sn}_{split.label}.csv"
68
+ long_df.to_csv(out_long, index=False)
69
+ print(f" -> {out_long}", file=sys.stderr)
70
+
71
+ # Persist raw tensors for post-hoc slicing (e.g. stationary-vs-shock).
72
+ out_npz = results_dir / f"{sn}_{split.label}_tensors.npz"
73
+ np.savez_compressed(
74
+ out_npz, y_pred=y_pred, y_true=y_true,
75
+ window_starts=np.asarray(starts, dtype=np.int64),
76
+ item_ids=np.asarray(split.item_ids),
77
+ L=L, H=H, stride=stride, model=model_id, dataset=split.label,
78
+ )
79
+ print(f" -> {out_npz}", file=sys.stderr)
80
+
81
+ summary = (long_df
82
+ .groupby(["model", "dataset", "metric", "h"])["value"]
83
+ .agg(mean="mean", median="median",
84
+ q25=lambda x: x.quantile(0.25),
85
+ q75=lambda x: x.quantile(0.75),
86
+ n="count")
87
+ .reset_index())
88
+ out_sum = results_dir / f"{sn}_{split.label}_summary.csv"
89
+ summary.to_csv(out_sum, index=False)
90
+ print(f" -> {out_sum}", file=sys.stderr)
91
+
92
+ print("\n Headline (median across channels):", file=sys.stderr)
93
+ print(summary.pivot_table(index="metric", columns="h",
94
+ values="median").to_string(),
95
+ file=sys.stderr)
96
+
97
+
98
+ def main():
99
+ repo = Path(__file__).resolve().parents[1]
100
+ ap = argparse.ArgumentParser()
101
+ ap.add_argument("--root", default=str(repo / "data"))
102
+ ap.add_argument("--dataset", default="output_item50")
103
+ ap.add_argument("--scenario_path", default=None,
104
+ help="Path to a scenario directory; "
105
+ "overrides --root/--dataset when set.")
106
+ ap.add_argument("--label", default=None,
107
+ help="Output filename label; defaults to out_dir.name.")
108
+ ap.add_argument("--model_id", default="google/timesfm-2.0-500m-pytorch")
109
+ ap.add_argument("--out", default=str(
110
+ repo / "results" / "eval" / "baseline_and_scenarios"))
111
+ ap.add_argument("--L", type=int, default=2048,
112
+ help="context length (TimesFM v2 trained up to 2048)")
113
+ ap.add_argument("--H", type=int, default=30,
114
+ help="reported horizon (model decodes horizon_len, sliced)")
115
+ ap.add_argument("--horizon_len", type=int, default=128,
116
+ help="model decode length; default 128 = one output patch")
117
+ ap.add_argument("--stride", type=int, default=30)
118
+ ap.add_argument("--per_core_batch_size", type=int, default=32,
119
+ help="channels processed in parallel per forward")
120
+ ap.add_argument("--max_windows", type=int, default=None,
121
+ help="cap for smoke testing")
122
+ args = ap.parse_args()
123
+ if args.scenario_path is not None:
124
+ out_dir = Path(args.scenario_path)
125
+ else:
126
+ out_dir = Path(args.root) / args.dataset
127
+ run(out_dir,
128
+ args.model_id, Path(args.out),
129
+ L=args.L, H=args.H, stride=args.stride,
130
+ per_core_batch_size=args.per_core_batch_size,
131
+ horizon_len=args.horizon_len,
132
+ max_windows=args.max_windows,
133
+ label=args.label)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
eval/timesfm_runner.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TimesFM zero-shot rolling-origin inference.
3
+
4
+ Wraps google/timesfm-2.0-500m-pytorch (univariate, like Chronos). Per
5
+ window we hand the model a list of C univariate context arrays and it
6
+ returns a (C, horizon_len) median-point forecast that we slice to H=30.
7
+
8
+ Designed to mirror chronos_runner.py's I/O contract so metrics.py can
9
+ stay unchanged.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import sys
14
+ import time
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ import timesfm
20
+
21
+
22
+ def load_timesfm(model_id: str = "google/timesfm-2.0-500m-pytorch",
23
+ context_len: int = 2048, horizon_len: int = 128,
24
+ per_core_batch_size: int = 32):
25
+ """Build a TimesFM v2 predictor (PyTorch backend, GPU-aware)."""
26
+ backend = "gpu" if torch.cuda.is_available() else "cpu"
27
+ print(f" loading {model_id} on {backend} "
28
+ f"(L={context_len}, horizon_len={horizon_len}, "
29
+ f"per_core_batch_size={per_core_batch_size})", file=sys.stderr)
30
+
31
+ # v2-500m-pytorch hparams (per timesfm finetuning_example).
32
+ hparams = timesfm.TimesFmHparams(
33
+ backend=backend,
34
+ per_core_batch_size=per_core_batch_size,
35
+ horizon_len=horizon_len,
36
+ num_layers=50,
37
+ use_positional_embedding=False,
38
+ context_len=context_len,
39
+ # defaults: model_dims=1280, num_heads=16,
40
+ # input_patch_len=32, output_patch_len=128,
41
+ # quantiles=(.1,.2,...,.9), point_forecast_mode='median'
42
+ )
43
+ tfm = timesfm.TimesFm(
44
+ hparams=hparams,
45
+ checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_id),
46
+ )
47
+ return tfm
48
+
49
+
50
+ def predict_rolling_origin(tfm, D: np.ndarray, window_starts: list[int],
51
+ L: int, H: int) -> np.ndarray:
52
+ """Returns y_pred of shape (n_windows, H, n_items); point=median."""
53
+ n_windows = len(window_starts)
54
+ n_items = D.shape[1]
55
+ y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32)
56
+
57
+ t0 = time.time()
58
+ for wi, t in enumerate(window_starts):
59
+ ctx = D[t - L:t, :] # (L, n_items)
60
+ # forecast() takes a list of 1-D arrays, one per channel.
61
+ inputs = [ctx[:, j].astype(np.float32) for j in range(n_items)]
62
+ # freq=0 ("high frequency") matches daily; harmless either way for
63
+ # zero-shot eval since TimesFM uses freq only as a coarse covariate.
64
+ freqs = [0] * n_items
65
+ point_fc, _ = tfm.forecast(inputs=inputs, freq=freqs)
66
+ # point_fc shape: (n_items, horizon_len). Slice to first H steps.
67
+ y_pred[wi] = np.asarray(point_fc[:, :H], dtype=np.float32).T
68
+ if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows:
69
+ elapsed = time.time() - t0
70
+ rate = (wi + 1) / max(elapsed, 1e-9)
71
+ eta = (n_windows - wi - 1) / max(rate, 1e-9)
72
+ print(f" window {wi+1:4d}/{n_windows} "
73
+ f"elapsed={elapsed:6.1f}s "
74
+ f"rate={rate:5.2f} win/s "
75
+ f"eta={eta/60:5.1f}min", file=sys.stderr)
76
+ return y_pred
77
+
78
+
79
+ def collect_y_true(D: np.ndarray, window_starts: list[int],
80
+ H: int) -> np.ndarray:
81
+ """Mirror chronos_runner.collect_y_true."""
82
+ n_windows = len(window_starts)
83
+ n_items = D.shape[1]
84
+ y_true = np.zeros((n_windows, H, n_items), dtype=np.float32)
85
+ for wi, t in enumerate(window_starts):
86
+ y_true[wi] = D[t:t + H]
87
+ return y_true
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo dependencies — used by app.py (Hugging Face Spaces / local run).
2
+ # The full research pipeline (eval/, uq/, analysis/) requires additional
3
+ # packages (torch, gluonts, chronos-forecasting, uni2ts, timesfm).
4
+ # See the GitHub repo for the complete research requirements.
5
+
6
+ gradio>=4.44
7
+ plotly>=5.0
8
+ numpy>=1.24
9
+ pandas>=2.0
10
+ kaleido>=1.0 # GIF export in the Network Map tab
11
+ Pillow>=9.2 # GIF export in the Network Map tab
simulator/.DS_Store ADDED
Binary file (6.15 kB). View file
 
simulator/Supplychaingeo_item200.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supply Chain Simulation — Day-level, 52560-step.
3
+ 200-item.
4
+
5
+ 1 step = 1 day | 365 steps = 1 year | 52560 steps = 144 years
6
+
7
+ logic:
8
+ - (s,S) inventory policy at warehouses
9
+ - Dijkstra routing (weight = travel_time / daily_capacity)
10
+ - Greedy first-fit bin packing
11
+ - Proactive shipping via pipeline_multiplier
12
+ - Streaming CSV for large runs
13
+
14
+
15
+ Multi-echelon changes:
16
+ - Only source nodes (SF, StLouis, Orlando) retain magic replenishment
17
+ - Intermediate nodes pull inventory from upstream via network edges
18
+ - Per-tier (s,S) parameters calibrated to demand flow
19
+ """
20
+
21
+ from __future__ import annotations
22
+ from dataclasses import dataclass, field
23
+ from typing import Dict, List, Tuple, Optional, Any, Callable
24
+ import math
25
+ import heapq
26
+ import random
27
+ import os
28
+ import csv
29
+ import time as _time_module
30
+ from datetime import datetime, timedelta
31
+ import io
32
+ import base64
33
+
34
+ try:
35
+ import matplotlib.pyplot as plt
36
+ from branca.element import Element
37
+ except ImportError:
38
+ pass
39
+
40
+ import numpy as np
41
+ import pandas as pd
42
+
43
+ try:
44
+ import folium
45
+ from folium.plugins import TimestampedGeoJson
46
+ HAS_FOLIUM = True
47
+ except ImportError:
48
+ HAS_FOLIUM = False
49
+
50
+
51
+ # ============================================================================
52
+ # Data Structures
53
+ # ============================================================================
54
+
55
+ @dataclass
56
+ class Item:
57
+ item_id: str
58
+ volume: float
59
+
60
+
61
+ @dataclass
62
+ class Edge:
63
+ u: str
64
+ v: str
65
+ travel_time_days: float
66
+ container_volume: float
67
+ num_containers_per_day: int
68
+ daily_containers: List[float] = field(default_factory=list)
69
+
70
+ def reset_daily(self, capacity_factor: float = 1.0) -> None:
71
+ effective_vol = self.container_volume * capacity_factor
72
+ self.daily_containers = [effective_vol] * self.num_containers_per_day
73
+
74
+ def find_container_slot(self, item_volume: float) -> Optional[int]:
75
+ for idx, rem in enumerate(self.daily_containers):
76
+ if rem >= item_volume:
77
+ return idx
78
+ return None
79
+
80
+ def allocate_in_container(self, idx: int, item_volume: float) -> bool:
81
+ if 0 <= idx < len(self.daily_containers) and \
82
+ self.daily_containers[idx] >= item_volume:
83
+ self.daily_containers[idx] -= item_volume
84
+ return True
85
+ return False
86
+
87
+ @property
88
+ def daily_total_capacity(self) -> float:
89
+ return self.container_volume * max(self.num_containers_per_day, 0)
90
+
91
+
92
+ @dataclass
93
+ class Node:
94
+ node_id: str
95
+ lat: float
96
+ lon: float
97
+ is_destination: bool = False
98
+ is_source: bool = False
99
+ inventory: Dict[str, int] = field(default_factory=dict)
100
+ s_levels: Dict[str, int] = field(default_factory=dict)
101
+ S_levels: Dict[str, int] = field(default_factory=dict)
102
+ lead_time_mean: Dict[str, float] = field(default_factory=dict)
103
+ lead_time_std_frac: float = 0.2
104
+ outstanding_orders: Dict[str, Optional[Tuple[int, int]]] = \
105
+ field(default_factory=dict)
106
+ backlog: Dict[str, int] = field(default_factory=dict)
107
+
108
+ def receive_orders_today(self, day: int) -> None:
109
+ to_clear = []
110
+ for item_id, order in self.outstanding_orders.items():
111
+ if order is None:
112
+ continue
113
+ arrival_day, qty = order
114
+ if day >= arrival_day:
115
+ self.inventory[item_id] = \
116
+ self.inventory.get(item_id, 0) + qty
117
+ to_clear.append(item_id)
118
+ for iid in to_clear:
119
+ self.outstanding_orders[iid] = None
120
+
121
+ def maybe_place_orders(self, day: int, rng: random.Random) -> None:
122
+ if self.is_destination:
123
+ return
124
+ if not self.is_source:
125
+ return
126
+ for item_id, s in self.s_levels.items():
127
+ on_hand = self.inventory.get(item_id, 0)
128
+ if on_hand < s and \
129
+ self.outstanding_orders.get(item_id) is None:
130
+ S = self.S_levels.get(item_id, on_hand)
131
+ qty = max(S - on_hand, 0)
132
+ if qty <= 0:
133
+ continue
134
+ mean_lt = max(self.lead_time_mean.get(item_id, 1.0), 0.1)
135
+ std = self.lead_time_std_frac * mean_lt
136
+ sampled = rng.normalvariate(mean_lt, std)
137
+ lt_days = max(1, int(math.ceil(sampled)))
138
+ self.outstanding_orders[item_id] = (day + lt_days, qty)
139
+
140
+
141
+ # ============================================================================
142
+ # Network + Dijkstra
143
+ # ============================================================================
144
+
145
+ class Network:
146
+ def __init__(self) -> None:
147
+ self.nodes: Dict[str, Node] = {}
148
+ self.edges: Dict[Tuple[str, str], Edge] = {}
149
+ self.adj: Dict[str, List[str]] = {}
150
+ self.weight_cache: Dict[Tuple[str, str], float] = {}
151
+ self.paths_to_dest: Dict[
152
+ str, Tuple[float, List[str], List[Tuple[str, str]]]] = {}
153
+
154
+ def add_node(self, node: Node) -> None:
155
+ self.nodes[node.node_id] = node
156
+ self.adj.setdefault(node.node_id, [])
157
+
158
+ def add_edge(self, edge: Edge) -> None:
159
+ self.edges[(edge.u, edge.v)] = edge
160
+ self.adj.setdefault(edge.u, []).append(edge.v)
161
+ cap = max(edge.daily_total_capacity, 1e-9)
162
+ self.weight_cache[(edge.u, edge.v)] = edge.travel_time_days / cap
163
+
164
+ def reset_daily_edges(self, capacity_factor: float = 1.0) -> None:
165
+ for e in self.edges.values():
166
+ e.reset_daily(capacity_factor)
167
+
168
+ def dijkstra(self, source: str, target: str) -> Tuple[float, List[str]]:
169
+ if source == target:
170
+ return 0.0, [source]
171
+ dist: Dict[str, float] = {source: 0.0}
172
+ prev: Dict[str, Optional[str]] = {source: None}
173
+ pq = [(0.0, source)]
174
+ visited: set = set()
175
+ while pq:
176
+ d, u = heapq.heappop(pq)
177
+ if u in visited:
178
+ continue
179
+ visited.add(u)
180
+ if u == target:
181
+ break
182
+ for v in self.adj.get(u, []):
183
+ w = self.weight_cache[(u, v)]
184
+ nd = d + w
185
+ if nd < dist.get(v, float('inf')):
186
+ dist[v] = nd
187
+ prev[v] = u
188
+ heapq.heappush(pq, (nd, v))
189
+ if target not in dist:
190
+ return float('inf'), []
191
+ path: List[str] = []
192
+ cur: Optional[str] = target
193
+ while cur is not None:
194
+ path.append(cur)
195
+ cur = prev.get(cur)
196
+ path.reverse()
197
+ return dist[target], path
198
+
199
+ def compute_paths_to_destination(self, dest_id: str) -> None:
200
+ self.paths_to_dest.clear()
201
+ for nid in self.nodes:
202
+ if nid == dest_id:
203
+ self.paths_to_dest[nid] = (0.0, [nid], [])
204
+ continue
205
+ d, pn = self.dijkstra(nid, dest_id)
206
+ if not pn:
207
+ self.paths_to_dest[nid] = (float('inf'), [], [])
208
+ else:
209
+ pe = [(pn[i], pn[i+1]) for i in range(len(pn)-1)]
210
+ self.paths_to_dest[nid] = (d, pn, pe)
211
+
212
+
213
+ # ============================================================================
214
+ # Greedy First-Fit Bin Packing
215
+ # ============================================================================
216
+
217
+ def allocate_units_along_path_greedy(
218
+ item: Item,
219
+ max_units: int,
220
+ path_edges: List[Tuple[str, str]],
221
+ network_edges: Dict[Tuple[str, str], Edge],
222
+ ) -> int:
223
+ if max_units <= 0 or not path_edges:
224
+ return 0
225
+ edges_objs = [network_edges[eid] for eid in path_edges]
226
+ placed = 0
227
+ ivol = item.volume
228
+ for _ in range(max_units):
229
+ slots: List[Tuple[Edge, int]] = []
230
+ ok = True
231
+ for e in edges_objs:
232
+ si = e.find_container_slot(ivol)
233
+ if si is None:
234
+ ok = False
235
+ break
236
+ slots.append((e, si))
237
+ if not ok:
238
+ break
239
+ good = True
240
+ for e, si in slots:
241
+ if not e.allocate_in_container(si, ivol):
242
+ good = False
243
+ break
244
+ if not good:
245
+ break
246
+ placed += 1
247
+ return placed
248
+
249
+
250
+ # ============================================================================
251
+ # Simulation Engine
252
+ # ============================================================================
253
+
254
+ class SupplyChainSimulation:
255
+
256
+ def __init__(
257
+ self,
258
+ network: Network,
259
+ items: Dict[str, Item],
260
+ destination_id: str,
261
+ demand_fn: Callable[[int], Dict[str, int]],
262
+ horizon_days: int,
263
+ seed: int = 42,
264
+ pipeline_multiplier: float = 0.0,
265
+ streaming_out_dir: Optional[str] = None,
266
+ packing: str = "greedy",
267
+ ) -> None:
268
+ assert destination_id in network.nodes and \
269
+ network.nodes[destination_id].is_destination
270
+ self.network = network
271
+ self.items = items
272
+ self.item_order: List[str] = sorted(self.items.keys())
273
+ self.round_robin_items: bool = True
274
+ self.per_item_daily_cap_units: Optional[int] = None
275
+
276
+ self.destination_id = destination_id
277
+ self.demand_fn = demand_fn
278
+ self.horizon_days = horizon_days
279
+ self.rng = random.Random(seed)
280
+ self.packing = packing
281
+ self.pipeline_multiplier = pipeline_multiplier
282
+
283
+ # EMA warm-start at approximate per-day mean demand
284
+ self.demand_ema: Dict[str, float] = {iid: 165.0 for iid in items}
285
+ self.ema_alpha = 0.05
286
+ self.item_intransit: Dict[str, int] = {iid: 0 for iid in items}
287
+
288
+ self.streaming_out_dir = streaming_out_dir
289
+ self._csv_files: Dict = {}
290
+ self._csv_writers: Dict = {}
291
+ self._csv_buffers: Dict[str, list] = {}
292
+ self.dest_in_transit: Dict[int, Dict[str, int]] = {}
293
+
294
+ self.daily_records: list = []
295
+ self.shipments_log: list = []
296
+ self.inventory_history: list = []
297
+ self.backlog_history: list = []
298
+ self.intransit_history: list = []
299
+
300
+ self.svc_demand: Dict[str, int] = {iid: 0 for iid in items}
301
+ self.svc_served: Dict[str, int] = {iid: 0 for iid in items}
302
+ self.svc_backlog: Dict[str, int] = {iid: 0 for iid in items}
303
+
304
+ network.compute_paths_to_destination(destination_id)
305
+ self.sorted_warehouses = sorted(
306
+ [(nid, d) for nid, (d, _, _) in network.paths_to_dest.items()
307
+ if nid != destination_id and math.isfinite(d)],
308
+ key=lambda x: x[1])
309
+
310
+ # precompute supplier relationships for intermediate nodes
311
+ self.node_suppliers: Dict[
312
+ str,
313
+ List[Tuple[str, float, List[str], List[Tuple[str, str]]]]
314
+ ] = {}
315
+ for nid, node in network.nodes.items():
316
+ if node.is_destination or node.is_source:
317
+ continue
318
+ suppliers = []
319
+ for (u, v) in network.edges:
320
+ if v == nid:
321
+ d, path = network.dijkstra(u, nid)
322
+ if path and math.isfinite(d):
323
+ pe = [(path[i], path[i + 1])
324
+ for i in range(len(path) - 1)]
325
+ tt = sum(network.edges[e].travel_time_days
326
+ for e in pe)
327
+ suppliers.append((u, tt, path, pe))
328
+ suppliers.sort(key=lambda x: x[1])
329
+ self.node_suppliers[nid] = suppliers
330
+
331
+ # intermediate nodes ordered upstream-first for replenishment
332
+ self.replenish_order: List[str] = [
333
+ nid for nid, _ in reversed(self.sorted_warehouses)
334
+ if nid in self.node_suppliers]
335
+
336
+ self.avg_travel_time = 6.0
337
+ if self.sorted_warehouses:
338
+ _, _, pe = network.paths_to_dest[self.sorted_warehouses[0][0]]
339
+ if pe:
340
+ self.avg_travel_time = sum(
341
+ network.edges[e].travel_time_days for e in pe)
342
+
343
+ def _total_tt(self, pe):
344
+ return sum(self.network.edges[e].travel_time_days for e in pe)
345
+
346
+ def _replenish_warehouses(self, day: int) -> None:
347
+ """Inter-warehouse replenishment: intermediate nodes pull
348
+ from upstream suppliers using (s,S) trigger + edge capacity."""
349
+ net = self.network
350
+
351
+ if self.round_robin_items and self.item_order:
352
+ k = day % len(self.item_order)
353
+ ids_today = self.item_order[k:] + self.item_order[:k]
354
+ else:
355
+ ids_today = self.item_order[:]
356
+
357
+ for nid in self.replenish_order:
358
+ node = net.nodes[nid]
359
+ for iid in ids_today:
360
+ on_hand = node.inventory.get(iid, 0)
361
+ s = node.s_levels.get(iid, 0)
362
+ if on_hand >= s:
363
+ continue
364
+ if node.outstanding_orders.get(iid) is not None:
365
+ continue
366
+ S = node.S_levels.get(iid, on_hand)
367
+ qty_needed = max(S - on_hand, 0)
368
+ if qty_needed <= 0:
369
+ continue
370
+
371
+ for sup_id, tt, path, pe in \
372
+ self.node_suppliers.get(nid, []):
373
+ sup_node = net.nodes[sup_id]
374
+ avail = sup_node.inventory.get(iid, 0)
375
+ if avail <= 0:
376
+ continue
377
+ attempt = min(avail, qty_needed)
378
+ placed = allocate_units_along_path_greedy(
379
+ self.items[iid], attempt, pe, net.edges)
380
+ if placed <= 0:
381
+ continue
382
+
383
+ sup_node.inventory[iid] -= placed
384
+ arr = day + max(1, int(math.ceil(tt)))
385
+ node.outstanding_orders[iid] = (arr, placed)
386
+
387
+ r = [day, arr, sup_id, nid, iid, placed,
388
+ str(path),
389
+ str([net.edges[e].travel_time_days
390
+ for e in pe])]
391
+ if self.streaming_out_dir:
392
+ self._csv_buffers.setdefault(
393
+ 'ship', []).append(r)
394
+ else:
395
+ self.shipments_log.append({
396
+ "day": day, "arrival_day": arr,
397
+ "from": sup_id, "to": nid,
398
+ "item": iid, "units": placed,
399
+ "path_nodes": path,
400
+ "edge_times": [
401
+ net.edges[e].travel_time_days
402
+ for e in pe]})
403
+ break
404
+
405
+ def step(self, day: int) -> None:
406
+ net = self.network
407
+ dest = net.nodes[self.destination_id]
408
+
409
+ # 1) Receive (s,S) replenishment at warehouses
410
+ for node in net.nodes.values():
411
+ node.receive_orders_today(day)
412
+
413
+ # 2) Arrivals at destination
414
+ arrivals = self.dest_in_transit.pop(day, {})
415
+ for iid, qty in arrivals.items():
416
+ self.item_intransit[iid] = max(
417
+ 0, self.item_intransit.get(iid, 0) - qty)
418
+ bl = dest.backlog.get(iid, 0)
419
+ if bl > 0:
420
+ use = min(qty, bl)
421
+ dest.backlog[iid] = bl - use
422
+ qty -= use
423
+ if qty > 0:
424
+ dest.inventory[iid] = dest.inventory.get(iid, 0) + qty
425
+
426
+ # 3) Reset edge containers
427
+ net.reset_daily_edges(1.0)
428
+
429
+ # 4) Demand at destination
430
+ td = self.demand_fn(day)
431
+ for iid in self.items:
432
+ dq = int(td.get(iid, 0))
433
+ self.demand_ema[iid] = (
434
+ self.ema_alpha * dq +
435
+ (1 - self.ema_alpha) * self.demand_ema[iid])
436
+ oh = dest.inventory.get(iid, 0)
437
+ if oh >= dq:
438
+ served, unfilled = dq, 0
439
+ dest.inventory[iid] = oh - dq
440
+ else:
441
+ served, unfilled = oh, dq - oh
442
+ dest.inventory[iid] = 0
443
+ dest.backlog[iid] = dest.backlog.get(iid, 0) + unfilled
444
+ self.svc_demand[iid] += dq
445
+ self.svc_served[iid] += served
446
+ self.svc_backlog[iid] += unfilled
447
+ rec = [day, iid, dq, served, unfilled,
448
+ dest.inventory.get(iid, 0), dest.backlog.get(iid, 0)]
449
+ if self.streaming_out_dir:
450
+ self._csv_buffers.setdefault('daily', []).append(rec)
451
+ else:
452
+ self.daily_records.append({
453
+ "day": day, "item": iid, "demand": dq,
454
+ "served_from_stock": served,
455
+ "new_backlog_today": unfilled,
456
+ "dest_on_hand_end_before_ship":
457
+ dest.inventory.get(iid, 0),
458
+ "dest_backlog_end_before_ship":
459
+ dest.backlog.get(iid, 0)})
460
+
461
+ # 5) Ship
462
+ if self.round_robin_items and self.item_order:
463
+ k = day % len(self.item_order)
464
+ ids_today = self.item_order[k:] + self.item_order[:k]
465
+ else:
466
+ ids_today = self.item_order[:]
467
+
468
+ for iid in ids_today:
469
+ item = self.items[iid]
470
+ cb = dest.backlog.get(iid, 0)
471
+ it = self.item_intransit.get(iid, 0)
472
+ oh = dest.inventory.get(iid, 0)
473
+
474
+ if self.pipeline_multiplier > 0:
475
+ pt = self.demand_ema[iid] * self.pipeline_multiplier
476
+ ship_target = max(0, int(math.ceil(cb + pt - it - oh)))
477
+ else:
478
+ S_dest = max(1, int(self.demand_ema[iid] * 3))
479
+ ship_target = max(0, cb + S_dest - oh - it)
480
+
481
+ if ship_target <= 0:
482
+ continue
483
+
484
+ remaining = ship_target
485
+ shipped = 0
486
+ for wid, _ in self.sorted_warehouses:
487
+ if remaining <= 0:
488
+ break
489
+ if self.per_item_daily_cap_units is not None and \
490
+ shipped >= self.per_item_daily_cap_units:
491
+ break
492
+ wn = net.nodes[wid]
493
+ avail = wn.inventory.get(iid, 0)
494
+ if avail <= 0:
495
+ continue
496
+ _, pn, pe = net.paths_to_dest[wid]
497
+ if not pe:
498
+ continue
499
+ attempt = min(avail, remaining)
500
+ if self.per_item_daily_cap_units is not None:
501
+ attempt = min(attempt,
502
+ self.per_item_daily_cap_units - shipped)
503
+ placed = allocate_units_along_path_greedy(
504
+ item, attempt, pe, net.edges)
505
+ if placed <= 0:
506
+ continue
507
+ wn.inventory[iid] -= placed
508
+ remaining -= placed
509
+ shipped += placed
510
+ arr = day + max(1, int(math.ceil(self._total_tt(pe))))
511
+ self.dest_in_transit.setdefault(arr, {})
512
+ self.dest_in_transit[arr][iid] = \
513
+ self.dest_in_transit[arr].get(iid, 0) + placed
514
+ self.item_intransit[iid] = \
515
+ self.item_intransit.get(iid, 0) + placed
516
+ r = [day, arr, wid, self.destination_id, iid, placed,
517
+ str(pn),
518
+ str([net.edges[e].travel_time_days for e in pe])]
519
+ if self.streaming_out_dir:
520
+ self._csv_buffers.setdefault('ship', []).append(r)
521
+ else:
522
+ self.shipments_log.append({
523
+ "day": day, "arrival_day": arr,
524
+ "from": wid, "to": self.destination_id,
525
+ "item": iid, "units": placed,
526
+ "path_nodes": pn,
527
+ "edge_times": [net.edges[e].travel_time_days
528
+ for e in pe]})
529
+
530
+ # 5b) Inter-warehouse replenishment
531
+ self._replenish_warehouses(day)
532
+
533
+ # 6) (s,S) orders at source warehouses
534
+ for node in net.nodes.values():
535
+ node.maybe_place_orders(day, self.rng)
536
+
537
+ # 7) Snapshots
538
+ did = self.destination_id
539
+ itc: Dict[str, int] = {}
540
+ for ad, im in self.dest_in_transit.items():
541
+ if ad > day:
542
+ for iid, q in im.items():
543
+ itc[iid] = itc.get(iid, 0) + int(q)
544
+
545
+ if self.streaming_out_dir:
546
+ for node in net.nodes.values():
547
+ for iid in self.items:
548
+ self._csv_buffers.setdefault('inv', []).append(
549
+ [day, node.node_id, iid,
550
+ int(node.inventory.get(iid, 0))])
551
+ self._csv_buffers.setdefault('bl', []).append(
552
+ [day, node.node_id, iid,
553
+ int(node.backlog.get(iid, 0))])
554
+ for iid in self.items:
555
+ self._csv_buffers.setdefault('it', []).append(
556
+ [day, did, iid, itc.get(iid, 0)])
557
+ else:
558
+ for node in net.nodes.values():
559
+ for iid in self.items:
560
+ self.inventory_history.append({
561
+ "day": day, "node": node.node_id, "item": iid,
562
+ "on_hand": int(node.inventory.get(iid, 0))})
563
+ self.backlog_history.append({
564
+ "day": day, "node": node.node_id, "item": iid,
565
+ "backlog": int(node.backlog.get(iid, 0))})
566
+ for iid in self.items:
567
+ self.intransit_history.append({
568
+ "day": day, "node": did, "item": iid,
569
+ "in_transit": itc.get(iid, 0)})
570
+
571
+ # --- CSV streaming ---
572
+ def _open_csv_files(self):
573
+ os.makedirs(self.streaming_out_dir, exist_ok=True)
574
+ hdrs = {
575
+ 'daily': ["day", "item", "demand", "served_from_stock",
576
+ "new_backlog_today", "dest_on_hand_end_before_ship",
577
+ "dest_backlog_end_before_ship"],
578
+ 'ship': ["day", "arrival_day", "from", "to", "item",
579
+ "units", "path_nodes", "edge_times"],
580
+ 'inv': ["day", "node", "item", "on_hand"],
581
+ 'bl': ["day", "node", "item", "backlog"],
582
+ 'it': ["day", "node", "item", "in_transit"],
583
+ }
584
+ fn = {'daily': 'daily_records', 'ship': 'shipments',
585
+ 'inv': 'inventory_history', 'bl': 'backlog_history',
586
+ 'it': 'intransit_history'}
587
+ for k in hdrs:
588
+ f = open(os.path.join(self.streaming_out_dir,
589
+ f"{fn[k]}.csv"),
590
+ "w", newline="", buffering=65536)
591
+ w = csv.writer(f)
592
+ w.writerow(hdrs[k])
593
+ self._csv_files[k] = f
594
+ self._csv_writers[k] = w
595
+ self._csv_buffers[k] = []
596
+
597
+ def _flush_csv(self):
598
+ for k in self._csv_buffers:
599
+ if self._csv_buffers[k] and k in self._csv_writers:
600
+ self._csv_writers[k].writerows(self._csv_buffers[k])
601
+ self._csv_buffers[k] = []
602
+ for f in self._csv_files.values():
603
+ f.flush()
604
+
605
+ def _close_csv(self):
606
+ self._flush_csv()
607
+ for f in self._csv_files.values():
608
+ f.close()
609
+
610
+ def run(self):
611
+ if self.streaming_out_dir:
612
+ self._open_csv_files()
613
+ t0 = _time_module.time()
614
+ ri = max(1, self.horizon_days // 100)
615
+ for day in range(self.horizon_days):
616
+ self.step(day)
617
+ if self.streaming_out_dir and day % 500 == 0:
618
+ self._flush_csv()
619
+ if day % 5000 == 0:
620
+ for k in [k for k in self.dest_in_transit if k < day]:
621
+ del self.dest_in_transit[k]
622
+ if day % ri == 0 or day == self.horizon_days - 1:
623
+ el = _time_module.time() - t0
624
+ pct = (day + 1) / self.horizon_days * 100
625
+ rate = (day + 1) / max(el, 0.001)
626
+ eta = (self.horizon_days - day - 1) / max(rate, 0.001)
627
+ print(f"\r Day {day+1:>6}/{self.horizon_days} "
628
+ f"({pct:5.1f}%) {rate:6.1f} days/s "
629
+ f"ETA {eta:6.0f}s", end="", flush=True)
630
+ print(f"\n Simulation complete in "
631
+ f"{_time_module.time()-t0:.1f}s")
632
+
633
+ if self.streaming_out_dir:
634
+ self._close_csv()
635
+ rows = []
636
+ for iid in sorted(self.items):
637
+ td = self.svc_demand[iid]
638
+ sv = self.svc_served[iid]
639
+ bl = self.svc_backlog[iid]
640
+ rows.append({
641
+ "item": iid, "total_demand": td,
642
+ "served_from_stock": sv,
643
+ "new_backlog_added": bl,
644
+ "fill_rate_stock_only":
645
+ round(sv / td, 6) if td > 0 else 0.0})
646
+ svc = pd.DataFrame(rows)
647
+ svc.to_csv(os.path.join(self.streaming_out_dir,
648
+ "service_summary.csv"), index=False)
649
+ return (pd.DataFrame(), pd.DataFrame(), svc,
650
+ pd.DataFrame(), pd.DataFrame(), pd.DataFrame())
651
+
652
+ dd = pd.DataFrame(self.daily_records)
653
+ ds = pd.DataFrame(self.shipments_log) if self.shipments_log \
654
+ else pd.DataFrame(columns=[
655
+ "day", "arrival_day", "from", "to",
656
+ "item", "units", "path_nodes", "edge_times"])
657
+ svc = dd.groupby("item", as_index=False).agg(
658
+ total_demand=("demand", "sum"),
659
+ served_from_stock=("served_from_stock", "sum"),
660
+ new_backlog_added=("new_backlog_today", "sum"))
661
+ svc["fill_rate_stock_only"] = (
662
+ svc["served_from_stock"] / svc["total_demand"]).fillna(0)
663
+ di = pd.DataFrame(self.inventory_history,
664
+ columns=["day", "node", "item", "on_hand"])
665
+ db = pd.DataFrame(self.backlog_history,
666
+ columns=["day", "node", "item", "backlog"])
667
+ dt = pd.DataFrame(self.intransit_history,
668
+ columns=["day", "node", "item", "in_transit"])
669
+ return dd, ds, svc, di, db, dt
670
+
671
+
672
+ # ============================================================================
673
+ # Build network from adjacency
674
+ # ============================================================================
675
+
676
+ def build_network_from_adjacency(nodes_meta, adjacency):
677
+ n = len(nodes_meta)
678
+ assert all(len(r) == n for r in adjacency)
679
+ net = Network()
680
+ for meta in nodes_meta:
681
+ net.add_node(Node(
682
+ node_id=meta["id"],
683
+ lat=float(meta["lat"]), lon=float(meta["lon"]),
684
+ is_destination=bool(meta.get("is_destination", False)),
685
+ is_source=bool(meta.get("is_source", False)),
686
+ inventory=dict(meta.get("inventory", {})),
687
+ s_levels=dict(meta.get("s_levels", {})),
688
+ S_levels=dict(meta.get("S_levels", {})),
689
+ lead_time_mean=dict(meta.get("lead_time_mean", {})),
690
+ backlog=dict(meta.get("backlog", {}))))
691
+ ids = [m["id"] for m in nodes_meta]
692
+ for i in range(n):
693
+ for j in range(n):
694
+ s = adjacency[i][j]
695
+ if s is None:
696
+ continue
697
+ tt, cv, nc = s
698
+ net.add_edge(Edge(
699
+ u=ids[i], v=ids[j],
700
+ travel_time_days=float(tt),
701
+ container_volume=float(cv),
702
+ num_containers_per_day=int(nc)))
703
+ return net
704
+
705
+
706
+ # ============================================================================
707
+ # Demand Generator — day-level
708
+ # ============================================================================
709
+
710
+ def build_demand_fn(
711
+ item_ids, n_steps, seed=42,
712
+ base_lambda_range=(80, 250),
713
+ ):
714
+ """
715
+
716
+ Structure:
717
+ - Yearly cycle T=365 days (moderate amplitude, two harmonics)
718
+ - Weekly cycle T=7 days (small texture)
719
+ - AR(1) drift (dominant low-frequency, decade-scale)
720
+ - Per-item spikes (sustained 1-6 months, clear amplitude)
721
+ - Global macro events (rare, wide, correlated across items)
722
+ - Poisson sampling
723
+ """
724
+ rng = np.random.default_rng(seed)
725
+ n_items = len(item_ids)
726
+
727
+ spy = 365
728
+ spw = 7
729
+
730
+ t = np.arange(n_steps, dtype=np.float64)
731
+ yearly_phase = 2 * np.pi * t / spy
732
+ weekly_phase = 2 * np.pi * (t % spw) / spw
733
+
734
+ lam = np.zeros((n_steps, n_items), dtype=np.float64)
735
+
736
+ # Global macro events
737
+ gs = np.zeros(n_steps)
738
+ n_global = int(rng.integers(5, 12))
739
+ for _ in range(n_global):
740
+ si = int(rng.integers(0, n_steps))
741
+ dur = int(rng.integers(180, 1100))
742
+ end = min(si + dur, n_steps)
743
+ h = rng.uniform(0.20, 0.60)
744
+ for k in range(si, end):
745
+ p = (k - si) / max(dur, 1)
746
+ if p < 0.15:
747
+ gs[k] += h * (p / 0.15)
748
+ elif p < 0.75:
749
+ gs[k] += h
750
+ else:
751
+ gs[k] += h * (1.0 - (p - 0.75) / 0.25)
752
+
753
+ for j, iid in enumerate(item_ids):
754
+ base = rng.uniform(*base_lambda_range)
755
+
756
+ ya1 = rng.uniform(0.12, 0.28)
757
+ ya2 = rng.uniform(0.04, 0.10)
758
+ yo = rng.uniform(0, 2 * np.pi)
759
+ yr = (ya1 * np.sin(yearly_phase + yo) +
760
+ ya2 * np.sin(2 * yearly_phase + yo * 0.7))
761
+
762
+ wa = rng.uniform(0.04, 0.10)
763
+ wo = rng.uniform(0, 2 * np.pi)
764
+ wy = wa * np.sin(weekly_phase + wo)
765
+
766
+ ac = rng.uniform(0.9990, 0.9996)
767
+ ar_std = rng.uniform(0.008, 0.018)
768
+ dr = np.zeros(n_steps)
769
+ dr[0] = rng.normal(0, 0.10)
770
+ for i in range(1, n_steps):
771
+ dr[i] = ac * dr[i-1] + rng.normal(0, ar_std)
772
+ dr = np.clip(dr, -0.60, 0.60)
773
+
774
+ sr = rng.uniform(0.0002, 0.001)
775
+ sm = rng.random(n_steps) < sr
776
+ sp = np.zeros(n_steps)
777
+ for si in np.where(sm)[0]:
778
+ dur = int(rng.integers(30, 180))
779
+ end = min(si + dur, n_steps)
780
+ h = rng.uniform(0.20, 0.70)
781
+ for k in range(si, end):
782
+ p = (k - si) / max(dur, 1)
783
+ if p < 0.15:
784
+ sp[k] += h * (p / 0.15)
785
+ elif p < 0.75:
786
+ sp[k] += h
787
+ else:
788
+ sp[k] += h * (1.0 - (p - 0.75) / 0.25)
789
+
790
+ gsens = rng.uniform(0.4, 1.2)
791
+ fac = 1.0 + yr + wy + dr + sp + gsens * gs
792
+ fac = np.clip(fac, 0.08, None)
793
+ lam[:, j] = base * fac
794
+
795
+ def demand_fn(day):
796
+ idx = day % n_steps
797
+ return {iid: int(rng.poisson(lam=max(0.01, lam[idx, j])))
798
+ for j, iid in enumerate(item_ids)}
799
+
800
+ return demand_fn, lam
801
+
802
+
803
+ # ============================================================================
804
+ # Build example simulation (200-item variant)
805
+ # ============================================================================
806
+
807
+ def build_example_simulation_from_adjacency(
808
+ seed=123,
809
+ horizon_days=52560,
810
+ pipeline_multiplier=3.0,
811
+ streaming_out_dir=None,
812
+ packing="greedy",
813
+ ):
814
+ """
815
+ Day-level supply chain: 1 step = 1 day, 52560 days = 144 years.
816
+ Multi-echelon design:
817
+ Sources: magic replenishment (factory)
818
+ Intermediate nodes: pull from upstream via network edges
819
+ Per-tier (s,S) calibrated to flow rate and upstream lead time
820
+ """
821
+ random.seed(2025)
822
+ item_ids = [f"I{i:03d}" for i in range(1, 201)] # 200 items
823
+ items = {iid: Item(iid, round(random.uniform(1.0, 4.0), 2))
824
+ for iid in item_ids}
825
+
826
+ def make_pol(inv_base=4000, inv_var=500,
827
+ s_base=600, s_var=100,
828
+ S_base=6000, S_var=500,
829
+ lt_mean=5, lt_var=1):
830
+ inv, s, S, lt = {}, {}, {}, {}
831
+ for iid in item_ids:
832
+ si = max(0, int(round(s_base + random.uniform(-s_var, s_var))))
833
+ Si = max(si + 1, int(round(
834
+ S_base + random.uniform(-S_var, S_var))))
835
+ ii = int(round(inv_base + random.uniform(-inv_var, inv_var)))
836
+ ii = max(si, min(Si, max(0, ii)))
837
+ li = max(1, int(round(
838
+ lt_mean + random.uniform(-lt_var, lt_var))))
839
+ s[iid], S[iid], inv[iid], lt[iid] = si, Si, ii, li
840
+ return inv, s, S, lt
841
+
842
+ nodes_meta = [
843
+ {"id": "NewYork", "lat": 40.7128, "lon": -74.0060,
844
+ "is_destination": True,
845
+ "inventory": {iid: 600 for iid in item_ids},
846
+ "backlog": {iid: 0 for iid in item_ids}},
847
+
848
+ {"id": "SanFrancisco", "lat": 37.7749, "lon": -122.4194,
849
+ "is_source": True},
850
+ {"id": "StLouis", "lat": 38.6270, "lon": -90.1994,
851
+ "is_source": True},
852
+ {"id": "Orlando", "lat": 28.5383, "lon": -81.3792,
853
+ "is_source": True},
854
+ {"id": "Nashville", "lat": 36.1627, "lon": -86.7816},
855
+ {"id": "Atlanta", "lat": 33.7490, "lon": -84.3880},
856
+ {"id": "Chicago", "lat": 41.8781, "lon": -87.6298},
857
+ {"id": "Charlotte", "lat": 35.2271, "lon": -80.8431},
858
+ {"id": "Columbus", "lat": 39.9612, "lon": -82.9988},
859
+ {"id": "Richmond", "lat": 37.5407, "lon": -77.4360},
860
+ {"id": "Philadelphia", "lat": 39.9526, "lon": -75.1652},
861
+ {"id": "Baltimore", "lat": 39.2904, "lon": -76.6122},
862
+ {"id": "Memphis", "lat": 35.1495, "lon": -90.0490},
863
+ ]
864
+
865
+ tier_params = {
866
+ "SanFrancisco": dict(inv_base=4000, inv_var=400,
867
+ s_base=400, s_var=60,
868
+ S_base=4000, S_var=400,
869
+ lt_mean=3, lt_var=1),
870
+ "StLouis": dict(inv_base=4000, inv_var=400,
871
+ s_base=400, s_var=60,
872
+ S_base=4000, S_var=400,
873
+ lt_mean=3, lt_var=1),
874
+ "Orlando": dict(inv_base=4000, inv_var=400,
875
+ s_base=400, s_var=60,
876
+ S_base=4000, S_var=400,
877
+ lt_mean=3, lt_var=1),
878
+ "Nashville": dict(inv_base=8000, inv_var=800,
879
+ s_base=1000, s_var=150,
880
+ S_base=8000, S_var=800,
881
+ lt_mean=3, lt_var=1),
882
+ "Atlanta": dict(inv_base=6000, inv_var=600,
883
+ s_base=500, s_var=80,
884
+ S_base=6000, S_var=600,
885
+ lt_mean=1, lt_var=0),
886
+ "Chicago": dict(inv_base=5000, inv_var=500,
887
+ s_base=1000, s_var=150,
888
+ S_base=5000, S_var=500,
889
+ lt_mean=8, lt_var=1),
890
+ "Charlotte": dict(inv_base=5000, inv_var=500,
891
+ s_base=1000, s_var=150,
892
+ S_base=5000, S_var=500,
893
+ lt_mean=7, lt_var=1),
894
+ "Memphis": dict(inv_base=3000, inv_var=300,
895
+ s_base=500, s_var=80,
896
+ S_base=3000, S_var=300,
897
+ lt_mean=7, lt_var=1),
898
+ "Columbus": dict(inv_base=4000, inv_var=400,
899
+ s_base=500, s_var=80,
900
+ S_base=4000, S_var=400,
901
+ lt_mean=2, lt_var=0),
902
+ "Richmond": dict(inv_base=4000, inv_var=400,
903
+ s_base=500, s_var=80,
904
+ S_base=4000, S_var=400,
905
+ lt_mean=2, lt_var=0),
906
+ "Philadelphia": dict(inv_base=3000, inv_var=300,
907
+ s_base=500, s_var=80,
908
+ S_base=3000, S_var=300,
909
+ lt_mean=1, lt_var=0),
910
+ "Baltimore": dict(inv_base=3000, inv_var=300,
911
+ s_base=500, s_var=80,
912
+ S_base=3000, S_var=300,
913
+ lt_mean=2, lt_var=0),
914
+ }
915
+
916
+ for m in nodes_meta:
917
+ nid = m["id"]
918
+ if m.get("is_destination", False):
919
+ continue
920
+ inv, s, S, lt = make_pol(**tier_params[nid])
921
+ m["inventory"] = inv
922
+ m["s_levels"] = s
923
+ m["S_levels"] = S
924
+ m["lead_time_mean"] = lt
925
+
926
+ n = len(nodes_meta)
927
+ adj = [[None]*n for _ in range(n)]
928
+ idx = {m["id"]: i for i, m in enumerate(nodes_meta)}
929
+
930
+ def se(u, v, tt, cv, nc):
931
+ adj[idx[u]][idx[v]] = (tt, cv, nc)
932
+
933
+ # Travel times in days; upstream capacity generous (not bottleneck)
934
+ se("SanFrancisco", "Nashville", 4, 20000.0, 3)
935
+ se("StLouis", "Nashville", 2, 20000.0, 3)
936
+ se("Orlando", "Nashville", 2, 20000.0, 3)
937
+ se("Nashville", "Atlanta", 1, 60000.0, 3)
938
+ se("Atlanta", "Chicago", 8, 16000.0, 3)
939
+ se("Atlanta", "Charlotte", 7, 16000.0, 3)
940
+ se("Atlanta", "Memphis", 7, 16000.0, 3)
941
+ se("Chicago", "Columbus", 2, 16000.0, 3)
942
+ se("Charlotte", "Richmond", 2, 16000.0, 3)
943
+ se("Columbus", "Philadelphia", 2, 16000.0, 3)
944
+ se("Richmond", "Philadelphia", 1, 16000.0, 3)
945
+ se("Richmond", "Baltimore", 3, 12000.0, 3)
946
+ se("Columbus", "Baltimore", 3, 12000.0, 3)
947
+ se("Memphis", "Baltimore", 2, 12000.0, 3)
948
+ # Last-mile: placeholder, overwritten dynamically below
949
+ se("Philadelphia", "NewYork", 1, 4000.0, 3)
950
+ se("Baltimore", "NewYork", 2, 4000.0, 3)
951
+
952
+ net = build_network_from_adjacency(nodes_meta, adj)
953
+
954
+ print("Building day-level demand signals (200 items)...")
955
+ demand_fn, demand_signals = build_demand_fn(
956
+ item_ids, horizon_days, seed,
957
+ base_lambda_range=(80, 250))
958
+ print(f" Shape: {demand_signals.shape} "
959
+ f"({horizon_days} days ≈ {horizon_days/365:.0f} years)")
960
+ print(f" Lambda: mean={demand_signals.mean():.1f} "
961
+ f"min={demand_signals.min():.1f} "
962
+ f"max={demand_signals.max():.1f}")
963
+
964
+ actual_mean_lam = float(demand_signals.mean())
965
+ avg_vol = 2.5
966
+ n_items = len(item_ids)
967
+ total_demand_vol = n_items * actual_mean_lam * avg_vol
968
+ target_ratio = 1.20
969
+ packing_eff = 0.93
970
+ raw_needed = total_demand_vol * target_ratio / packing_eff
971
+ philadelphia_cv = round(raw_needed * 0.55 / 3 / 100) * 100
972
+ baltimore_cv = round(raw_needed * 0.45 / 3 / 100) * 100
973
+ net.edges[("Philadelphia", "NewYork")].container_volume = float(philadelphia_cv)
974
+ net.edges[("Baltimore", "NewYork")].container_volume = float(baltimore_cv)
975
+ for eid in [("Philadelphia", "NewYork"), ("Baltimore", "NewYork")]:
976
+ e = net.edges[eid]
977
+ net.weight_cache[eid] = e.travel_time_days / max(
978
+ e.daily_total_capacity, 1e-9)
979
+ last_mile_cap = (philadelphia_cv + baltimore_cv) * 3
980
+ print(f" Demand vol: {total_demand_vol:.0f}/day "
981
+ f"Last-mile cap: {last_mile_cap:.0f}/day "
982
+ f"Ratio: {last_mile_cap/total_demand_vol:.1%} "
983
+ f"(Philadelphia={philadelphia_cv:.0f} Baltimore={baltimore_cv:.0f})")
984
+
985
+ sim = SupplyChainSimulation(
986
+ network=net,
987
+ items=items,
988
+ destination_id="NewYork",
989
+ demand_fn=demand_fn,
990
+ horizon_days=horizon_days,
991
+ seed=seed,
992
+ pipeline_multiplier=pipeline_multiplier,
993
+ streaming_out_dir=streaming_out_dir,
994
+ packing=packing)
995
+
996
+ return sim, net, items, demand_signals
997
+
998
+
999
+ # ============================================================================
1000
+ # Folium helpers
1001
+ # ============================================================================
1002
+
1003
+ def plot_network_folium(network, tiles="cartodbpositron", zoom_start=5):
1004
+ lats = [n.lat for n in network.nodes.values()]
1005
+ lons = [n.lon for n in network.nodes.values()]
1006
+ c = (sum(lats)/len(lats), sum(lons)/len(lons))
1007
+ m = folium.Map(location=c, zoom_start=zoom_start, tiles=tiles)
1008
+ for (u, v), e in network.edges.items():
1009
+ nu, nv = network.nodes[u], network.nodes[v]
1010
+ folium.PolyLine(
1011
+ [(nu.lat, nu.lon), (nv.lat, nv.lon)],
1012
+ weight=2, opacity=0.7,
1013
+ tooltip=f"{u}→{v} t={e.travel_time_days:.0f}d "
1014
+ f"cap={e.daily_total_capacity:.0f}/day").add_to(m)
1015
+ for nid, n in network.nodes.items():
1016
+ folium.CircleMarker(
1017
+ (n.lat, n.lon),
1018
+ radius=6 if n.is_destination else 4, fill=True,
1019
+ tooltip=f"{nid} dest={n.is_destination}").add_to(m)
1020
+ return m
1021
+
1022
+
1023
+ def export_map_with_animation(
1024
+ network, sdf, inv_df=None, bl_df=None, it_df=None,
1025
+ out_html="supply_chain_map.html", start_date="2025-01-01",
1026
+ zoom_start=5,
1027
+ ):
1028
+ m = plot_network_folium(network, zoom_start=zoom_start)
1029
+ m.save(out_html)
1030
+ return out_html
1031
+
1032
+
1033
+ # ============================================================================
1034
+ # CLI
1035
+ # ============================================================================
1036
+
1037
+ if __name__ == "__main__":
1038
+ import argparse
1039
+ ap = argparse.ArgumentParser(
1040
+ description="Supply Chain Simulation — day-level, 52560-step, "
1041
+ "200-item variant")
1042
+ ap.add_argument("--days", type=int, default=52560)
1043
+ ap.add_argument("--seed", type=int, default=2025)
1044
+ ap.add_argument("--out_dir", type=str, default="test_output")
1045
+ ap.add_argument("--pipeline_mult", type=float, default=0.0,
1046
+ help="Days of EMA demand to keep in pipeline. "
1047
+ "0 = reactive mode (backlog + 3-day buffer).")
1048
+ ap.add_argument("--no_streaming", action="store_true")
1049
+ args = ap.parse_args()
1050
+
1051
+ streaming = not args.no_streaming and args.days > 500
1052
+
1053
+ print("=== Supply Chain Simulation (Multi-Echelon, 200 items) ===")
1054
+ print(f" Days: {args.days:,} ({args.days/365:.1f} years) "
1055
+ f"Seed: {args.seed}")
1056
+ print(f" 1 step = 1 day | 365 = 1 year | 52560 = 144 years")
1057
+ print(f" Items: 200 Pipeline: {args.pipeline_mult} "
1058
+ f"Streaming: {streaming}")
1059
+ print()
1060
+
1061
+ sim, net, items, dsig = build_example_simulation_from_adjacency(
1062
+ seed=args.seed,
1063
+ horizon_days=args.days,
1064
+ pipeline_multiplier=args.pipeline_mult,
1065
+ streaming_out_dir=args.out_dir if streaming else None,
1066
+ packing="greedy")
1067
+
1068
+ dd, ds, svc, di, db, dt = sim.run()
1069
+
1070
+ os.makedirs(args.out_dir, exist_ok=True)
1071
+
1072
+ # Save demand signals
1073
+ print("Saving demand signals...")
1074
+ np.save(os.path.join(args.out_dir, "demand_signals.npy"),
1075
+ dsig[:args.days])
1076
+ with open(os.path.join(args.out_dir, "demand_signals_cols.txt"), "w") as f:
1077
+ f.write(",".join(sorted(items.keys())) + "\n")
1078
+ print(f" Saved shape={dsig[:args.days].shape}")
1079
+
1080
+ if not streaming:
1081
+ dd.to_csv(os.path.join(args.out_dir, "daily_records.csv"),
1082
+ index=False)
1083
+ ds.to_csv(os.path.join(args.out_dir, "shipments.csv"), index=False)
1084
+ svc.to_csv(os.path.join(args.out_dir, "service_summary.csv"),
1085
+ index=False)
1086
+ if args.days <= 500:
1087
+ di.to_csv(os.path.join(args.out_dir, "inventory_history.csv"),
1088
+ index=False)
1089
+ db.to_csv(os.path.join(args.out_dir, "backlog_history.csv"),
1090
+ index=False)
1091
+ dt.to_csv(os.path.join(args.out_dir, "intransit_history.csv"),
1092
+ index=False)
1093
+
1094
+ fr = svc['fill_rate_stock_only']
1095
+ print(f"\n=== Service Summary (200 items) ===")
1096
+ print(f" Fill rate: mean={fr.mean():.3f} "
1097
+ f"median={fr.median():.3f} "
1098
+ f"min={fr.min():.3f} max={fr.max():.3f}")
1099
+ print(f" Total demand: {svc['total_demand'].sum():,}")
1100
+ print(f" Total served: {svc['served_from_stock'].sum():,}")
1101
+ print(f" Total backlog: {svc['new_backlog_added'].sum():,}")
1102
+ print(f"\nOutputs → {args.out_dir}/")
simulator/Supplychaingeo_item50.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supply Chain Simulation — Day-level, 52560-step.
3
+
4
+ 1 step = 1 day | 365 steps = 1 year | 52560 steps = 144 years
5
+
6
+ logic:
7
+ - (s,S) inventory policy at warehouses
8
+ - Dijkstra routing (weight = travel_time / daily_capacity)
9
+ - Greedy first-fit bin packing
10
+ - Proactive shipping via pipeline_multiplier
11
+ - Streaming CSV for large runs
12
+
13
+ Multi-echelon:
14
+ - Only source nodes (SF, StLouis, Orlando) retain magic replenishment
15
+ - Intermediate nodes pull inventory from upstream via network edges
16
+ - Per-tier (s,S) parameters calibrated to demand flow
17
+ """
18
+
19
+ from __future__ import annotations
20
+ from dataclasses import dataclass, field
21
+ from typing import Dict, List, Tuple, Optional, Any, Callable
22
+ import math
23
+ import heapq
24
+ import random
25
+ import os
26
+ import csv
27
+ import time as _time_module
28
+ from datetime import datetime, timedelta
29
+ import io
30
+ import base64
31
+
32
+ try:
33
+ import matplotlib.pyplot as plt
34
+ from branca.element import Element
35
+ except ImportError:
36
+ pass
37
+
38
+ import numpy as np
39
+ import pandas as pd
40
+
41
+ try:
42
+ import folium
43
+ from folium.plugins import TimestampedGeoJson
44
+ HAS_FOLIUM = True
45
+ except ImportError:
46
+ HAS_FOLIUM = False
47
+
48
+
49
+ # ============================================================================
50
+ # Data Structures
51
+ # ============================================================================
52
+
53
+ @dataclass
54
+ class Item:
55
+ item_id: str
56
+ volume: float
57
+
58
+
59
+ @dataclass
60
+ class Edge:
61
+ u: str
62
+ v: str
63
+ travel_time_days: float
64
+ container_volume: float
65
+ num_containers_per_day: int
66
+ daily_containers: List[float] = field(default_factory=list)
67
+
68
+ def reset_daily(self, capacity_factor: float = 1.0) -> None:
69
+ effective_vol = self.container_volume * capacity_factor
70
+ self.daily_containers = [effective_vol] * self.num_containers_per_day
71
+
72
+ def find_container_slot(self, item_volume: float) -> Optional[int]:
73
+ for idx, rem in enumerate(self.daily_containers):
74
+ if rem >= item_volume:
75
+ return idx
76
+ return None
77
+
78
+ def allocate_in_container(self, idx: int, item_volume: float) -> bool:
79
+ if 0 <= idx < len(self.daily_containers) and \
80
+ self.daily_containers[idx] >= item_volume:
81
+ self.daily_containers[idx] -= item_volume
82
+ return True
83
+ return False
84
+
85
+ @property
86
+ def daily_total_capacity(self) -> float:
87
+ return self.container_volume * max(self.num_containers_per_day, 0)
88
+
89
+
90
+ @dataclass
91
+ class Node:
92
+ node_id: str
93
+ lat: float
94
+ lon: float
95
+ is_destination: bool = False
96
+ is_source: bool = False
97
+ inventory: Dict[str, int] = field(default_factory=dict)
98
+ s_levels: Dict[str, int] = field(default_factory=dict)
99
+ S_levels: Dict[str, int] = field(default_factory=dict)
100
+ lead_time_mean: Dict[str, float] = field(default_factory=dict)
101
+ lead_time_std_frac: float = 0.2
102
+ outstanding_orders: Dict[str, Optional[Tuple[int, int]]] = \
103
+ field(default_factory=dict)
104
+ backlog: Dict[str, int] = field(default_factory=dict)
105
+
106
+ def receive_orders_today(self, day: int) -> None:
107
+ to_clear = []
108
+ for item_id, order in self.outstanding_orders.items():
109
+ if order is None:
110
+ continue
111
+ arrival_day, qty = order
112
+ if day >= arrival_day:
113
+ self.inventory[item_id] = \
114
+ self.inventory.get(item_id, 0) + qty
115
+ to_clear.append(item_id)
116
+ for iid in to_clear:
117
+ self.outstanding_orders[iid] = None
118
+
119
+ def maybe_place_orders(self, day: int, rng: random.Random) -> None:
120
+ if self.is_destination:
121
+ return
122
+ if not self.is_source:
123
+ return
124
+ for item_id, s in self.s_levels.items():
125
+ on_hand = self.inventory.get(item_id, 0)
126
+ if on_hand < s and \
127
+ self.outstanding_orders.get(item_id) is None:
128
+ S = self.S_levels.get(item_id, on_hand)
129
+ qty = max(S - on_hand, 0)
130
+ if qty <= 0:
131
+ continue
132
+ mean_lt = max(self.lead_time_mean.get(item_id, 1.0), 0.1)
133
+ std = self.lead_time_std_frac * mean_lt
134
+ sampled = rng.normalvariate(mean_lt, std)
135
+ lt_days = max(1, int(math.ceil(sampled)))
136
+ self.outstanding_orders[item_id] = (day + lt_days, qty)
137
+
138
+
139
+ # ============================================================================
140
+ # Network + Dijkstra
141
+ # ============================================================================
142
+
143
+ class Network:
144
+ def __init__(self) -> None:
145
+ self.nodes: Dict[str, Node] = {}
146
+ self.edges: Dict[Tuple[str, str], Edge] = {}
147
+ self.adj: Dict[str, List[str]] = {}
148
+ self.weight_cache: Dict[Tuple[str, str], float] = {}
149
+ self.paths_to_dest: Dict[
150
+ str, Tuple[float, List[str], List[Tuple[str, str]]]] = {}
151
+
152
+ def add_node(self, node: Node) -> None:
153
+ self.nodes[node.node_id] = node
154
+ self.adj.setdefault(node.node_id, [])
155
+
156
+ def add_edge(self, edge: Edge) -> None:
157
+ self.edges[(edge.u, edge.v)] = edge
158
+ self.adj.setdefault(edge.u, []).append(edge.v)
159
+ cap = max(edge.daily_total_capacity, 1e-9)
160
+ self.weight_cache[(edge.u, edge.v)] = edge.travel_time_days / cap
161
+
162
+ def reset_daily_edges(self, capacity_factor: float = 1.0) -> None:
163
+ for e in self.edges.values():
164
+ e.reset_daily(capacity_factor)
165
+
166
+ def dijkstra(self, source: str, target: str) -> Tuple[float, List[str]]:
167
+ if source == target:
168
+ return 0.0, [source]
169
+ dist: Dict[str, float] = {source: 0.0}
170
+ prev: Dict[str, Optional[str]] = {source: None}
171
+ pq = [(0.0, source)]
172
+ visited: set = set()
173
+ while pq:
174
+ d, u = heapq.heappop(pq)
175
+ if u in visited:
176
+ continue
177
+ visited.add(u)
178
+ if u == target:
179
+ break
180
+ for v in self.adj.get(u, []):
181
+ w = self.weight_cache[(u, v)]
182
+ nd = d + w
183
+ if nd < dist.get(v, float('inf')):
184
+ dist[v] = nd
185
+ prev[v] = u
186
+ heapq.heappush(pq, (nd, v))
187
+ if target not in dist:
188
+ return float('inf'), []
189
+ path: List[str] = []
190
+ cur: Optional[str] = target
191
+ while cur is not None:
192
+ path.append(cur)
193
+ cur = prev.get(cur)
194
+ path.reverse()
195
+ return dist[target], path
196
+
197
+ def compute_paths_to_destination(self, dest_id: str) -> None:
198
+ self.paths_to_dest.clear()
199
+ for nid in self.nodes:
200
+ if nid == dest_id:
201
+ self.paths_to_dest[nid] = (0.0, [nid], [])
202
+ continue
203
+ d, pn = self.dijkstra(nid, dest_id)
204
+ if not pn:
205
+ self.paths_to_dest[nid] = (float('inf'), [], [])
206
+ else:
207
+ pe = [(pn[i], pn[i+1]) for i in range(len(pn)-1)]
208
+ self.paths_to_dest[nid] = (d, pn, pe)
209
+
210
+
211
+ # ============================================================================
212
+ # Greedy First-Fit Bin Packing
213
+ # ============================================================================
214
+
215
+ def allocate_units_along_path_greedy(
216
+ item: Item,
217
+ max_units: int,
218
+ path_edges: List[Tuple[str, str]],
219
+ network_edges: Dict[Tuple[str, str], Edge],
220
+ ) -> int:
221
+ if max_units <= 0 or not path_edges:
222
+ return 0
223
+ edges_objs = [network_edges[eid] for eid in path_edges]
224
+ placed = 0
225
+ ivol = item.volume
226
+ for _ in range(max_units):
227
+ slots: List[Tuple[Edge, int]] = []
228
+ ok = True
229
+ for e in edges_objs:
230
+ si = e.find_container_slot(ivol)
231
+ if si is None:
232
+ ok = False
233
+ break
234
+ slots.append((e, si))
235
+ if not ok:
236
+ break
237
+ good = True
238
+ for e, si in slots:
239
+ if not e.allocate_in_container(si, ivol):
240
+ good = False
241
+ break
242
+ if not good:
243
+ break
244
+ placed += 1
245
+ return placed
246
+
247
+
248
+ # ============================================================================
249
+ # Simulation Engine
250
+ # ============================================================================
251
+
252
+ class SupplyChainSimulation:
253
+
254
+ def __init__(
255
+ self,
256
+ network: Network,
257
+ items: Dict[str, Item],
258
+ destination_id: str,
259
+ demand_fn: Callable[[int], Dict[str, int]],
260
+ horizon_days: int,
261
+ seed: int = 42,
262
+ pipeline_multiplier: float = 0.0,
263
+ streaming_out_dir: Optional[str] = None,
264
+ packing: str = "greedy",
265
+ ) -> None:
266
+ assert destination_id in network.nodes and \
267
+ network.nodes[destination_id].is_destination
268
+ self.network = network
269
+ self.items = items
270
+ self.item_order: List[str] = sorted(self.items.keys())
271
+ self.round_robin_items: bool = True
272
+ self.per_item_daily_cap_units: Optional[int] = None
273
+
274
+ self.destination_id = destination_id
275
+ self.demand_fn = demand_fn
276
+ self.horizon_days = horizon_days
277
+ self.rng = random.Random(seed)
278
+ self.packing = packing
279
+ self.pipeline_multiplier = pipeline_multiplier
280
+
281
+ # EMA warm-start at approximate per-day mean demand
282
+ self.demand_ema: Dict[str, float] = {iid: 165.0 for iid in items}
283
+ self.ema_alpha = 0.05
284
+ self.item_intransit: Dict[str, int] = {iid: 0 for iid in items}
285
+
286
+ self.streaming_out_dir = streaming_out_dir
287
+ self._csv_files: Dict = {}
288
+ self._csv_writers: Dict = {}
289
+ self._csv_buffers: Dict[str, list] = {}
290
+ self.dest_in_transit: Dict[int, Dict[str, int]] = {}
291
+
292
+ self.daily_records: list = []
293
+ self.shipments_log: list = []
294
+ self.inventory_history: list = []
295
+ self.backlog_history: list = []
296
+ self.intransit_history: list = []
297
+
298
+ self.svc_demand: Dict[str, int] = {iid: 0 for iid in items}
299
+ self.svc_served: Dict[str, int] = {iid: 0 for iid in items}
300
+ self.svc_backlog: Dict[str, int] = {iid: 0 for iid in items}
301
+
302
+ network.compute_paths_to_destination(destination_id)
303
+ self.sorted_warehouses = sorted(
304
+ [(nid, d) for nid, (d, _, _) in network.paths_to_dest.items()
305
+ if nid != destination_id and math.isfinite(d)],
306
+ key=lambda x: x[1])
307
+
308
+ # ── precompute supplier relationships ──────────────
309
+ self.node_suppliers: Dict[
310
+ str,
311
+ List[Tuple[str, float, List[str], List[Tuple[str, str]]]]
312
+ ] = {}
313
+ for nid, node in network.nodes.items():
314
+ if node.is_destination or node.is_source:
315
+ continue
316
+ suppliers = []
317
+ for (u, v) in network.edges:
318
+ if v == nid:
319
+ d, path = network.dijkstra(u, nid)
320
+ if path and math.isfinite(d):
321
+ pe = [(path[i], path[i + 1])
322
+ for i in range(len(path) - 1)]
323
+ tt = sum(network.edges[e].travel_time_days
324
+ for e in pe)
325
+ suppliers.append((u, tt, path, pe))
326
+ suppliers.sort(key=lambda x: x[1])
327
+ self.node_suppliers[nid] = suppliers
328
+
329
+ # intermediate nodes ordered upstream-first for replenishment
330
+ self.replenish_order: List[str] = [
331
+ nid for nid, _ in reversed(self.sorted_warehouses)
332
+ if nid in self.node_suppliers]
333
+
334
+ self.avg_travel_time = 6.0
335
+ if self.sorted_warehouses:
336
+ _, _, pe = network.paths_to_dest[self.sorted_warehouses[0][0]]
337
+ if pe:
338
+ self.avg_travel_time = sum(
339
+ network.edges[e].travel_time_days for e in pe)
340
+
341
+ def _total_tt(self, pe):
342
+ return sum(self.network.edges[e].travel_time_days for e in pe)
343
+
344
+ def _replenish_warehouses(self, day: int) -> None:
345
+ """Inter-warehouse replenishment: intermediate nodes pull
346
+ from upstream suppliers using (s,S) trigger + edge capacity."""
347
+ net = self.network
348
+
349
+ # round-robin item fairness
350
+ if self.round_robin_items and self.item_order:
351
+ k = day % len(self.item_order)
352
+ ids_today = self.item_order[k:] + self.item_order[:k]
353
+ else:
354
+ ids_today = self.item_order[:]
355
+
356
+ for nid in self.replenish_order:
357
+ node = net.nodes[nid]
358
+ for iid in ids_today:
359
+ on_hand = node.inventory.get(iid, 0)
360
+ s = node.s_levels.get(iid, 0)
361
+ if on_hand >= s:
362
+ continue
363
+ if node.outstanding_orders.get(iid) is not None:
364
+ continue
365
+ S = node.S_levels.get(iid, on_hand)
366
+ qty_needed = max(S - on_hand, 0)
367
+ if qty_needed <= 0:
368
+ continue
369
+
370
+ for sup_id, tt, path, pe in \
371
+ self.node_suppliers.get(nid, []):
372
+ sup_node = net.nodes[sup_id]
373
+ avail = sup_node.inventory.get(iid, 0)
374
+ if avail <= 0:
375
+ continue
376
+ attempt = min(avail, qty_needed)
377
+ placed = allocate_units_along_path_greedy(
378
+ self.items[iid], attempt, pe, net.edges)
379
+ if placed <= 0:
380
+ continue
381
+
382
+ sup_node.inventory[iid] -= placed
383
+ arr = day + max(1, int(math.ceil(tt)))
384
+ node.outstanding_orders[iid] = (arr, placed)
385
+
386
+ # log (reuses existing shipments infrastructure)
387
+ r = [day, arr, sup_id, nid, iid, placed,
388
+ str(path),
389
+ str([net.edges[e].travel_time_days
390
+ for e in pe])]
391
+ if self.streaming_out_dir:
392
+ self._csv_buffers.setdefault(
393
+ 'ship', []).append(r)
394
+ else:
395
+ self.shipments_log.append({
396
+ "day": day, "arrival_day": arr,
397
+ "from": sup_id, "to": nid,
398
+ "item": iid, "units": placed,
399
+ "path_nodes": path,
400
+ "edge_times": [
401
+ net.edges[e].travel_time_days
402
+ for e in pe]})
403
+ break
404
+
405
+ def step(self, day: int) -> None:
406
+ net = self.network
407
+ dest = net.nodes[self.destination_id]
408
+
409
+ # 1) Receive (s,S) replenishment at warehouses
410
+ for node in net.nodes.values():
411
+ node.receive_orders_today(day)
412
+
413
+ # 2) Arrivals at destination
414
+ arrivals = self.dest_in_transit.pop(day, {})
415
+ for iid, qty in arrivals.items():
416
+ self.item_intransit[iid] = max(
417
+ 0, self.item_intransit.get(iid, 0) - qty)
418
+ bl = dest.backlog.get(iid, 0)
419
+ if bl > 0:
420
+ use = min(qty, bl)
421
+ dest.backlog[iid] = bl - use
422
+ qty -= use
423
+ if qty > 0:
424
+ dest.inventory[iid] = dest.inventory.get(iid, 0) + qty
425
+
426
+ # 3) Reset edge containers
427
+ net.reset_daily_edges(1.0)
428
+
429
+ # 4) Demand at destination
430
+ td = self.demand_fn(day)
431
+ for iid in self.items:
432
+ dq = int(td.get(iid, 0))
433
+ self.demand_ema[iid] = (
434
+ self.ema_alpha * dq +
435
+ (1 - self.ema_alpha) * self.demand_ema[iid])
436
+ oh = dest.inventory.get(iid, 0)
437
+ if oh >= dq:
438
+ served, unfilled = dq, 0
439
+ dest.inventory[iid] = oh - dq
440
+ else:
441
+ served, unfilled = oh, dq - oh
442
+ dest.inventory[iid] = 0
443
+ dest.backlog[iid] = dest.backlog.get(iid, 0) + unfilled
444
+ self.svc_demand[iid] += dq
445
+ self.svc_served[iid] += served
446
+ self.svc_backlog[iid] += unfilled
447
+ rec = [day, iid, dq, served, unfilled,
448
+ dest.inventory.get(iid, 0), dest.backlog.get(iid, 0)]
449
+ if self.streaming_out_dir:
450
+ self._csv_buffers.setdefault('daily', []).append(rec)
451
+ else:
452
+ self.daily_records.append({
453
+ "day": day, "item": iid, "demand": dq,
454
+ "served_from_stock": served,
455
+ "new_backlog_today": unfilled,
456
+ "dest_on_hand_end_before_ship":
457
+ dest.inventory.get(iid, 0),
458
+ "dest_backlog_end_before_ship":
459
+ dest.backlog.get(iid, 0)})
460
+
461
+ # 5) Ship
462
+ if self.round_robin_items and self.item_order:
463
+ k = day % len(self.item_order)
464
+ ids_today = self.item_order[k:] + self.item_order[:k]
465
+ else:
466
+ ids_today = self.item_order[:]
467
+
468
+ for iid in ids_today:
469
+ item = self.items[iid]
470
+ cb = dest.backlog.get(iid, 0)
471
+ it = self.item_intransit.get(iid, 0)
472
+ oh = dest.inventory.get(iid, 0)
473
+
474
+ if self.pipeline_multiplier > 0:
475
+ pt = self.demand_ema[iid] * self.pipeline_multiplier
476
+ ship_target = max(0, int(math.ceil(cb + pt - it - oh)))
477
+ else:
478
+ # Reactive: ship to cover backlog + 3 days of EMA demand
479
+ S_dest = max(1, int(self.demand_ema[iid] * 3))
480
+ ship_target = max(0, cb + S_dest - oh - it)
481
+
482
+ if ship_target <= 0:
483
+ continue
484
+
485
+ remaining = ship_target
486
+ shipped = 0
487
+ for wid, _ in self.sorted_warehouses:
488
+ if remaining <= 0:
489
+ break
490
+ if self.per_item_daily_cap_units is not None and \
491
+ shipped >= self.per_item_daily_cap_units:
492
+ break
493
+ wn = net.nodes[wid]
494
+ avail = wn.inventory.get(iid, 0)
495
+ if avail <= 0:
496
+ continue
497
+ _, pn, pe = net.paths_to_dest[wid]
498
+ if not pe:
499
+ continue
500
+ attempt = min(avail, remaining)
501
+ if self.per_item_daily_cap_units is not None:
502
+ attempt = min(attempt,
503
+ self.per_item_daily_cap_units - shipped)
504
+ placed = allocate_units_along_path_greedy(
505
+ item, attempt, pe, net.edges)
506
+ if placed <= 0:
507
+ continue
508
+ wn.inventory[iid] -= placed
509
+ remaining -= placed
510
+ shipped += placed
511
+ arr = day + max(1, int(math.ceil(self._total_tt(pe))))
512
+ self.dest_in_transit.setdefault(arr, {})
513
+ self.dest_in_transit[arr][iid] = \
514
+ self.dest_in_transit[arr].get(iid, 0) + placed
515
+ self.item_intransit[iid] = \
516
+ self.item_intransit.get(iid, 0) + placed
517
+ r = [day, arr, wid, self.destination_id, iid, placed,
518
+ str(pn),
519
+ str([net.edges[e].travel_time_days for e in pe])]
520
+ if self.streaming_out_dir:
521
+ self._csv_buffers.setdefault('ship', []).append(r)
522
+ else:
523
+ self.shipments_log.append({
524
+ "day": day, "arrival_day": arr,
525
+ "from": wid, "to": self.destination_id,
526
+ "item": iid, "units": placed,
527
+ "path_nodes": pn,
528
+ "edge_times": [net.edges[e].travel_time_days
529
+ for e in pe]})
530
+
531
+ # 5b) Inter-warehouse replenishment
532
+ self._replenish_warehouses(day)
533
+
534
+ # 6) (s,S) orders at source warehouses
535
+ for node in net.nodes.values():
536
+ node.maybe_place_orders(day, self.rng)
537
+
538
+ # 7) Snapshots
539
+ did = self.destination_id
540
+ itc: Dict[str, int] = {}
541
+ for ad, im in self.dest_in_transit.items():
542
+ if ad > day:
543
+ for iid, q in im.items():
544
+ itc[iid] = itc.get(iid, 0) + int(q)
545
+
546
+ if self.streaming_out_dir:
547
+ for node in net.nodes.values():
548
+ for iid in self.items:
549
+ self._csv_buffers.setdefault('inv', []).append(
550
+ [day, node.node_id, iid,
551
+ int(node.inventory.get(iid, 0))])
552
+ self._csv_buffers.setdefault('bl', []).append(
553
+ [day, node.node_id, iid,
554
+ int(node.backlog.get(iid, 0))])
555
+ for iid in self.items:
556
+ self._csv_buffers.setdefault('it', []).append(
557
+ [day, did, iid, itc.get(iid, 0)])
558
+ else:
559
+ for node in net.nodes.values():
560
+ for iid in self.items:
561
+ self.inventory_history.append({
562
+ "day": day, "node": node.node_id, "item": iid,
563
+ "on_hand": int(node.inventory.get(iid, 0))})
564
+ self.backlog_history.append({
565
+ "day": day, "node": node.node_id, "item": iid,
566
+ "backlog": int(node.backlog.get(iid, 0))})
567
+ for iid in self.items:
568
+ self.intransit_history.append({
569
+ "day": day, "node": did, "item": iid,
570
+ "in_transit": itc.get(iid, 0)})
571
+
572
+ # --- CSV streaming ---
573
+ def _open_csv_files(self):
574
+ os.makedirs(self.streaming_out_dir, exist_ok=True)
575
+ hdrs = {
576
+ 'daily': ["day", "item", "demand", "served_from_stock",
577
+ "new_backlog_today", "dest_on_hand_end_before_ship",
578
+ "dest_backlog_end_before_ship"],
579
+ 'ship': ["day", "arrival_day", "from", "to", "item",
580
+ "units", "path_nodes", "edge_times"],
581
+ 'inv': ["day", "node", "item", "on_hand"],
582
+ 'bl': ["day", "node", "item", "backlog"],
583
+ 'it': ["day", "node", "item", "in_transit"],
584
+ }
585
+ fn = {'daily': 'daily_records', 'ship': 'shipments',
586
+ 'inv': 'inventory_history', 'bl': 'backlog_history',
587
+ 'it': 'intransit_history'}
588
+ for k in hdrs:
589
+ f = open(os.path.join(self.streaming_out_dir,
590
+ f"{fn[k]}.csv"),
591
+ "w", newline="", buffering=65536)
592
+ w = csv.writer(f)
593
+ w.writerow(hdrs[k])
594
+ self._csv_files[k] = f
595
+ self._csv_writers[k] = w
596
+ self._csv_buffers[k] = []
597
+
598
+ def _flush_csv(self):
599
+ for k in self._csv_buffers:
600
+ if self._csv_buffers[k] and k in self._csv_writers:
601
+ self._csv_writers[k].writerows(self._csv_buffers[k])
602
+ self._csv_buffers[k] = []
603
+ for f in self._csv_files.values():
604
+ f.flush()
605
+
606
+ def _close_csv(self):
607
+ self._flush_csv()
608
+ for f in self._csv_files.values():
609
+ f.close()
610
+
611
+ def run(self):
612
+ if self.streaming_out_dir:
613
+ self._open_csv_files()
614
+ t0 = _time_module.time()
615
+ ri = max(1, self.horizon_days // 100)
616
+ for day in range(self.horizon_days):
617
+ self.step(day)
618
+ if self.streaming_out_dir and day % 500 == 0:
619
+ self._flush_csv()
620
+ if day % 5000 == 0:
621
+ for k in [k for k in self.dest_in_transit if k < day]:
622
+ del self.dest_in_transit[k]
623
+ if day % ri == 0 or day == self.horizon_days - 1:
624
+ el = _time_module.time() - t0
625
+ pct = (day + 1) / self.horizon_days * 100
626
+ rate = (day + 1) / max(el, 0.001)
627
+ eta = (self.horizon_days - day - 1) / max(rate, 0.001)
628
+ print(f"\r Day {day+1:>6}/{self.horizon_days} "
629
+ f"({pct:5.1f}%) {rate:6.1f} days/s "
630
+ f"ETA {eta:6.0f}s", end="", flush=True)
631
+ print(f"\n Simulation complete in "
632
+ f"{_time_module.time()-t0:.1f}s")
633
+
634
+ if self.streaming_out_dir:
635
+ self._close_csv()
636
+ rows = []
637
+ for iid in sorted(self.items):
638
+ td = self.svc_demand[iid]
639
+ sv = self.svc_served[iid]
640
+ bl = self.svc_backlog[iid]
641
+ rows.append({
642
+ "item": iid, "total_demand": td,
643
+ "served_from_stock": sv,
644
+ "new_backlog_added": bl,
645
+ "fill_rate_stock_only":
646
+ round(sv / td, 6) if td > 0 else 0.0})
647
+ svc = pd.DataFrame(rows)
648
+ svc.to_csv(os.path.join(self.streaming_out_dir,
649
+ "service_summary.csv"), index=False)
650
+ return (pd.DataFrame(), pd.DataFrame(), svc,
651
+ pd.DataFrame(), pd.DataFrame(), pd.DataFrame())
652
+
653
+ dd = pd.DataFrame(self.daily_records)
654
+ ds = pd.DataFrame(self.shipments_log) if self.shipments_log \
655
+ else pd.DataFrame(columns=[
656
+ "day", "arrival_day", "from", "to",
657
+ "item", "units", "path_nodes", "edge_times"])
658
+ svc = dd.groupby("item", as_index=False).agg(
659
+ total_demand=("demand", "sum"),
660
+ served_from_stock=("served_from_stock", "sum"),
661
+ new_backlog_added=("new_backlog_today", "sum"))
662
+ svc["fill_rate_stock_only"] = (
663
+ svc["served_from_stock"] / svc["total_demand"]).fillna(0)
664
+ di = pd.DataFrame(self.inventory_history,
665
+ columns=["day", "node", "item", "on_hand"])
666
+ db = pd.DataFrame(self.backlog_history,
667
+ columns=["day", "node", "item", "backlog"])
668
+ dt = pd.DataFrame(self.intransit_history,
669
+ columns=["day", "node", "item", "in_transit"])
670
+ return dd, ds, svc, di, db, dt
671
+
672
+
673
+ # ============================================================================
674
+ # Build network from adjacency
675
+ # ============================================================================
676
+
677
+ def build_network_from_adjacency(nodes_meta, adjacency):
678
+ n = len(nodes_meta)
679
+ assert all(len(r) == n for r in adjacency)
680
+ net = Network()
681
+ for meta in nodes_meta:
682
+ net.add_node(Node(
683
+ node_id=meta["id"],
684
+ lat=float(meta["lat"]), lon=float(meta["lon"]),
685
+ is_destination=bool(meta.get("is_destination", False)),
686
+ is_source=bool(meta.get("is_source", False)),
687
+ inventory=dict(meta.get("inventory", {})),
688
+ s_levels=dict(meta.get("s_levels", {})),
689
+ S_levels=dict(meta.get("S_levels", {})),
690
+ lead_time_mean=dict(meta.get("lead_time_mean", {})),
691
+ backlog=dict(meta.get("backlog", {}))))
692
+ ids = [m["id"] for m in nodes_meta]
693
+ for i in range(n):
694
+ for j in range(n):
695
+ s = adjacency[i][j]
696
+ if s is None:
697
+ continue
698
+ tt, cv, nc = s
699
+ net.add_edge(Edge(
700
+ u=ids[i], v=ids[j],
701
+ travel_time_days=float(tt),
702
+ container_volume=float(cv),
703
+ num_containers_per_day=int(nc)))
704
+ return net
705
+
706
+
707
+ # ============================================================================
708
+ # Demand Generator — day-level, 52560-step
709
+ # ============================================================================
710
+
711
+ def build_demand_fn(
712
+ item_ids, n_steps, seed=42,
713
+ base_lambda_range=(80, 250),
714
+ scenario=None,
715
+ ):
716
+ """
717
+
718
+ Structure:
719
+ - Yearly cycle T=365 days (moderate amplitude, two harmonics)
720
+ - Weekly cycle T=7 days (small texture)
721
+ - AR(1) drift (dominant low-frequency, decade-scale)
722
+ - Per-item spikes (sustained 1-6 months, clear amplitude)
723
+ - Global macro events (rare, wide, correlated across items)
724
+ - Poisson sampling
725
+ """
726
+
727
+ rng = np.random.default_rng(seed)
728
+ n_items = len(item_ids)
729
+
730
+ sc = scenario or {}
731
+ phi_lo = float(sc.get("phi_lo", 0.9990))
732
+ phi_hi = float(sc.get("phi_hi", 0.9996))
733
+ shock_count_scale = float(sc.get("shock_count_scale", 1.0))
734
+ shock_height_scale = float(sc.get("shock_height_scale", 1.0))
735
+ seasonal_scale = float(sc.get("seasonal_scale", 1.0))
736
+ burst_rate_scale = float(sc.get("burst_rate_scale", 1.0))
737
+ burst_height_scale = float(sc.get("burst_height_scale", 1.0))
738
+
739
+ spy = 365 # yearly period
740
+ spw = 7 # weekly period
741
+
742
+ t = np.arange(n_steps, dtype=np.float64)
743
+ yearly_phase = 2 * np.pi * t / spy
744
+ weekly_phase = 2 * np.pi * (t % spw) / spw
745
+
746
+ lam = np.zeros((n_steps, n_items), dtype=np.float64)
747
+
748
+ # Global macro events: rare, long, meaningful
749
+ gs = np.zeros(n_steps)
750
+ base_n_global = rng.integers(5, 12)
751
+ n_global = max(0, int(round(float(base_n_global) * shock_count_scale)))
752
+ for _ in range(n_global):
753
+ si = int(rng.integers(0, n_steps))
754
+ dur = int(rng.integers(180, 1100))
755
+ end = min(si + dur, n_steps)
756
+ h = rng.uniform(0.20, 0.60) * shock_height_scale
757
+ for k in range(si, end):
758
+ p = (k - si) / max(dur, 1)
759
+ if p < 0.15:
760
+ gs[k] += h * (p / 0.15)
761
+ elif p < 0.75:
762
+ gs[k] += h
763
+ else:
764
+ gs[k] += h * (1.0 - (p - 0.75) / 0.25)
765
+
766
+ for j, iid in enumerate(item_ids):
767
+ base = rng.uniform(*base_lambda_range)
768
+
769
+ # Yearly: two harmonics so shape is not a perfect sine
770
+ ya1 = rng.uniform(0.12, 0.28) * seasonal_scale
771
+ ya2 = rng.uniform(0.04, 0.10) * seasonal_scale
772
+ yo = rng.uniform(0, 2 * np.pi)
773
+ yr = (ya1 * np.sin(yearly_phase + yo) +
774
+ ya2 * np.sin(2 * yearly_phase + yo * 0.7))
775
+
776
+ # Weekly: small texture only
777
+ wa = rng.uniform(0.04, 0.10)
778
+ wo = rng.uniform(0, 2 * np.pi)
779
+ wy = wa * np.sin(weekly_phase + wo)
780
+
781
+ # AR(1): dominant low-frequency component
782
+ ac = rng.uniform(phi_lo, phi_hi)
783
+ ar_std = rng.uniform(0.008, 0.018)
784
+ dr = np.zeros(n_steps)
785
+ dr[0] = rng.normal(0, 0.10)
786
+ for i in range(1, n_steps):
787
+ dr[i] = ac * dr[i-1] + rng.normal(0, ar_std)
788
+ dr = np.clip(dr, -0.60, 0.60)
789
+
790
+ # Per-item spikes
791
+ sr = rng.uniform(0.0002, 0.001) * burst_rate_scale
792
+ sm = rng.random(n_steps) < sr
793
+ sp = np.zeros(n_steps)
794
+ for si in np.where(sm)[0]:
795
+ dur = int(rng.integers(30, 180))
796
+ end = min(si + dur, n_steps)
797
+ h = rng.uniform(0.20, 0.70) * burst_height_scale
798
+ for k in range(si, end):
799
+ p = (k - si) / max(dur, 1)
800
+ if p < 0.15:
801
+ sp[k] += h * (p / 0.15)
802
+ elif p < 0.75:
803
+ sp[k] += h
804
+ else:
805
+ sp[k] += h * (1.0 - (p - 0.75) / 0.25)
806
+
807
+ gsens = rng.uniform(0.4, 1.2)
808
+ fac = 1.0 + yr + wy + dr + sp + gsens * gs
809
+ fac = np.clip(fac, 0.08, None)
810
+ lam[:, j] = base * fac
811
+
812
+ def demand_fn(day):
813
+ idx = day % n_steps
814
+ return {iid: int(rng.poisson(lam=max(0.01, lam[idx, j])))
815
+ for j, iid in enumerate(item_ids)}
816
+
817
+ return demand_fn, lam
818
+
819
+
820
+ # ============================================================================
821
+ # Build example simulation
822
+ # ============================================================================
823
+
824
+ def build_example_simulation_from_adjacency(
825
+ seed=123,
826
+ horizon_days=52560,
827
+ pipeline_multiplier=3.0,
828
+ streaming_out_dir=None,
829
+ packing="greedy",
830
+ scenario=None,
831
+ ):
832
+ """
833
+ Day-level supply chain: 1 step = 1 day, 52560 days = 144 years.
834
+
835
+ Multi-echelon design:
836
+ Sources (SF, StLouis, Orlando): magic replenishment (factory)
837
+ Intermediate nodes: pull from upstream via network edges
838
+ Per-tier (s,S) calibrated to flow rate and upstream lead time
839
+ """
840
+ random.seed(2025)
841
+ item_ids = [f"I{i:02d}" for i in range(1, 51)]
842
+ items = {iid: Item(iid, round(random.uniform(1.0, 4.0), 2))
843
+ for iid in item_ids}
844
+
845
+ sc_local = scenario or {}
846
+ ss_scale = float(sc_local.get("ss_scale", 1.0))
847
+ leadtime_scale = float(sc_local.get("leadtime_scale", 1.0))
848
+
849
+ def make_pol(inv_base=4000, inv_var=500,
850
+ s_base=600, s_var=100,
851
+ S_base=6000, S_var=500,
852
+ lt_mean=5, lt_var=1):
853
+ inv_base = inv_base * ss_scale
854
+ inv_var = inv_var * ss_scale
855
+ s_base = s_base * ss_scale
856
+ s_var = s_var * ss_scale
857
+ S_base = S_base * ss_scale
858
+ S_var = S_var * ss_scale
859
+ lt_mean = lt_mean * leadtime_scale
860
+ lt_var = lt_var * leadtime_scale
861
+ inv, s, S, lt = {}, {}, {}, {}
862
+ for iid in item_ids:
863
+ si = max(0, int(round(s_base + random.uniform(-s_var, s_var))))
864
+ Si = max(si + 1, int(round(
865
+ S_base + random.uniform(-S_var, S_var))))
866
+ ii = int(round(inv_base + random.uniform(-inv_var, inv_var)))
867
+ ii = max(si, min(Si, max(0, ii)))
868
+ li = max(1, int(round(
869
+ lt_mean + random.uniform(-lt_var, lt_var))))
870
+ s[iid], S[iid], inv[iid], lt[iid] = si, Si, ii, li
871
+ return inv, s, S, lt
872
+
873
+ nodes_meta = [
874
+ {"id": "NewYork", "lat": 40.7128, "lon": -74.0060,
875
+ "is_destination": True,
876
+ "inventory": {iid: 600 for iid in item_ids},
877
+ "backlog": {iid: 0 for iid in item_ids}},
878
+
879
+ {"id": "SanFrancisco", "lat": 37.7749, "lon": -122.4194,
880
+ "is_source": True},
881
+ {"id": "StLouis", "lat": 38.6270, "lon": -90.1994,
882
+ "is_source": True},
883
+ {"id": "Orlando", "lat": 28.5383, "lon": -81.3792,
884
+ "is_source": True},
885
+ {"id": "Nashville", "lat": 36.1627, "lon": -86.7816},
886
+ {"id": "Atlanta", "lat": 33.7490, "lon": -84.3880},
887
+ {"id": "Chicago", "lat": 41.8781, "lon": -87.6298},
888
+ {"id": "Charlotte", "lat": 35.2271, "lon": -80.8431},
889
+ {"id": "Memphis", "lat": 35.1495, "lon": -90.0490},
890
+ {"id": "Columbus", "lat": 39.9612, "lon": -82.9988},
891
+ {"id": "Richmond", "lat": 37.5407, "lon": -77.4360},
892
+ {"id": "Philadelphia", "lat": 39.9526, "lon": -75.1652},
893
+ {"id": "Baltimore", "lat": 39.2904, "lon": -76.6122},
894
+ ]
895
+
896
+ tier_params = {
897
+ # Sources — magic replenishment (factory production)
898
+ "SanFrancisco": dict(inv_base=4000, inv_var=400,
899
+ s_base=400, s_var=60,
900
+ S_base=4000, S_var=400,
901
+ lt_mean=3, lt_var=1),
902
+ "StLouis": dict(inv_base=4000, inv_var=400,
903
+ s_base=400, s_var=60,
904
+ S_base=4000, S_var=400,
905
+ lt_mean=3, lt_var=1),
906
+ "Orlando": dict(inv_base=4000, inv_var=400,
907
+ s_base=400, s_var=60,
908
+ S_base=4000, S_var=400,
909
+ lt_mean=3, lt_var=1),
910
+ # Tier-1
911
+ "Nashville": dict(inv_base=8000, inv_var=800,
912
+ s_base=1000, s_var=150,
913
+ S_base=8000, S_var=800,
914
+ lt_mean=3, lt_var=1),
915
+ # Tier-2
916
+ "Atlanta": dict(inv_base=6000, inv_var=600,
917
+ s_base=500, s_var=80,
918
+ S_base=6000, S_var=600,
919
+ lt_mean=1, lt_var=0),
920
+ # Tier-3
921
+ "Chicago": dict(inv_base=5000, inv_var=500,
922
+ s_base=1000, s_var=150,
923
+ S_base=5000, S_var=500,
924
+ lt_mean=8, lt_var=1),
925
+ "Charlotte": dict(inv_base=5000, inv_var=500,
926
+ s_base=1000, s_var=150,
927
+ S_base=5000, S_var=500,
928
+ lt_mean=7, lt_var=1),
929
+ "Memphis": dict(inv_base=3000, inv_var=300,
930
+ s_base=500, s_var=80,
931
+ S_base=3000, S_var=300,
932
+ lt_mean=7, lt_var=1),
933
+ # Tier-4
934
+ "Columbus": dict(inv_base=4000, inv_var=400,
935
+ s_base=500, s_var=80,
936
+ S_base=4000, S_var=400,
937
+ lt_mean=2, lt_var=0),
938
+ "Richmond": dict(inv_base=4000, inv_var=400,
939
+ s_base=500, s_var=80,
940
+ S_base=4000, S_var=400,
941
+ lt_mean=2, lt_var=0),
942
+ # Tier-5
943
+ "Philadelphia": dict(inv_base=3000, inv_var=300,
944
+ s_base=500, s_var=80,
945
+ S_base=3000, S_var=300,
946
+ lt_mean=1, lt_var=0),
947
+ "Baltimore": dict(inv_base=3000, inv_var=300,
948
+ s_base=500, s_var=80,
949
+ S_base=3000, S_var=300,
950
+ lt_mean=2, lt_var=0),
951
+ }
952
+
953
+ for m in nodes_meta:
954
+ nid = m["id"]
955
+ if m.get("is_destination", False):
956
+ continue
957
+ inv, s, S, lt = make_pol(**tier_params[nid])
958
+ m["inventory"] = inv
959
+ m["s_levels"] = s
960
+ m["S_levels"] = S
961
+ m["lead_time_mean"] = lt
962
+
963
+ n = len(nodes_meta)
964
+ adj = [[None]*n for _ in range(n)]
965
+ idx = {m["id"]: i for i, m in enumerate(nodes_meta)}
966
+
967
+ def se(u, v, tt, cv, nc):
968
+ adj[idx[u]][idx[v]] = (tt, cv, nc)
969
+
970
+ # Travel times in days; upstream capacity generous (not bottleneck)
971
+ se("SanFrancisco", "Nashville", 4, 5000.0, 3)
972
+ se("StLouis", "Nashville", 2, 5000.0, 3)
973
+ se("Orlando", "Nashville", 2, 5000.0, 3)
974
+ se("Nashville", "Atlanta", 1, 15000.0, 3)
975
+ se("Atlanta", "Chicago", 8, 4000.0, 3)
976
+ se("Atlanta", "Charlotte", 7, 4000.0, 3)
977
+ se("Atlanta", "Memphis", 7, 4000.0, 3)
978
+ se("Chicago", "Columbus", 2, 4000.0, 3)
979
+ se("Charlotte", "Richmond", 2, 4000.0, 3)
980
+ se("Columbus", "Philadelphia", 2, 4000.0, 3)
981
+ se("Richmond", "Philadelphia", 1, 4000.0, 3)
982
+ se("Richmond", "Baltimore", 3, 3000.0, 3)
983
+ se("Columbus", "Baltimore", 3, 3000.0, 3)
984
+ se("Memphis", "Baltimore", 2, 3000.0, 3)
985
+ # Last-mile: placeholder, overwritten dynamically below
986
+ se("Philadelphia", "NewYork", 1, 1000.0, 3)
987
+ se("Baltimore", "NewYork", 2, 1000.0, 3)
988
+
989
+ net = build_network_from_adjacency(nodes_meta, adj)
990
+
991
+ sc = scenario or {}
992
+ containers_scale = float(sc.get("containers_scale", 1.0))
993
+ if containers_scale != 1.0:
994
+ for eid, e in net.edges.items():
995
+ e.num_containers_per_day = max(
996
+ 1, int(round(e.num_containers_per_day * containers_scale)))
997
+
998
+ base_lambda_lo = float(sc.get("base_lambda_lo", 80))
999
+ base_lambda_hi = float(sc.get("base_lambda_hi", 250))
1000
+
1001
+ print("Building day-level demand signals...")
1002
+ demand_fn, demand_signals = build_demand_fn(
1003
+ item_ids, horizon_days, seed,
1004
+ base_lambda_range=(base_lambda_lo, base_lambda_hi),
1005
+ scenario=scenario)
1006
+ print(f" Shape: {demand_signals.shape} "
1007
+ f"({horizon_days} days ≈ {horizon_days/365:.0f} years)")
1008
+ print(f" Lambda: mean={demand_signals.mean():.1f} "
1009
+ f"min={demand_signals.min():.1f} "
1010
+ f"max={demand_signals.max():.1f}")
1011
+
1012
+ # Set last-mile capacity dynamically based on actual lambda mean
1013
+ actual_mean_lam = float(demand_signals.mean())
1014
+ avg_vol = 2.5
1015
+ total_demand_vol = 50 * actual_mean_lam * avg_vol
1016
+ target_ratio = 1.20
1017
+ packing_eff = 0.93
1018
+ raw_needed = total_demand_vol * target_ratio / packing_eff
1019
+ philadelphia_cv = round(raw_needed * 0.55 / 3 / 100) * 100
1020
+ baltimore_cv = round(raw_needed * 0.45 / 3 / 100) * 100
1021
+ net.edges[("Philadelphia", "NewYork")].container_volume = float(philadelphia_cv)
1022
+ net.edges[("Baltimore", "NewYork")].container_volume = float(baltimore_cv)
1023
+ for eid in [("Philadelphia", "NewYork"), ("Baltimore", "NewYork")]:
1024
+ e = net.edges[eid]
1025
+ net.weight_cache[eid] = e.travel_time_days / max(
1026
+ e.daily_total_capacity, 1e-9)
1027
+ last_mile_cap = (philadelphia_cv + baltimore_cv) * 3
1028
+ print(f" Demand vol: {total_demand_vol:.0f}/day "
1029
+ f"Last-mile cap: {last_mile_cap:.0f}/day "
1030
+ f"Ratio: {last_mile_cap/total_demand_vol:.1%} "
1031
+ f"(Philadelphia={philadelphia_cv:.0f} Baltimore={baltimore_cv:.0f})")
1032
+ sim = SupplyChainSimulation(
1033
+ network=net,
1034
+ items=items,
1035
+ destination_id="NewYork",
1036
+ demand_fn=demand_fn,
1037
+ horizon_days=horizon_days,
1038
+ seed=seed,
1039
+ pipeline_multiplier=pipeline_multiplier,
1040
+ streaming_out_dir=streaming_out_dir,
1041
+ packing=packing)
1042
+
1043
+ return sim, net, items, demand_signals
1044
+
1045
+
1046
+ # ============================================================================
1047
+ # Folium helpers
1048
+ # ============================================================================
1049
+
1050
+ def plot_network_folium(network, tiles="cartodbpositron", zoom_start=5):
1051
+ lats = [n.lat for n in network.nodes.values()]
1052
+ lons = [n.lon for n in network.nodes.values()]
1053
+ c = (sum(lats)/len(lats), sum(lons)/len(lons))
1054
+ m = folium.Map(location=c, zoom_start=zoom_start, tiles=tiles)
1055
+ for (u, v), e in network.edges.items():
1056
+ nu, nv = network.nodes[u], network.nodes[v]
1057
+ folium.PolyLine(
1058
+ [(nu.lat, nu.lon), (nv.lat, nv.lon)],
1059
+ weight=2, opacity=0.7,
1060
+ tooltip=f"{u}→{v} t={e.travel_time_days:.0f}d "
1061
+ f"cap={e.daily_total_capacity:.0f}/day").add_to(m)
1062
+ for nid, n in network.nodes.items():
1063
+ folium.CircleMarker(
1064
+ (n.lat, n.lon),
1065
+ radius=6 if n.is_destination else 4, fill=True,
1066
+ tooltip=f"{nid} dest={n.is_destination}").add_to(m)
1067
+ return m
1068
+
1069
+
1070
+ def export_map_with_animation(
1071
+ network, sdf, inv_df=None, bl_df=None, it_df=None,
1072
+ out_html="supply_chain_map.html", start_date="2025-01-01",
1073
+ zoom_start=5,
1074
+ ):
1075
+ m = plot_network_folium(network, zoom_start=zoom_start)
1076
+ m.save(out_html)
1077
+ return out_html
1078
+
1079
+
1080
+ # ============================================================================
1081
+ # CLI
1082
+ # ============================================================================
1083
+
1084
+ if __name__ == "__main__":
1085
+ import argparse
1086
+ ap = argparse.ArgumentParser(
1087
+ description="Supply Chain Simulation — day-level, 52560-step")
1088
+ ap.add_argument("--days", type=int, default=52560)
1089
+ ap.add_argument("--seed", type=int, default=2025)
1090
+ ap.add_argument("--out_dir", type=str, default="test_output")
1091
+ ap.add_argument("--pipeline_mult", type=float, default=0.0,
1092
+ help="Days of EMA demand to keep in pipeline. "
1093
+ "0 = reactive mode (backlog + 3-day buffer).")
1094
+ ap.add_argument("--no_streaming", action="store_true")
1095
+ # Scenario / perturbation overrides (Exp C mixture). Defaults match the
1096
+ # released baseline configuration byte-for-byte.
1097
+ ap.add_argument("--phi_lo", type=float, default=0.9990)
1098
+ ap.add_argument("--phi_hi", type=float, default=0.9996)
1099
+ ap.add_argument("--shock_count_scale", type=float, default=1.0)
1100
+ ap.add_argument("--shock_height_scale", type=float, default=1.0)
1101
+ ap.add_argument("--seasonal_scale", type=float, default=1.0)
1102
+ ap.add_argument("--containers_scale", type=float, default=1.0)
1103
+ ap.add_argument("--ss_scale", type=float, default=1.0,
1104
+ help="Multiplier on (s_n, S_n) reorder/target levels "
1105
+ "(and initial inventory) at every non-destination "
1106
+ "node. 1.0 reproduces the released baseline.")
1107
+ ap.add_argument("--leadtime_scale", type=float, default=1.0,
1108
+ help="Multiplier on per-tier mean source lead time "
1109
+ "mu^n. 1.0 reproduces the released baseline.")
1110
+ ap.add_argument("--burst_rate_scale", type=float, default=1.0,
1111
+ help="Multiplier on per-item idiosyncratic burst "
1112
+ "Bernoulli rate sr. 1.0 reproduces baseline; "
1113
+ "values >1 increase per-item-burst frequency.")
1114
+ ap.add_argument("--burst_height_scale", type=float, default=1.0,
1115
+ help="Multiplier on per-item burst height h. "
1116
+ "1.0 reproduces baseline.")
1117
+ ap.add_argument("--base_lambda_lo", type=float, default=80.0)
1118
+ ap.add_argument("--base_lambda_hi", type=float, default=250.0)
1119
+ ap.add_argument("--scenario_name", type=str, default="baseline",
1120
+ help="Label written into <out_dir>/scenario.json")
1121
+ args = ap.parse_args()
1122
+
1123
+ scenario = {
1124
+ "name": args.scenario_name,
1125
+ "phi_lo": args.phi_lo, "phi_hi": args.phi_hi,
1126
+ "shock_count_scale": args.shock_count_scale,
1127
+ "shock_height_scale": args.shock_height_scale,
1128
+ "seasonal_scale": args.seasonal_scale,
1129
+ "containers_scale": args.containers_scale,
1130
+ "ss_scale": args.ss_scale,
1131
+ "leadtime_scale": args.leadtime_scale,
1132
+ "burst_rate_scale": args.burst_rate_scale,
1133
+ "burst_height_scale": args.burst_height_scale,
1134
+ "base_lambda_lo": args.base_lambda_lo,
1135
+ "base_lambda_hi": args.base_lambda_hi,
1136
+ "seed": args.seed, "days": args.days,
1137
+ }
1138
+
1139
+ streaming = not args.no_streaming and args.days > 500
1140
+
1141
+ print("=== Supply Chain Simulation (Multi-Echelon) ===")
1142
+ print(f" Days: {args.days:,} ({args.days/365:.1f} years) "
1143
+ f"Seed: {args.seed}")
1144
+ print(f" 1 step = 1 day | 365 = 1 year | 52560 = 144 years")
1145
+ print(f" Pipeline: {args.pipeline_mult} Streaming: {streaming}")
1146
+ print()
1147
+
1148
+ sim, net, items, dsig = build_example_simulation_from_adjacency(
1149
+ seed=args.seed,
1150
+ horizon_days=args.days,
1151
+ pipeline_multiplier=args.pipeline_mult,
1152
+ streaming_out_dir=args.out_dir if streaming else None,
1153
+ packing="greedy",
1154
+ scenario=scenario)
1155
+
1156
+ os.makedirs(args.out_dir, exist_ok=True)
1157
+ import json as _json
1158
+ with open(os.path.join(args.out_dir, "scenario.json"), "w") as _f:
1159
+ _json.dump(scenario, _f, indent=2)
1160
+
1161
+ dd, ds, svc, di, db, dt = sim.run()
1162
+
1163
+ os.makedirs(args.out_dir, exist_ok=True)
1164
+
1165
+ # Save demand signals
1166
+ print("Saving demand signals...")
1167
+ np.save(os.path.join(args.out_dir, "demand_signals.npy"),
1168
+ dsig[:args.days])
1169
+ with open(os.path.join(args.out_dir, "demand_signals_cols.txt"), "w") as f:
1170
+ f.write(",".join(sorted(items.keys())) + "\n")
1171
+ print(f" Saved shape={dsig[:args.days].shape}")
1172
+
1173
+ if not streaming:
1174
+ dd.to_csv(os.path.join(args.out_dir, "daily_records.csv"),
1175
+ index=False)
1176
+ ds.to_csv(os.path.join(args.out_dir, "shipments.csv"), index=False)
1177
+ svc.to_csv(os.path.join(args.out_dir, "service_summary.csv"),
1178
+ index=False)
1179
+ if args.days <= 500:
1180
+ di.to_csv(os.path.join(args.out_dir, "inventory_history.csv"),
1181
+ index=False)
1182
+ db.to_csv(os.path.join(args.out_dir, "backlog_history.csv"),
1183
+ index=False)
1184
+ dt.to_csv(os.path.join(args.out_dir, "intransit_history.csv"),
1185
+ index=False)
1186
+
1187
+ fr = svc['fill_rate_stock_only']
1188
+ print(f"\n=== Service Summary ===")
1189
+ print(f" Fill rate: mean={fr.mean():.3f} "
1190
+ f"median={fr.median():.3f} "
1191
+ f"min={fr.min():.3f} max={fr.max():.3f}")
1192
+ print(f" Total demand: {svc['total_demand'].sum():,}")
1193
+ print(f" Total served: {svc['served_from_stock'].sum():,}")
1194
+ print(f" Total backlog: {svc['new_backlog_added'].sum():,}")
1195
+ print(f"\nOutputs → {args.out_dir}/")
simulator/demo_simulator.py ADDED
@@ -0,0 +1,1097 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISOMORPH Interactive Demo — Simulation Backend
3
+ ================================================
4
+ Lightweight supply-chain simulation engine adapted from
5
+ simulator/Supplychaingeo_item50.py for the interactive web demo.
6
+
7
+ Differences from the original:
8
+ 1. Returns a SimResult dataclass — no disk I/O.
9
+ 2. Supports configurable T (1–1000) and C (1–5 items).
10
+ 3. Adds a per-edge Markovian disruption mechanism (Bernoulli trigger,
11
+ fixed-duration capacity knockout).
12
+ 4. Tracks per-node inflow / outflow arrays during simulation so the
13
+ bullwhip analysis and time-series panels can be computed without
14
+ post-processing the shipments log.
15
+ 5. All core logic — AR(1) demand model, (s,S) inventory policy,
16
+ Dijkstra routing, greedy bin-packing — is kept identical to the
17
+ original simulator.
18
+
19
+ Public entry point
20
+ ------------------
21
+ result = run_demo_simulation(config)
22
+
23
+ config keys (all optional, defaults shown):
24
+ T int 365 simulation horizon in days
25
+ n_items int 3 number of SKUs (1–5)
26
+ seed int 42
27
+ pipeline_mult float 7.0 days of EMA demand to keep in pipeline
28
+ phi_lo float 0.95 AR(1) coefficient lower bound
29
+ phi_hi float 0.97 AR(1) coefficient upper bound
30
+ shock_count_scale float 1.0
31
+ shock_height_scale float 1.0
32
+ burst_rate_scale float 1.0
33
+ burst_height_scale float 1.0
34
+ seasonal_scale float 1.0
35
+ base_lambda_lo float 80.0
36
+ base_lambda_hi float 250.0
37
+ containers_scale float 1.0 edge capacity multiplier
38
+ ss_scale float 1.0 (s, S) and initial inventory multiplier
39
+ leadtime_scale float 1.0 lead-time multiplier at source nodes
40
+ disruption_edge tuple|None None (from_node, to_node) edge to disrupt
41
+ disruption_prob float 0.0 per-day Bernoulli trigger probability
42
+ disruption_duration int 10 consecutive days edge is disabled
43
+ holding_cost float 1.0 per-unit per-day cost (display only)
44
+ backlog_penalty float 5.0 per-unit per-day penalty (display only)
45
+ """
46
+
47
+ from __future__ import annotations
48
+
49
+ import math
50
+ import heapq
51
+ import random
52
+ from dataclasses import dataclass, field
53
+ from typing import Dict, List, Optional, Tuple, Callable
54
+
55
+ import numpy as np
56
+
57
+
58
+ # ============================================================================
59
+ # Data Structures — verbatim from Supplychaingeo_item50.py
60
+ # ============================================================================
61
+
62
+ @dataclass
63
+ class Item:
64
+ item_id: str
65
+ volume: float
66
+
67
+
68
+ @dataclass
69
+ class Edge:
70
+ u: str
71
+ v: str
72
+ travel_time_days: float
73
+ container_volume: float
74
+ num_containers_per_day: int
75
+ daily_containers: List[float] = field(default_factory=list)
76
+
77
+ def reset_daily(self, capacity_factor: float = 1.0) -> None:
78
+ effective_vol = self.container_volume * capacity_factor
79
+ self.daily_containers = [effective_vol] * self.num_containers_per_day
80
+
81
+ def find_container_slot(self, item_volume: float) -> Optional[int]:
82
+ for idx, rem in enumerate(self.daily_containers):
83
+ if rem >= item_volume:
84
+ return idx
85
+ return None
86
+
87
+ def allocate_in_container(self, idx: int, item_volume: float) -> bool:
88
+ if 0 <= idx < len(self.daily_containers) and \
89
+ self.daily_containers[idx] >= item_volume:
90
+ self.daily_containers[idx] -= item_volume
91
+ return True
92
+ return False
93
+
94
+ @property
95
+ def daily_total_capacity(self) -> float:
96
+ return self.container_volume * max(self.num_containers_per_day, 0)
97
+
98
+
99
+ @dataclass
100
+ class Node:
101
+ node_id: str
102
+ lat: float
103
+ lon: float
104
+ is_destination: bool = False
105
+ is_source: bool = False
106
+ inventory: Dict[str, int] = field(default_factory=dict)
107
+ s_levels: Dict[str, int] = field(default_factory=dict)
108
+ S_levels: Dict[str, int] = field(default_factory=dict)
109
+ lead_time_mean: Dict[str, float] = field(default_factory=dict)
110
+ lead_time_std_frac: float = 0.2
111
+ outstanding_orders: Dict[str, Optional[Tuple[int, int]]] = \
112
+ field(default_factory=dict)
113
+ backlog: Dict[str, int] = field(default_factory=dict)
114
+
115
+ def receive_orders_today(self, day: int) -> None:
116
+ to_clear = []
117
+ for item_id, order in self.outstanding_orders.items():
118
+ if order is None:
119
+ continue
120
+ arrival_day, qty = order
121
+ if day >= arrival_day:
122
+ self.inventory[item_id] = \
123
+ self.inventory.get(item_id, 0) + qty
124
+ to_clear.append(item_id)
125
+ for iid in to_clear:
126
+ self.outstanding_orders[iid] = None
127
+
128
+ def maybe_place_orders(self, day: int, rng: random.Random) -> None:
129
+ if self.is_destination:
130
+ return
131
+ if not self.is_source:
132
+ return
133
+ for item_id, s in self.s_levels.items():
134
+ on_hand = self.inventory.get(item_id, 0)
135
+ if on_hand < s and \
136
+ self.outstanding_orders.get(item_id) is None:
137
+ S = self.S_levels.get(item_id, on_hand)
138
+ qty = max(S - on_hand, 0)
139
+ if qty <= 0:
140
+ continue
141
+ mean_lt = max(self.lead_time_mean.get(item_id, 1.0), 0.1)
142
+ std = self.lead_time_std_frac * mean_lt
143
+ sampled = rng.normalvariate(mean_lt, std)
144
+ lt_days = max(1, int(math.ceil(sampled)))
145
+ self.outstanding_orders[item_id] = (day + lt_days, qty)
146
+
147
+
148
+ # ============================================================================
149
+ # Network + Dijkstra — verbatim from Supplychaingeo_item50.py
150
+ # ============================================================================
151
+
152
+ class Network:
153
+ def __init__(self) -> None:
154
+ self.nodes: Dict[str, Node] = {}
155
+ self.edges: Dict[Tuple[str, str], Edge] = {}
156
+ self.adj: Dict[str, List[str]] = {}
157
+ self.weight_cache: Dict[Tuple[str, str], float] = {}
158
+ self.paths_to_dest: Dict[
159
+ str, Tuple[float, List[str], List[Tuple[str, str]]]] = {}
160
+
161
+ def add_node(self, node: Node) -> None:
162
+ self.nodes[node.node_id] = node
163
+ self.adj.setdefault(node.node_id, [])
164
+
165
+ def add_edge(self, edge: Edge) -> None:
166
+ self.edges[(edge.u, edge.v)] = edge
167
+ self.adj.setdefault(edge.u, []).append(edge.v)
168
+ cap = max(edge.daily_total_capacity, 1e-9)
169
+ self.weight_cache[(edge.u, edge.v)] = edge.travel_time_days / cap
170
+
171
+ def reset_daily_edges(self, capacity_factor: float = 1.0) -> None:
172
+ for e in self.edges.values():
173
+ e.reset_daily(capacity_factor)
174
+
175
+ def dijkstra(self, source: str, target: str) -> Tuple[float, List[str]]:
176
+ if source == target:
177
+ return 0.0, [source]
178
+ dist: Dict[str, float] = {source: 0.0}
179
+ prev: Dict[str, Optional[str]] = {source: None}
180
+ pq = [(0.0, source)]
181
+ visited: set = set()
182
+ while pq:
183
+ d, u = heapq.heappop(pq)
184
+ if u in visited:
185
+ continue
186
+ visited.add(u)
187
+ if u == target:
188
+ break
189
+ for v in self.adj.get(u, []):
190
+ w = self.weight_cache[(u, v)]
191
+ nd = d + w
192
+ if nd < dist.get(v, float('inf')):
193
+ dist[v] = nd
194
+ prev[v] = u
195
+ heapq.heappush(pq, (nd, v))
196
+ if target not in dist:
197
+ return float('inf'), []
198
+ path: List[str] = []
199
+ cur: Optional[str] = target
200
+ while cur is not None:
201
+ path.append(cur)
202
+ cur = prev.get(cur)
203
+ path.reverse()
204
+ return dist[target], path
205
+
206
+ def compute_paths_to_destination(self, dest_id: str) -> None:
207
+ self.paths_to_dest.clear()
208
+ for nid in self.nodes:
209
+ if nid == dest_id:
210
+ self.paths_to_dest[nid] = (0.0, [nid], [])
211
+ continue
212
+ d, pn = self.dijkstra(nid, dest_id)
213
+ if not pn:
214
+ self.paths_to_dest[nid] = (float('inf'), [], [])
215
+ else:
216
+ pe = [(pn[i], pn[i + 1]) for i in range(len(pn) - 1)]
217
+ self.paths_to_dest[nid] = (d, pn, pe)
218
+
219
+
220
+ # ============================================================================
221
+ # Greedy First-Fit Bin Packing — verbatim from Supplychaingeo_item50.py
222
+ # ============================================================================
223
+
224
+ def allocate_units_along_path_greedy(
225
+ item: Item,
226
+ max_units: int,
227
+ path_edges: List[Tuple[str, str]],
228
+ network_edges: Dict[Tuple[str, str], Edge],
229
+ ) -> int:
230
+ if max_units <= 0 or not path_edges:
231
+ return 0
232
+ edges_objs = [network_edges[eid] for eid in path_edges]
233
+ placed = 0
234
+ ivol = item.volume
235
+ for _ in range(max_units):
236
+ slots: List[Tuple[Edge, int]] = []
237
+ ok = True
238
+ for e in edges_objs:
239
+ si = e.find_container_slot(ivol)
240
+ if si is None:
241
+ ok = False
242
+ break
243
+ slots.append((e, si))
244
+ if not ok:
245
+ break
246
+ good = True
247
+ for e, si in slots:
248
+ if not e.allocate_in_container(si, ivol):
249
+ good = False
250
+ break
251
+ if not good:
252
+ break
253
+ placed += 1
254
+ return placed
255
+
256
+
257
+ # ============================================================================
258
+ # SimResult — returned by run_demo_simulation
259
+ # ============================================================================
260
+
261
+ @dataclass
262
+ class SimResult:
263
+ """All simulation outputs needed by the visualization layer."""
264
+
265
+ T: int
266
+ item_ids: List[str]
267
+ node_ids: List[str]
268
+ edge_ids: List[Tuple[str, str]]
269
+
270
+ # Per-node, per-item time series: dict[node_id][item_id] -> ndarray(T,)
271
+ inventory: Dict[str, Dict[str, np.ndarray]]
272
+ backlog: Dict[str, Dict[str, np.ndarray]]
273
+ # inflow[n][i][t] = units arriving AT node n for item i on day t
274
+ # outflow[n][i][t] = units dispatched FROM node n for item i on day t
275
+ inflow: Dict[str, Dict[str, np.ndarray]]
276
+ outflow: Dict[str, Dict[str, np.ndarray]]
277
+
278
+ # Realized customer demand at destination per item
279
+ demand: Dict[str, np.ndarray] # item_id -> ndarray(T,)
280
+
281
+ # Demand intensity signal (before Poisson sampling)
282
+ demand_signals: np.ndarray # shape (T, n_items)
283
+
284
+ # Shipment records for animation and bullwhip
285
+ # Each entry: {day, arrival_day, from, to, item, units,
286
+ # path_nodes (list), edge_times (list)}
287
+ shipments: List[Dict]
288
+
289
+ # Edge utilization: fraction of daily capacity used, shape (T,)
290
+ edge_util: Dict[Tuple[str, str], np.ndarray]
291
+ edge_cap: Dict[Tuple[str, str], float] # cap_per_day
292
+
293
+ # Summary
294
+ fill_rate: Dict[str, float] # item_id -> fill_rate_stock_only
295
+
296
+ # Tier assignment for bullwhip visualization
297
+ # tier 0 = destination (NewYork), higher = further upstream
298
+ tier: Dict[str, int]
299
+
300
+ # Coords for map visualization
301
+ node_coords: Dict[str, Tuple[float, float]] # node_id -> (lat, lon)
302
+
303
+ # Log of disruption events
304
+ disruption_log: List[Dict] # [{day, edge, duration}]
305
+
306
+ # Config echo
307
+ config: Dict
308
+
309
+
310
+ # ============================================================================
311
+ # Demo Simulation Engine
312
+ # Adapted from SupplyChainSimulation in Supplychaingeo_item50.py.
313
+ # Changes: no CSV streaming; per-edge disruption; direct inflow/outflow tracking.
314
+ # ============================================================================
315
+
316
+ class DemoSimulation:
317
+
318
+ def __init__(
319
+ self,
320
+ network: Network,
321
+ items: Dict[str, Item],
322
+ destination_id: str,
323
+ demand_fn: Callable[[int], Dict[str, int]],
324
+ T: int,
325
+ seed: int = 42,
326
+ pipeline_multiplier: float = 7.0,
327
+ disruption_edge: Optional[Tuple[str, str]] = None,
328
+ disruption_prob: float = 0.0,
329
+ disruption_duration: int = 10,
330
+ ) -> None:
331
+ assert destination_id in network.nodes and \
332
+ network.nodes[destination_id].is_destination
333
+
334
+ self.network = network
335
+ self.items = items
336
+ self.item_order: List[str] = sorted(self.items.keys())
337
+ self.destination_id = destination_id
338
+ self.demand_fn = demand_fn
339
+ self.T = T
340
+ self.rng = random.Random(seed)
341
+ self.pipeline_multiplier = pipeline_multiplier
342
+
343
+ # Disruption state
344
+ self.disruption_edge = disruption_edge
345
+ self.disruption_prob = disruption_prob
346
+ self.disruption_duration = disruption_duration
347
+ self._disruption_remaining: int = 0 # days remaining in current event
348
+ self.disruption_log: List[Dict] = []
349
+
350
+ # EMA warm-start (same default as original)
351
+ self.demand_ema: Dict[str, float] = {iid: 165.0 for iid in items}
352
+ self.ema_alpha = 0.05
353
+ self.item_intransit: Dict[str, int] = {iid: 0 for iid in items}
354
+
355
+ # In-transit to destination: arrival_day -> {item_id: qty}
356
+ self.dest_in_transit: Dict[int, Dict[str, int]] = {}
357
+
358
+ # ── Pre-allocate history arrays ───────────────────────────────────
359
+ node_ids = list(network.nodes.keys())
360
+ item_ids = self.item_order
361
+
362
+ def _zarr():
363
+ return np.zeros(T, dtype=np.int32)
364
+
365
+ self.inv_hist = {nid: {iid: _zarr() for iid in item_ids}
366
+ for nid in node_ids}
367
+ self.bl_hist = {nid: {iid: _zarr() for iid in item_ids}
368
+ for nid in node_ids}
369
+ self.inflow = {nid: {iid: _zarr() for iid in item_ids}
370
+ for nid in node_ids}
371
+ self.outflow = {nid: {iid: _zarr() for iid in item_ids}
372
+ for nid in node_ids}
373
+ self.demand_hist = {iid: _zarr() for iid in item_ids}
374
+
375
+ # Shipments log — populated during run
376
+ self.shipments_log: List[Dict] = []
377
+
378
+ # Service counters
379
+ self.svc_demand: Dict[str, int] = {iid: 0 for iid in items}
380
+ self.svc_served: Dict[str, int] = {iid: 0 for iid in items}
381
+
382
+ # ── Precompute routing ────────────────────────────────────────────
383
+ network.compute_paths_to_destination(destination_id)
384
+ self.sorted_warehouses = sorted(
385
+ [(nid, d) for nid, (d, _, _) in network.paths_to_dest.items()
386
+ if nid != destination_id and math.isfinite(d)],
387
+ key=lambda x: x[1])
388
+
389
+ # Supplier relationships for inter-warehouse replenishment
390
+ self.node_suppliers: Dict[
391
+ str,
392
+ List[Tuple[str, float, List[str], List[Tuple[str, str]]]]
393
+ ] = {}
394
+ for nid, node in network.nodes.items():
395
+ if node.is_destination or node.is_source:
396
+ continue
397
+ suppliers = []
398
+ for (u, v) in network.edges:
399
+ if v == nid:
400
+ d, path = network.dijkstra(u, nid)
401
+ if path and math.isfinite(d):
402
+ pe = [(path[i], path[i + 1])
403
+ for i in range(len(path) - 1)]
404
+ tt = sum(network.edges[e].travel_time_days for e in pe)
405
+ suppliers.append((u, tt, path, pe))
406
+ suppliers.sort(key=lambda x: x[1])
407
+ self.node_suppliers[nid] = suppliers
408
+
409
+ # Intermediate nodes ordered upstream-first for replenishment
410
+ self.replenish_order: List[str] = [
411
+ nid for nid, _ in reversed(self.sorted_warehouses)
412
+ if nid in self.node_suppliers]
413
+
414
+ # ── helpers ──────────────────────────────────────────────────────────────
415
+
416
+ def _total_tt(self, pe: List[Tuple[str, str]]) -> float:
417
+ return sum(self.network.edges[e].travel_time_days for e in pe)
418
+
419
+ def _reset_edges_with_disruption(self, day: int) -> None:
420
+ """Reset edge containers, knocking out the disruption edge if active."""
421
+ ded = self.disruption_edge
422
+
423
+ if ded is not None and ded in self.network.edges:
424
+ # Advance disruption state machine
425
+ if self._disruption_remaining > 0:
426
+ self._disruption_remaining -= 1
427
+ elif self.rng.random() < self.disruption_prob:
428
+ self._disruption_remaining = self.disruption_duration
429
+ self.disruption_log.append(
430
+ {"day": day, "edge": ded,
431
+ "duration": self.disruption_duration})
432
+
433
+ for eid, e in self.network.edges.items():
434
+ if ded is not None and eid == ded and \
435
+ self._disruption_remaining > 0:
436
+ e.reset_daily(0.0) # capacity = 0: edge blocked
437
+ else:
438
+ e.reset_daily(1.0)
439
+
440
+ # ── replenishment ─────────────────────────────────────────────────────────
441
+
442
+ def _replenish_warehouses(self, day: int) -> None:
443
+ """Inter-warehouse replenishment: intermediate nodes pull
444
+ from upstream suppliers using (s,S) trigger + edge capacity.
445
+ Logic verbatim from Supplychaingeo_item50.py."""
446
+ net = self.network
447
+ k = day % len(self.item_order)
448
+ ids_today = self.item_order[k:] + self.item_order[:k]
449
+
450
+ for nid in self.replenish_order:
451
+ node = net.nodes[nid]
452
+ for iid in ids_today:
453
+ on_hand = node.inventory.get(iid, 0)
454
+ s = node.s_levels.get(iid, 0)
455
+ if on_hand >= s:
456
+ continue
457
+ if node.outstanding_orders.get(iid) is not None:
458
+ continue
459
+ S = node.S_levels.get(iid, on_hand)
460
+ qty_needed = max(S - on_hand, 0)
461
+ if qty_needed <= 0:
462
+ continue
463
+
464
+ for sup_id, tt, path, pe in \
465
+ self.node_suppliers.get(nid, []):
466
+ sup_node = net.nodes[sup_id]
467
+ avail = sup_node.inventory.get(iid, 0)
468
+ if avail <= 0:
469
+ continue
470
+ attempt = min(avail, qty_needed)
471
+ placed = allocate_units_along_path_greedy(
472
+ self.items[iid], attempt, pe, net.edges)
473
+ if placed <= 0:
474
+ continue
475
+
476
+ sup_node.inventory[iid] -= placed
477
+ arr = day + max(1, int(math.ceil(tt)))
478
+ node.outstanding_orders[iid] = (arr, placed)
479
+
480
+ # Track outflow at supplier and inflow at receiver
481
+ self.outflow[sup_id][iid][day] += placed
482
+ if arr < self.T:
483
+ self.inflow[nid][iid][arr] += placed
484
+
485
+ self.shipments_log.append({
486
+ "day": day, "arrival_day": arr,
487
+ "from": sup_id, "to": nid,
488
+ "item": iid, "units": placed,
489
+ "path_nodes": path,
490
+ "edge_times": [net.edges[e].travel_time_days
491
+ for e in pe]})
492
+ break
493
+
494
+ # ── main step ─────────────────────────────────────────────────────────────
495
+
496
+ def step(self, day: int) -> None:
497
+ net = self.network
498
+ dest = net.nodes[self.destination_id]
499
+
500
+ # 1) (s,S) replenishment arrivals at warehouse nodes
501
+ for node in net.nodes.values():
502
+ node.receive_orders_today(day)
503
+
504
+ # 2) Arrivals at destination
505
+ arrivals = self.dest_in_transit.pop(day, {})
506
+ for iid, qty in arrivals.items():
507
+ self.item_intransit[iid] = max(
508
+ 0, self.item_intransit.get(iid, 0) - qty)
509
+ # Backlog absorption
510
+ bl = dest.backlog.get(iid, 0)
511
+ if bl > 0:
512
+ use = min(qty, bl)
513
+ dest.backlog[iid] = bl - use
514
+ qty -= use
515
+ if qty > 0:
516
+ dest.inventory[iid] = dest.inventory.get(iid, 0) + qty
517
+ # Record inflow at destination
518
+ self.inflow[self.destination_id][iid][day] += arrivals.get(iid, 0)
519
+
520
+ # 3) Reset edge containers (with disruption)
521
+ self._reset_edges_with_disruption(day)
522
+
523
+ # 4) Demand at destination
524
+ td = self.demand_fn(day)
525
+ k = day % len(self.item_order)
526
+ ids_today = self.item_order[k:] + self.item_order[:k]
527
+
528
+ for iid in self.items:
529
+ dq = int(td.get(iid, 0))
530
+ self.demand_ema[iid] = (
531
+ self.ema_alpha * dq +
532
+ (1 - self.ema_alpha) * self.demand_ema[iid])
533
+ oh = dest.inventory.get(iid, 0)
534
+ if oh >= dq:
535
+ served, unfilled = dq, 0
536
+ dest.inventory[iid] = oh - dq
537
+ else:
538
+ served, unfilled = oh, dq - oh
539
+ dest.inventory[iid] = 0
540
+ dest.backlog[iid] = dest.backlog.get(iid, 0) + unfilled
541
+ self.svc_demand[iid] += dq
542
+ self.svc_served[iid] += served
543
+ self.demand_hist[iid][day] = dq
544
+
545
+ # 5) Ship from warehouses to destination
546
+ for iid in ids_today:
547
+ item = self.items[iid]
548
+ cb = dest.backlog.get(iid, 0)
549
+ it = self.item_intransit.get(iid, 0)
550
+ oh = dest.inventory.get(iid, 0)
551
+
552
+ if self.pipeline_multiplier > 0:
553
+ pt = self.demand_ema[iid] * self.pipeline_multiplier
554
+ ship_target = max(0, int(math.ceil(cb + pt - it - oh)))
555
+ else:
556
+ S_dest = max(1, int(self.demand_ema[iid] * 3))
557
+ ship_target = max(0, cb + S_dest - oh - it)
558
+
559
+ if ship_target <= 0:
560
+ continue
561
+
562
+ remaining = ship_target
563
+ for wid, _ in self.sorted_warehouses:
564
+ if remaining <= 0:
565
+ break
566
+ wn = net.nodes[wid]
567
+ avail = wn.inventory.get(iid, 0)
568
+ if avail <= 0:
569
+ continue
570
+ _, pn, pe = net.paths_to_dest[wid]
571
+ if not pe:
572
+ continue
573
+ attempt = min(avail, remaining)
574
+ placed = allocate_units_along_path_greedy(
575
+ item, attempt, pe, net.edges)
576
+ if placed <= 0:
577
+ continue
578
+ wn.inventory[iid] -= placed
579
+ remaining -= placed
580
+ arr = day + max(1, int(math.ceil(self._total_tt(pe))))
581
+ self.dest_in_transit.setdefault(arr, {})
582
+ self.dest_in_transit[arr][iid] = \
583
+ self.dest_in_transit[arr].get(iid, 0) + placed
584
+ self.item_intransit[iid] = \
585
+ self.item_intransit.get(iid, 0) + placed
586
+
587
+ # Track outflow at warehouse; inflow at dest tracked in step 2
588
+ self.outflow[wid][iid][day] += placed
589
+
590
+ self.shipments_log.append({
591
+ "day": day, "arrival_day": arr,
592
+ "from": wid, "to": self.destination_id,
593
+ "item": iid, "units": placed,
594
+ "path_nodes": pn,
595
+ "edge_times": [net.edges[e].travel_time_days for e in pe]})
596
+
597
+ # 5b) Inter-warehouse replenishment
598
+ self._replenish_warehouses(day)
599
+
600
+ # 6) (s,S) orders at source nodes
601
+ for node in net.nodes.values():
602
+ node.maybe_place_orders(day, self.rng)
603
+
604
+ # 7) Snapshot inventory and backlog
605
+ for node in net.nodes.values():
606
+ for iid in self.items:
607
+ self.inv_hist[node.node_id][iid][day] = \
608
+ int(node.inventory.get(iid, 0))
609
+ self.bl_hist[node.node_id][iid][day] = \
610
+ int(node.backlog.get(iid, 0))
611
+
612
+ # ── run ───────────────────────────────────────────────────────────────────
613
+
614
+ def run(self) -> None:
615
+ for day in range(self.T):
616
+ self.step(day)
617
+ # Clean up stale in-transit entries
618
+ if day % 200 == 0:
619
+ for k in [k for k in self.dest_in_transit if k < day]:
620
+ del self.dest_in_transit[k]
621
+
622
+
623
+ # ============================================================================
624
+ # Demand Generator — same AR(1) + shocks + burst model as the original.
625
+ # phi defaults lowered for shorter demo horizons (more visible dynamics).
626
+ # ============================================================================
627
+
628
+ def build_demand_fn(
629
+ item_ids: List[str],
630
+ n_steps: int,
631
+ seed: int = 42,
632
+ base_lambda_range: Tuple[float, float] = (80.0, 250.0),
633
+ scenario: Optional[Dict] = None,
634
+ ) -> Tuple[Callable[[int], Dict[str, int]], np.ndarray]:
635
+ """
636
+ Returns (demand_fn, lam) where lam has shape (n_steps, n_items).
637
+ demand_fn(day) -> {item_id: int} via Poisson sampling from lam[day].
638
+
639
+ AR(1) defaults are phi_lo=0.95, phi_hi=0.97 for demo horizons (300–1000
640
+ days); the original used 0.999 for 52560-step runs.
641
+ """
642
+ rng = np.random.default_rng(seed)
643
+ n_items = len(item_ids)
644
+
645
+ sc = scenario or {}
646
+ phi_lo = float(sc.get("phi_lo", 0.95))
647
+ phi_hi = float(sc.get("phi_hi", 0.97))
648
+ shock_count_scale = float(sc.get("shock_count_scale", 1.0))
649
+ shock_height_scale = float(sc.get("shock_height_scale", 1.0))
650
+ seasonal_scale = float(sc.get("seasonal_scale", 1.0))
651
+ burst_rate_scale = float(sc.get("burst_rate_scale", 1.0))
652
+ burst_height_scale = float(sc.get("burst_height_scale", 1.0))
653
+
654
+ spy = 365 # yearly period in days
655
+ spw = 7 # weekly period in days
656
+
657
+ t = np.arange(n_steps, dtype=np.float64)
658
+ yearly_phase = 2 * np.pi * t / spy
659
+ weekly_phase = 2 * np.pi * (t % spw) / spw
660
+
661
+ lam = np.zeros((n_steps, n_items), dtype=np.float64)
662
+
663
+ # Global macro events: rare, long, correlated across items
664
+ gs = np.zeros(n_steps)
665
+ base_n_global = rng.integers(5, 12)
666
+ n_global = max(0, int(round(float(base_n_global) * shock_count_scale)))
667
+ for _ in range(n_global):
668
+ si = int(rng.integers(0, n_steps))
669
+ dur = int(rng.integers(min(180, n_steps), max(min(1100, n_steps), min(180, n_steps) + 1)))
670
+ end = min(si + dur, n_steps)
671
+ h = rng.uniform(0.20, 0.60) * shock_height_scale
672
+ for k in range(si, end):
673
+ p = (k - si) / max(dur, 1)
674
+ if p < 0.15:
675
+ gs[k] += h * (p / 0.15)
676
+ elif p < 0.75:
677
+ gs[k] += h
678
+ else:
679
+ gs[k] += h * (1.0 - (p - 0.75) / 0.25)
680
+
681
+ for j, iid in enumerate(item_ids):
682
+ base = rng.uniform(*base_lambda_range)
683
+
684
+ # Yearly seasonality: two harmonics
685
+ ya1 = rng.uniform(0.12, 0.28) * seasonal_scale
686
+ ya2 = rng.uniform(0.04, 0.10) * seasonal_scale
687
+ yo = rng.uniform(0, 2 * np.pi)
688
+ yr = (ya1 * np.sin(yearly_phase + yo) +
689
+ ya2 * np.sin(2 * yearly_phase + yo * 0.7))
690
+
691
+ # Weekly texture
692
+ wa = rng.uniform(0.04, 0.10)
693
+ wo = rng.uniform(0, 2 * np.pi)
694
+ wy = wa * np.sin(weekly_phase + wo)
695
+
696
+ # AR(1) drift (dominant low-frequency component)
697
+ ac = rng.uniform(phi_lo, phi_hi)
698
+ ar_std = rng.uniform(0.008, 0.018)
699
+ dr = np.zeros(n_steps)
700
+ dr[0] = rng.normal(0, 0.10)
701
+ for i in range(1, n_steps):
702
+ dr[i] = ac * dr[i - 1] + rng.normal(0, ar_std)
703
+ dr = np.clip(dr, -0.60, 0.60)
704
+
705
+ # Per-item idiosyncratic burst events
706
+ sr = rng.uniform(0.0002, 0.001) * burst_rate_scale
707
+ sm = rng.random(n_steps) < sr
708
+ sp = np.zeros(n_steps)
709
+ for si in np.where(sm)[0]:
710
+ dur = int(rng.integers(min(30, n_steps), max(min(180, n_steps), min(30, n_steps) + 1)))
711
+ end = min(si + dur, n_steps)
712
+ h = rng.uniform(0.20, 0.70) * burst_height_scale
713
+ for k in range(si, end):
714
+ p = (k - si) / max(dur, 1)
715
+ if p < 0.15:
716
+ sp[k] += h * (p / 0.15)
717
+ elif p < 0.75:
718
+ sp[k] += h
719
+ else:
720
+ sp[k] += h * (1.0 - (p - 0.75) / 0.25)
721
+
722
+ gsens = rng.uniform(0.4, 1.2)
723
+ fac = 1.0 + yr + wy + dr + sp + gsens * gs
724
+ fac = np.clip(fac, 0.08, None)
725
+ lam[:, j] = base * fac
726
+
727
+ def demand_fn(day: int) -> Dict[str, int]:
728
+ idx = day % n_steps
729
+ return {iid: int(rng.poisson(lam=max(0.01, lam[idx, j])))
730
+ for j, iid in enumerate(item_ids)}
731
+
732
+ return demand_fn, lam
733
+
734
+
735
+ # ============================================================================
736
+ # Network Builder — same fixed US topology as Supplychaingeo_item50.py,
737
+ # scaled for n_items and demo parameters.
738
+ # ============================================================================
739
+
740
+ # Tier labeling matches Table C.4 / bullwhip_analysis.py in the paper.
741
+ _TIER: Dict[str, int] = {
742
+ "NewYork": 0, # destination
743
+ "Philadelphia": 1, # last-mile
744
+ "Baltimore": 1, # last-mile
745
+ "Columbus": 2, # Tier-4
746
+ "Richmond": 2, # Tier-4
747
+ "Charlotte": 3, # Tier-3
748
+ "Chicago": 3, # Tier-3
749
+ "Memphis": 3, # Tier-3
750
+ "Atlanta": 4, # Tier-2
751
+ "Nashville": 5, # hub / Tier-1
752
+ "SanFrancisco": 6, # source
753
+ "StLouis": 6, # source
754
+ "Orlando": 6, # source
755
+ }
756
+
757
+ # Coordinates (lat, lon)
758
+ _COORDS: Dict[str, Tuple[float, float]] = {
759
+ "NewYork": (40.7128, -74.0060),
760
+ "Philadelphia": (39.9526, -75.1652),
761
+ "Baltimore": (39.2904, -76.6122),
762
+ "Columbus": (39.9612, -82.9988),
763
+ "Richmond": (37.5407, -77.4360),
764
+ "Charlotte": (35.2271, -80.8431),
765
+ "Chicago": (41.8781, -87.6298),
766
+ "Memphis": (35.1495, -90.0490),
767
+ "Atlanta": (33.7490, -84.3880),
768
+ "Nashville": (36.1627, -86.7816),
769
+ "SanFrancisco": (37.7749, -122.4194),
770
+ "StLouis": (38.6270, -90.1994),
771
+ "Orlando": (28.5383, -81.3792),
772
+ }
773
+
774
+
775
+ def _build_demo_network(
776
+ n_items: int,
777
+ item_ids: List[str],
778
+ seed: int,
779
+ scenario: Dict,
780
+ demand_signals: np.ndarray,
781
+ ) -> Network:
782
+ """
783
+ Build the canonical ISOMORPH network with (s,S) parameters adapted
784
+ for the demo item count and scenario knobs.
785
+ """
786
+ random.seed(seed)
787
+
788
+ ss_scale = float(scenario.get("ss_scale", 1.0))
789
+ leadtime_scale = float(scenario.get("leadtime_scale", 1.0))
790
+ containers_scale = float(scenario.get("containers_scale", 1.0))
791
+
792
+ def make_pol(inv_base=4000, inv_var=500,
793
+ s_base=600, s_var=100,
794
+ S_base=6000, S_var=500,
795
+ lt_mean=5, lt_var=1):
796
+ inv_base = inv_base * ss_scale
797
+ inv_var = inv_var * ss_scale
798
+ s_base = s_base * ss_scale
799
+ s_var = s_var * ss_scale
800
+ S_base = S_base * ss_scale
801
+ S_var = S_var * ss_scale
802
+ lt_mean = lt_mean * leadtime_scale
803
+ lt_var = lt_var * leadtime_scale
804
+ inv, s, S, lt = {}, {}, {}, {}
805
+ for iid in item_ids:
806
+ si = max(0, int(round(s_base + random.uniform(-s_var, s_var))))
807
+ Si = max(si + 1, int(round(
808
+ S_base + random.uniform(-S_var, S_var))))
809
+ ii = int(round(inv_base + random.uniform(-inv_var, inv_var)))
810
+ ii = max(si, min(Si, max(0, ii)))
811
+ li = max(1, int(round(lt_mean + random.uniform(-lt_var, lt_var))))
812
+ s[iid], S[iid], inv[iid], lt[iid] = si, Si, ii, li
813
+ return inv, s, S, lt
814
+
815
+ tier_params = {
816
+ "SanFrancisco": dict(inv_base=4000, inv_var=400,
817
+ s_base=400, s_var=60,
818
+ S_base=4000, S_var=400,
819
+ lt_mean=3, lt_var=1),
820
+ "StLouis": dict(inv_base=4000, inv_var=400,
821
+ s_base=400, s_var=60,
822
+ S_base=4000, S_var=400,
823
+ lt_mean=3, lt_var=1),
824
+ "Orlando": dict(inv_base=4000, inv_var=400,
825
+ s_base=400, s_var=60,
826
+ S_base=4000, S_var=400,
827
+ lt_mean=3, lt_var=1),
828
+ "Nashville": dict(inv_base=8000, inv_var=800,
829
+ s_base=1000, s_var=150,
830
+ S_base=8000, S_var=800,
831
+ lt_mean=3, lt_var=1),
832
+ "Atlanta": dict(inv_base=6000, inv_var=600,
833
+ s_base=500, s_var=80,
834
+ S_base=6000, S_var=600,
835
+ lt_mean=1, lt_var=0),
836
+ "Chicago": dict(inv_base=5000, inv_var=500,
837
+ s_base=1000, s_var=150,
838
+ S_base=5000, S_var=500,
839
+ lt_mean=8, lt_var=1),
840
+ "Charlotte": dict(inv_base=5000, inv_var=500,
841
+ s_base=1000, s_var=150,
842
+ S_base=5000, S_var=500,
843
+ lt_mean=7, lt_var=1),
844
+ "Memphis": dict(inv_base=3000, inv_var=300,
845
+ s_base=500, s_var=80,
846
+ S_base=3000, S_var=300,
847
+ lt_mean=7, lt_var=1),
848
+ "Columbus": dict(inv_base=4000, inv_var=400,
849
+ s_base=500, s_var=80,
850
+ S_base=4000, S_var=400,
851
+ lt_mean=2, lt_var=0),
852
+ "Richmond": dict(inv_base=4000, inv_var=400,
853
+ s_base=500, s_var=80,
854
+ S_base=4000, S_var=400,
855
+ lt_mean=2, lt_var=0),
856
+ "Philadelphia": dict(inv_base=3000, inv_var=300,
857
+ s_base=500, s_var=80,
858
+ S_base=3000, S_var=300,
859
+ lt_mean=1, lt_var=0),
860
+ "Baltimore": dict(inv_base=3000, inv_var=300,
861
+ s_base=500, s_var=80,
862
+ S_base=3000, S_var=300,
863
+ lt_mean=2, lt_var=0),
864
+ }
865
+
866
+ net = Network()
867
+
868
+ # Destination
869
+ net.add_node(Node(
870
+ node_id="NewYork",
871
+ lat=_COORDS["NewYork"][0], lon=_COORDS["NewYork"][1],
872
+ is_destination=True,
873
+ inventory={iid: 600 for iid in item_ids},
874
+ backlog={iid: 0 for iid in item_ids},
875
+ ))
876
+
877
+ # All other nodes
878
+ for nid, params in tier_params.items():
879
+ is_src = nid in ("SanFrancisco", "StLouis", "Orlando")
880
+ inv, s, S, lt = make_pol(**params)
881
+ lat, lon = _COORDS[nid]
882
+ net.add_node(Node(
883
+ node_id=nid,
884
+ lat=lat, lon=lon,
885
+ is_source=is_src,
886
+ inventory=inv, s_levels=s, S_levels=S, lead_time_mean=lt,
887
+ outstanding_orders={iid: None for iid in item_ids},
888
+ backlog={iid: 0 for iid in item_ids},
889
+ ))
890
+
891
+ # Edges — travel times and upstream capacities from original
892
+ edge_defs = [
893
+ ("SanFrancisco", "Nashville", 4, 5000.0, 3),
894
+ ("StLouis", "Nashville", 2, 5000.0, 3),
895
+ ("Orlando", "Nashville", 2, 5000.0, 3),
896
+ ("Nashville", "Atlanta", 1, 15000.0, 3),
897
+ ("Atlanta", "Chicago", 8, 4000.0, 3),
898
+ ("Atlanta", "Charlotte", 7, 4000.0, 3),
899
+ ("Atlanta", "Memphis", 7, 4000.0, 3),
900
+ ("Chicago", "Columbus", 2, 4000.0, 3),
901
+ ("Charlotte", "Richmond", 2, 4000.0, 3),
902
+ ("Columbus", "Philadelphia", 2, 4000.0, 3),
903
+ ("Richmond", "Philadelphia", 1, 4000.0, 3),
904
+ ("Richmond", "Baltimore", 3, 3000.0, 3),
905
+ ("Columbus", "Baltimore", 3, 3000.0, 3),
906
+ ("Memphis", "Baltimore", 2, 3000.0, 3),
907
+ # Last-mile: placeholder container_volume, back-solved below
908
+ ("Philadelphia", "NewYork", 1, 1000.0, 3),
909
+ ("Baltimore", "NewYork", 2, 1000.0, 3),
910
+ ]
911
+
912
+ for u, v, tt, cv, nc in edge_defs:
913
+ nc_scaled = max(1, int(round(nc * containers_scale)))
914
+ net.add_edge(Edge(u=u, v=v,
915
+ travel_time_days=float(tt),
916
+ container_volume=float(cv),
917
+ num_containers_per_day=nc_scaled))
918
+
919
+ # Back-solve last-mile capacity from actual demand mean (same formula as original)
920
+ actual_mean_lam = float(demand_signals.mean())
921
+ avg_vol = 2.5 # mean item volume
922
+ target_ratio = 1.20
923
+ packing_eff = 0.93
924
+ total_demand_vol = n_items * actual_mean_lam * avg_vol
925
+ raw_needed = total_demand_vol * target_ratio / packing_eff
926
+ phil_cv = round(raw_needed * 0.55 / 3 / 100) * 100
927
+ balt_cv = round(raw_needed * 0.45 / 3 / 100) * 100
928
+ phil_cv = max(phil_cv, 100.0)
929
+ balt_cv = max(balt_cv, 100.0)
930
+
931
+ for eid, cv in [(("Philadelphia", "NewYork"), phil_cv),
932
+ (("Baltimore", "NewYork"), balt_cv)]:
933
+ e = net.edges[eid]
934
+ e.container_volume = float(cv)
935
+ net.weight_cache[eid] = e.travel_time_days / max(
936
+ e.daily_total_capacity, 1e-9)
937
+
938
+ return net
939
+
940
+
941
+ # ============================================================================
942
+ # Edge Utilisation — adapted from simulator/derive_edge_files.py
943
+ # ============================================================================
944
+
945
+ def _compute_edge_util(
946
+ shipments_log: List[Dict],
947
+ edge_ids: List[Tuple[str, str]],
948
+ edge_cap: Dict[Tuple[str, str], float],
949
+ T: int,
950
+ ) -> Dict[Tuple[str, str], np.ndarray]:
951
+ """
952
+ For each edge, compute fraction of daily capacity used (0=empty, 1=full).
953
+ Attributes each shipment's volume to every edge leg on the day that leg
954
+ starts — matching the convention in derive_edge_files.py.
955
+ """
956
+ edge_to_idx = {eid: i for i, eid in enumerate(edge_ids)}
957
+ n_edges = len(edge_ids)
958
+ vol = np.zeros((T, n_edges), dtype=np.float64)
959
+
960
+ for rec in shipments_log:
961
+ path = rec["path_nodes"]
962
+ ets = rec["edge_times"]
963
+ d0 = int(rec["day"])
964
+ units = float(rec["units"])
965
+ cum = 0.0
966
+ for h in range(len(ets)):
967
+ hop = (path[h], path[h + 1])
968
+ eid_idx = edge_to_idx.get(hop)
969
+ start_day = d0 + int(round(cum))
970
+ if eid_idx is not None and 0 <= start_day < T:
971
+ vol[start_day, eid_idx] += units
972
+ cum += float(ets[h])
973
+
974
+ cap_arr = np.array([max(edge_cap.get(eid, 1.0), 1e-9)
975
+ for eid in edge_ids], dtype=np.float64)
976
+ util_mat = vol / cap_arr[np.newaxis, :]
977
+
978
+ return {eid: util_mat[:, i] for i, eid in enumerate(edge_ids)}
979
+
980
+
981
+ # ============================================================================
982
+ # Public entry point
983
+ # ============================================================================
984
+
985
+ def run_demo_simulation(config: Optional[Dict] = None) -> SimResult:
986
+ """
987
+ Run one simulation and return a SimResult.
988
+
989
+ Parameters
990
+ ----------
991
+ config : dict, optional
992
+ See module docstring for all supported keys. Missing keys use defaults.
993
+ """
994
+ cfg = config or {}
995
+
996
+ T = max(1, min(int(cfg.get("T", 365)), 1000))
997
+ n_items = max(1, min(int(cfg.get("n_items", 3)), 5))
998
+ seed = int(cfg.get("seed", 42))
999
+ pipe_mult = float(cfg.get("pipeline_mult", 7.0))
1000
+
1001
+ scenario = {
1002
+ "phi_lo": float(cfg.get("phi_lo", 0.95)),
1003
+ "phi_hi": float(cfg.get("phi_hi", 0.97)),
1004
+ "shock_count_scale": float(cfg.get("shock_count_scale", 1.0)),
1005
+ "shock_height_scale": float(cfg.get("shock_height_scale", 1.0)),
1006
+ "burst_rate_scale": float(cfg.get("burst_rate_scale", 1.0)),
1007
+ "burst_height_scale": float(cfg.get("burst_height_scale", 1.0)),
1008
+ "seasonal_scale": float(cfg.get("seasonal_scale", 1.0)),
1009
+ "base_lambda_lo": float(cfg.get("base_lambda_lo", 80.0)),
1010
+ "base_lambda_hi": float(cfg.get("base_lambda_hi", 250.0)),
1011
+ "containers_scale": float(cfg.get("containers_scale", 1.0)),
1012
+ "ss_scale": float(cfg.get("ss_scale", 1.0)),
1013
+ "leadtime_scale": float(cfg.get("leadtime_scale", 1.0)),
1014
+ }
1015
+
1016
+ disruption_edge = cfg.get("disruption_edge", None)
1017
+ disruption_prob = float(cfg.get("disruption_prob", 0.0))
1018
+ disruption_duration = int(cfg.get("disruption_duration", 10))
1019
+
1020
+ # Normalize disruption_edge: accept "NodeA,NodeB" string or (A,B) tuple
1021
+ if isinstance(disruption_edge, str) and "," in disruption_edge:
1022
+ parts = disruption_edge.split(",")
1023
+ disruption_edge = (parts[0].strip(), parts[1].strip())
1024
+
1025
+ item_ids = [f"I{i:02d}" for i in range(1, n_items + 1)]
1026
+
1027
+ # Build demand signals
1028
+ demand_fn, demand_signals = build_demand_fn(
1029
+ item_ids, T, seed=seed,
1030
+ base_lambda_range=(scenario["base_lambda_lo"],
1031
+ scenario["base_lambda_hi"]),
1032
+ scenario=scenario,
1033
+ )
1034
+
1035
+ # Build network
1036
+ net = _build_demo_network(n_items, item_ids, seed, scenario,
1037
+ demand_signals)
1038
+
1039
+ items = {iid: Item(iid, round(random.uniform(1.0, 4.0), 2))
1040
+ for iid in item_ids}
1041
+ # Use deterministic volumes (same as original: seeded separately)
1042
+ random.seed(seed)
1043
+ for iid in item_ids:
1044
+ items[iid] = Item(iid, round(random.uniform(1.0, 4.0), 2))
1045
+
1046
+ # Run simulation
1047
+ sim = DemoSimulation(
1048
+ network=net,
1049
+ items=items,
1050
+ destination_id="NewYork",
1051
+ demand_fn=demand_fn,
1052
+ T=T,
1053
+ seed=seed,
1054
+ pipeline_multiplier=pipe_mult,
1055
+ disruption_edge=disruption_edge,
1056
+ disruption_prob=disruption_prob,
1057
+ disruption_duration=disruption_duration,
1058
+ )
1059
+ sim.run()
1060
+
1061
+ # Edge identifiers and capacities
1062
+ edge_ids = list(net.edges.keys())
1063
+ edge_cap = {eid: net.edges[eid].daily_total_capacity for eid in edge_ids}
1064
+
1065
+ # Edge utilization
1066
+ edge_util = _compute_edge_util(
1067
+ sim.shipments_log, edge_ids, edge_cap, T)
1068
+
1069
+ # Fill rates
1070
+ fill_rate = {}
1071
+ for iid in item_ids:
1072
+ td = sim.svc_demand[iid]
1073
+ sv = sim.svc_served[iid]
1074
+ fill_rate[iid] = round(sv / td, 6) if td > 0 else 0.0
1075
+
1076
+ node_ids = list(net.nodes.keys())
1077
+
1078
+ return SimResult(
1079
+ T=T,
1080
+ item_ids=item_ids,
1081
+ node_ids=node_ids,
1082
+ edge_ids=edge_ids,
1083
+ inventory=sim.inv_hist,
1084
+ backlog=sim.bl_hist,
1085
+ inflow=sim.inflow,
1086
+ outflow=sim.outflow,
1087
+ demand=sim.demand_hist,
1088
+ demand_signals=demand_signals,
1089
+ shipments=sim.shipments_log,
1090
+ edge_util=edge_util,
1091
+ edge_cap=edge_cap,
1092
+ fill_rate=fill_rate,
1093
+ tier=_TIER,
1094
+ node_coords=_COORDS,
1095
+ disruption_log=sim.disruption_log,
1096
+ config=cfg,
1097
+ )
simulator/derive_edge_files.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Derive edge_list.csv, edge_utilisation.npy, edge_saturation.npy from
2
+ a finished scenario directory's shipments.csv + demand_signals.npy +
3
+ scenario.json.
4
+
5
+ The original exp_d_edge_utilisation.py source is lost (only a stale
6
+ .pyc survives); this is a clean re-implementation matching the cap
7
+ convention used by the existing 13 scenarios (verified by spot-checking
8
+ baseline / cap_0.3 / cap_2.5 edge_list.csv values).
9
+
10
+ Edge capacities follow the simulator's two-tier convention:
11
+ - Upstream edges (15 entries): STATIC_EDGES with num_containers
12
+ multiplied by scenario['containers_scale'] (rounded, min 1), and
13
+ container_volume held fixed.
14
+ - Last-mile edges (PHL->NYC, BAL->NYC): back-solved from the
15
+ realized demand mean using the same formula as the simulator
16
+ (simulate_item50.py).
17
+
18
+ Daily edge utilisation is computed by streaming shipments.csv and, for
19
+ each shipment record, walking its path_nodes/edge_times to attribute
20
+ the shipment volume to each (from, to) hop on the day that hop
21
+ starts.
22
+
23
+ Outputs:
24
+ <scenario_dir>/edge_list.csv per-edge cap_per_day table
25
+ <scenario_dir>/edge_utilisation.npy shape (T, |E|) float32 in [0, *)
26
+ <scenario_dir>/edge_saturation.npy shape (T, |E|) uint8 = (util >= tau)
27
+
28
+ Usage:
29
+ python derive_edge_files.py --scenario_dir <path> [--tau 0.9]
30
+ """
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import ast
35
+ import json
36
+ import sys
37
+ from pathlib import Path
38
+
39
+ import numpy as np
40
+ import pandas as pd
41
+
42
+ # Static upstream graph (matches simulate_item50.py).
43
+ # Each entry: (from, to, travel_time_days, container_volume, num_containers).
44
+ STATIC_UPSTREAM_EDGES: list[tuple[str, str, int, float, int]] = [
45
+ ("SanFrancisco", "Nashville", 4, 5000.0, 3),
46
+ ("StLouis", "Nashville", 2, 5000.0, 3),
47
+ ("Orlando", "Nashville", 2, 5000.0, 3),
48
+ ("Nashville", "Atlanta", 1, 15000.0, 3),
49
+ ("Atlanta", "Chicago", 8, 4000.0, 3),
50
+ ("Atlanta", "Charlotte", 7, 4000.0, 3),
51
+ ("Atlanta", "Memphis", 7, 4000.0, 3),
52
+ ("Chicago", "Columbus", 2, 4000.0, 3),
53
+ ("Charlotte", "Richmond", 2, 4000.0, 3),
54
+ ("Columbus", "Philadelphia", 2, 4000.0, 3),
55
+ ("Richmond", "Philadelphia", 1, 4000.0, 3),
56
+ ("Richmond", "Baltimore", 3, 3000.0, 3),
57
+ ("Columbus", "Baltimore", 3, 3000.0, 3),
58
+ ("Memphis", "Baltimore", 2, 3000.0, 3),
59
+ ]
60
+
61
+ LAST_MILE_EDGES: list[tuple[str, str, int]] = [
62
+ ("Philadelphia", "NewYork", 1),
63
+ ("Baltimore", "NewYork", 2),
64
+ ]
65
+
66
+ ITEM_AVG_VOL = 2.5
67
+ TARGET_RATIO = 1.20
68
+ PACKING_EFF = 0.93
69
+ PHIL_SHARE = 0.55
70
+ BALT_SHARE = 0.45
71
+
72
+
73
+ def back_solve_last_mile_cv(
74
+ n_items: int, demand_signals_path: Path,
75
+ ) -> tuple[float, float]:
76
+ """Reproduce the simulator's last-mile container_volume back-solve.
77
+
78
+ See simulate_item50.py.
79
+ """
80
+ lam = np.load(demand_signals_path)
81
+ actual_mean_lam = float(lam.mean())
82
+ total_demand_vol = n_items * actual_mean_lam * ITEM_AVG_VOL
83
+ raw_needed = total_demand_vol * TARGET_RATIO / PACKING_EFF
84
+ phil_cv = round(raw_needed * PHIL_SHARE / 3 / 100) * 100
85
+ balt_cv = round(raw_needed * BALT_SHARE / 3 / 100) * 100
86
+ return float(phil_cv), float(balt_cv)
87
+
88
+
89
+ def build_edge_df(scenario_dir: Path,
90
+ scenario: dict,
91
+ n_items: int = 50) -> pd.DataFrame:
92
+ containers_scale = float(scenario.get("containers_scale", 1.0))
93
+ phil_cv, balt_cv = back_solve_last_mile_cv(
94
+ n_items, scenario_dir / "demand_signals.npy")
95
+
96
+ rows = []
97
+ edge_id = 0
98
+ for frm, to, tt, cv, nc in STATIC_UPSTREAM_EDGES:
99
+ nc_scaled = max(1, int(round(nc * containers_scale)))
100
+ rows.append({
101
+ "edge_id": edge_id, "from": frm, "to": to,
102
+ "travel_time_days": tt,
103
+ "container_volume": cv,
104
+ "num_containers": nc_scaled,
105
+ "cap_per_day": cv * nc_scaled,
106
+ })
107
+ edge_id += 1
108
+ for (frm, to, tt), cv in zip(LAST_MILE_EDGES, [phil_cv, balt_cv]):
109
+ nc_scaled = max(1, int(round(3 * containers_scale)))
110
+ rows.append({
111
+ "edge_id": edge_id, "from": frm, "to": to,
112
+ "travel_time_days": tt,
113
+ "container_volume": cv,
114
+ "num_containers": nc_scaled,
115
+ "cap_per_day": cv * nc_scaled,
116
+ })
117
+ edge_id += 1
118
+ return pd.DataFrame(rows)
119
+
120
+
121
+ def compute_utilisation(shipments_path: Path,
122
+ edge_df: pd.DataFrame,
123
+ n_days: int,
124
+ chunksize: int = 500000) -> np.ndarray:
125
+ """Stream shipments.csv and accumulate per-edge daily volume.
126
+
127
+ For each shipment row, walks (path_nodes[h], path_nodes[h+1]) and
128
+ attributes its `units` to that edge on the day the hop starts.
129
+ Hop start day = dispatch day + cumulative edge_times of prior hops.
130
+ """
131
+ edge_to_id = {(r["from"], r["to"]): int(r["edge_id"])
132
+ for _, r in edge_df.iterrows()}
133
+ n_edges = len(edge_df)
134
+ vol = np.zeros((n_days, n_edges), dtype=np.float32)
135
+
136
+ rows_seen = 0
137
+ hops_seen = 0
138
+ hops_dropped = 0
139
+
140
+ for chunk in pd.read_csv(
141
+ shipments_path,
142
+ usecols=["day", "units", "path_nodes", "edge_times"],
143
+ dtype={"day": np.int32, "units": np.float32},
144
+ chunksize=chunksize,
145
+ ):
146
+ days = chunk["day"].to_numpy()
147
+ units = chunk["units"].to_numpy()
148
+ path_strs = chunk["path_nodes"].to_numpy()
149
+ et_strs = chunk["edge_times"].to_numpy()
150
+ for i in range(len(chunk)):
151
+ path = ast.literal_eval(path_strs[i])
152
+ ets = ast.literal_eval(et_strs[i])
153
+ d0 = int(days[i])
154
+ u = float(units[i])
155
+ cum = 0.0
156
+ for h in range(len(ets)):
157
+ hop = (path[h], path[h + 1])
158
+ eid = edge_to_id.get(hop)
159
+ start_day = d0 + int(round(cum))
160
+ if eid is not None and 0 <= start_day < n_days:
161
+ vol[start_day, eid] += u
162
+ hops_seen += 1
163
+ else:
164
+ hops_dropped += 1
165
+ cum += float(ets[h])
166
+ rows_seen += len(chunk)
167
+ print(f" processed {rows_seen} rows ({hops_seen} hops, "
168
+ f"{hops_dropped} dropped)", flush=True)
169
+
170
+ return vol
171
+
172
+
173
+ def main():
174
+ ap = argparse.ArgumentParser()
175
+ ap.add_argument("--scenario_dir", required=True,
176
+ help="Path to <output_mixture>/<name>/seed<seed>/")
177
+ ap.add_argument("--n_days", type=int, default=52560)
178
+ ap.add_argument("--n_items", type=int, default=50)
179
+ ap.add_argument("--tau", type=float, default=0.9,
180
+ help="Saturation threshold; matches the convention "
181
+ "used by the existing 13 scenarios.")
182
+ args = ap.parse_args()
183
+
184
+ sdir = Path(args.scenario_dir)
185
+ if not sdir.is_dir():
186
+ sys.exit(f"scenario_dir not found: {sdir}")
187
+ sc = json.loads((sdir / "scenario.json").read_text())
188
+
189
+ print(f"=== {sdir.name} (containers_scale="
190
+ f"{sc.get('containers_scale', 1.0)}) ===")
191
+
192
+ edge_df = build_edge_df(sdir, sc, n_items=args.n_items)
193
+ edge_df.to_csv(sdir / "edge_list.csv", index=False)
194
+ print(f" wrote edge_list.csv (|E|={len(edge_df)})")
195
+
196
+ vol = compute_utilisation(
197
+ sdir / "shipments.csv", edge_df, args.n_days)
198
+ cap = edge_df["cap_per_day"].to_numpy(dtype=np.float32)
199
+ util = vol / np.maximum(cap, 1e-9)
200
+ sat = (util >= args.tau).astype(np.uint8)
201
+
202
+ np.save(sdir / "edge_utilisation.npy", util.astype(np.float32))
203
+ np.save(sdir / "edge_saturation.npy", sat)
204
+ print(f" saved edge_utilisation.npy shape={util.shape}")
205
+ print(f" saved edge_saturation.npy shape={sat.shape} "
206
+ f"saturated_frac={sat.mean():.4f}")
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
uq/plot_uq_envelope.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """§4 UQ forecast-envelope figure: parameter UQ propagated to forecaster output.
2
+
3
+ For each model with K=20 LHS perturbation tensors, plots median + 10/90
4
+ band of y_true (input) and y_pred (output) across K, at one or three
5
+ forecast windows for one item. The grey band is the band of physical
6
+ realisations the network produces under demand-side parameter perturbation;
7
+ the coloured band is the band of zero-shot forecasts of those realisations.
8
+
9
+ python plot_uq_envelope.py # 2x2, deterministic mid window
10
+ python plot_uq_envelope.py --multi # 3x4 multi-window grid
11
+ python plot_uq_envelope.py --window 25 # explicit single window
12
+ python plot_uq_envelope.py --item I05 chronos # subset of models
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+
22
+ REPO = Path(__file__).resolve().parents[1]
23
+ RESULT_DIR = REPO / "results" / "eval" / "uq"
24
+ FIG_DIR = REPO / "results" / "uq" / "figures"
25
+
26
+ MODELS = {
27
+ "chronos": {"prefix": "chronos_t5_base",
28
+ "fill": "#9BB0CC", "line": "#2F4A75",
29
+ "display": "Chronos"},
30
+ "moirai": {"prefix": "moirai_1_1_R_base",
31
+ "fill": "#A8BFA0", "line": "#345531",
32
+ "display": "Moirai"},
33
+ "timesfm": {"prefix": "timesfm_2_0_500m_pytorch",
34
+ "fill": "#D7A992", "line": "#693220",
35
+ "display": "TimesFM"},
36
+ "lagllama": {"prefix": "lag_llama",
37
+ "fill": "#C2A8CC", "line": "#4D2752",
38
+ "display": "Lag-Llama"},
39
+ }
40
+ TRUTH_BAND_COLOR = "#9CA3AF"
41
+ TRUTH_LINE_COLOR = "#374151"
42
+ AXES_FACE = "#FFFFFF"
43
+ SPINE_COLOR = "#374151"
44
+
45
+
46
+ def load_K_tensors(prefix: str, K: int = 20):
47
+ yp_list, yt_list = [], []
48
+ item_ids = window_starts = None
49
+ for k in range(1, K + 1):
50
+ p = RESULT_DIR / f"{prefix}_perturb_k{k:02d}_tensors.npz"
51
+ if not p.exists():
52
+ return None
53
+ d = np.load(p)
54
+ yp_list.append(d["y_pred"])
55
+ yt_list.append(d["y_true"])
56
+ if item_ids is None:
57
+ item_ids, window_starts = d["item_ids"], d["window_starts"]
58
+ y_pred = np.stack(yp_list, axis=0) # (K, W, H, C)
59
+ y_true = np.stack(yt_list, axis=0)
60
+ return y_pred, y_true, item_ids, window_starts
61
+
62
+
63
+ def deterministic_windows(W: int, n: int) -> list[int]:
64
+ """Evenly-spaced windows inside the test split, avoiding the edges."""
65
+ if n == 1:
66
+ return [W // 2]
67
+ return [int(round(W * (i + 1) / (n + 1))) for i in range(n)]
68
+
69
+
70
+ def draw_band(ax, arr_kh: np.ndarray, fill_color: str, line_color: str,
71
+ label: str, fill_alpha: float, lw: float, z_base: int):
72
+ h = np.arange(1, arr_kh.shape[1] + 1)
73
+ med = np.median(arr_kh, axis=0)
74
+ q10 = np.percentile(arr_kh, 10, axis=0)
75
+ q90 = np.percentile(arr_kh, 90, axis=0)
76
+ ax.fill_between(h, q10, q90, color=fill_color, alpha=fill_alpha,
77
+ linewidth=0, zorder=z_base)
78
+ ax.plot(h, med, color=line_color, lw=lw, label=label, zorder=z_base + 2)
79
+
80
+
81
+ def style_axes(ax):
82
+ ax.set_facecolor(AXES_FACE)
83
+ for side in ("top", "right", "left", "bottom"):
84
+ ax.spines[side].set_color(SPINE_COLOR)
85
+ ax.spines[side].set_linewidth(0.7)
86
+ ax.tick_params(colors=SPINE_COLOR, length=3, width=0.7)
87
+ for label in ax.get_xticklabels() + ax.get_yticklabels():
88
+ label.set_color(SPINE_COLOR)
89
+ ax.grid(True, linestyle=':', alpha=0.35, linewidth=0.5,
90
+ color=SPINE_COLOR)
91
+
92
+
93
+ def _draw_main(ax, yp_kh, yt_kh, model_key, draw_legend):
94
+ draw_band(ax, yt_kh,
95
+ fill_color=TRUTH_BAND_COLOR, line_color=TRUTH_LINE_COLOR,
96
+ label=r"truth $y_{i,t}$",
97
+ fill_alpha=0.30, lw=1.0, z_base=1)
98
+ draw_band(ax, yp_kh,
99
+ fill_color=MODELS[model_key]["fill"],
100
+ line_color=MODELS[model_key]["line"],
101
+ label=r"forecast $\hat y_{i,t}$",
102
+ fill_alpha=0.30, lw=1.8, z_base=3)
103
+ style_axes(ax)
104
+ if draw_legend:
105
+ leg = ax.legend(loc="upper left", fontsize=8.5, frameon=True,
106
+ facecolor=AXES_FACE, edgecolor=SPINE_COLOR)
107
+ leg.get_frame().set_linewidth(0.6)
108
+ for txt in leg.get_texts():
109
+ txt.set_color(SPINE_COLOR)
110
+
111
+
112
+ def _zoom_ylim(yp_kh, yt_kh, pad_frac: float = 0.08):
113
+ yt_med = np.median(yt_kh, axis=0)
114
+ yp_med = np.median(yp_kh, axis=0)
115
+ y_lo = float(min(yt_med.min(), yp_med.min()))
116
+ y_hi = float(max(yt_med.max(), yp_med.max()))
117
+ span = max(y_hi - y_lo, 1e-6)
118
+ pad = pad_frac * span
119
+ return y_lo - pad, y_hi + pad
120
+
121
+
122
+ def _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key):
123
+ """Sibling axes below `main_ax` with the same bands and medians, but
124
+ y-axis tightened to the median range. Also shades the corresponding
125
+ horizontal slice on `main_ax` so the link is explicit.
126
+ """
127
+ h = np.arange(1, yp_kh.shape[1] + 1)
128
+ yt_med = np.median(yt_kh, axis=0)
129
+ yp_med = np.median(yp_kh, axis=0)
130
+ yt_q10 = np.percentile(yt_kh, 10, axis=0)
131
+ yt_q90 = np.percentile(yt_kh, 90, axis=0)
132
+ yp_q10 = np.percentile(yp_kh, 10, axis=0)
133
+ yp_q90 = np.percentile(yp_kh, 90, axis=0)
134
+ y_lo, y_hi = _zoom_ylim(yp_kh, yt_kh)
135
+
136
+ zoom_ax.fill_between(h, yt_q10, yt_q90, color=TRUTH_BAND_COLOR,
137
+ alpha=0.30, linewidth=0, zorder=1)
138
+ zoom_ax.plot(h, yt_med, color=TRUTH_LINE_COLOR, lw=1.0, zorder=3)
139
+ zoom_ax.fill_between(h, yp_q10, yp_q90,
140
+ color=MODELS[model_key]["fill"],
141
+ alpha=0.30, linewidth=0, zorder=2)
142
+ zoom_ax.plot(h, yp_med, color=MODELS[model_key]["line"], lw=1.6, zorder=4)
143
+ zoom_ax.set_xlim(int(h[0]), int(h[-1]))
144
+ zoom_ax.set_ylim(y_lo, y_hi)
145
+ style_axes(zoom_ax)
146
+ zoom_ax.tick_params(axis="y", labelsize=8)
147
+
148
+ # Mark the zoom y-slice on the parent so the reader sees exactly which
149
+ # part of the main panel is being zoomed.
150
+ main_ax.axhspan(y_lo, y_hi, color=SPINE_COLOR, alpha=0.10,
151
+ linewidth=0, zorder=0.5)
152
+ main_ax.axhline(y_lo, color=SPINE_COLOR, lw=0.5, ls=":",
153
+ alpha=0.7, zorder=0.6)
154
+ main_ax.axhline(y_hi, color=SPINE_COLOR, lw=0.5, ls=":",
155
+ alpha=0.7, zorder=0.6)
156
+
157
+
158
+ def plot_2x2(data: dict, item_id: str, item_idx: int, w: int,
159
+ window_start: int, fig_path: Path, with_zoom: bool = True):
160
+ fig_h = 8.4 if with_zoom else 5.6
161
+ fig = plt.figure(figsize=(9.6, fig_h))
162
+ fig.patch.set_facecolor("white")
163
+ outer = fig.add_gridspec(2, 2, hspace=0.30, wspace=0.18,
164
+ left=0.07, right=0.99, top=0.94, bottom=0.07)
165
+ items = list(data.items())
166
+ placements = [(0, 0), (0, 1), (1, 0), (1, 1)]
167
+ for (ri, ci), (model_key, (yp, yt)) in zip(placements, items):
168
+ yp_kh = yp[:, w, :, item_idx]
169
+ yt_kh = yt[:, w, :, item_idx]
170
+ if with_zoom:
171
+ inner = outer[ri, ci].subgridspec(
172
+ 2, 1, height_ratios=[2.6, 1.9], hspace=0.06)
173
+ main_ax = fig.add_subplot(inner[0])
174
+ zoom_ax = fig.add_subplot(inner[1], sharex=main_ax)
175
+ else:
176
+ main_ax = fig.add_subplot(outer[ri, ci])
177
+ zoom_ax = None
178
+ _draw_main(main_ax, yp_kh, yt_kh, model_key,
179
+ draw_legend=(ri == 0 and ci == 0))
180
+ main_ax.set_title(MODELS[model_key]["display"], fontsize=11,
181
+ color=SPINE_COLOR)
182
+ if zoom_ax is not None:
183
+ _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key)
184
+ plt.setp(main_ax.get_xticklabels(), visible=False)
185
+ if ri == 1:
186
+ (zoom_ax if zoom_ax is not None else main_ax).set_xlabel(
187
+ r"forecast horizon $h$ (time units)", color=SPINE_COLOR)
188
+ if ci == 0:
189
+ main_ax.set_ylabel(f"item {item_id} demand",
190
+ color=SPINE_COLOR)
191
+ if zoom_ax is not None:
192
+ zoom_ax.set_ylabel("zoom (medians)", fontsize=8.5,
193
+ color=SPINE_COLOR)
194
+ fig.savefig(fig_path, bbox_inches="tight", facecolor="white")
195
+ fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight",
196
+ dpi=160, facecolor="white")
197
+ plt.close(fig)
198
+
199
+
200
+ def plot_multi(data: dict, item_id: str, item_idx: int,
201
+ windows: list[int], window_starts_arr: np.ndarray,
202
+ fig_path: Path):
203
+ """Grid: rows = windows, cols = models. Each cell is a (main, zoom)
204
+ vertical pair sharing x; the zoom row uses tightened y-limits."""
205
+ n_rows = len(windows)
206
+ n_cols = len(data)
207
+ fig = plt.figure(figsize=(3.2 * n_cols, 4.0 * n_rows))
208
+ fig.patch.set_facecolor("white")
209
+ outer = fig.add_gridspec(n_rows, n_cols, hspace=0.32, wspace=0.20,
210
+ left=0.06, right=0.99, top=0.95, bottom=0.06)
211
+ model_items = list(data.items())
212
+ for r, w in enumerate(windows):
213
+ t0 = int(window_starts_arr[w])
214
+ for c, (model_key, (yp, yt)) in enumerate(model_items):
215
+ inner = outer[r, c].subgridspec(
216
+ 2, 1, height_ratios=[2.6, 1.9], hspace=0.06)
217
+ main_ax = fig.add_subplot(inner[0])
218
+ zoom_ax = fig.add_subplot(inner[1], sharex=main_ax)
219
+ yp_kh = yp[:, w, :, item_idx]
220
+ yt_kh = yt[:, w, :, item_idx]
221
+ _draw_main(main_ax, yp_kh, yt_kh, model_key,
222
+ draw_legend=(r == 0 and c == 0))
223
+ if r == 0:
224
+ main_ax.set_title(MODELS[model_key]["display"], fontsize=11,
225
+ color=SPINE_COLOR)
226
+ _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key)
227
+ plt.setp(main_ax.get_xticklabels(), visible=False)
228
+ if c == 0:
229
+ main_ax.set_ylabel(f"$t_0{{=}}{t0}$", fontsize=10,
230
+ color=SPINE_COLOR)
231
+ zoom_ax.set_ylabel("zoom", fontsize=8.5, color=SPINE_COLOR)
232
+ if r == n_rows - 1:
233
+ zoom_ax.set_xlabel(r"forecast horizon $h$ (time units)",
234
+ color=SPINE_COLOR)
235
+ fig.text(0.005, 0.5, f"item {item_id} demand",
236
+ rotation="vertical", va="center", ha="left",
237
+ fontsize=10.5, color=SPINE_COLOR)
238
+ fig.savefig(fig_path, bbox_inches="tight", facecolor="white")
239
+ fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight",
240
+ dpi=160, facecolor="white")
241
+ plt.close(fig)
242
+
243
+
244
+ def main():
245
+ args = sys.argv[1:]
246
+ item = "I01"
247
+ window_arg: int | None = None
248
+ multi = False
249
+ narrowest = False
250
+ if "--item" in args:
251
+ i = args.index("--item"); item = args[i + 1]; args = args[:i] + args[i + 2:]
252
+ if "--window" in args:
253
+ i = args.index("--window"); window_arg = int(args[i + 1])
254
+ args = args[:i] + args[i + 2:]
255
+ if "--multi" in args:
256
+ multi = True; args.remove("--multi")
257
+ if "--narrowest" in args:
258
+ narrowest = True; args.remove("--narrowest")
259
+ no_zoom = False
260
+ if "--no-zoom" in args:
261
+ no_zoom = True; args.remove("--no-zoom")
262
+ requested = args if args else list(MODELS.keys())
263
+ bad = [m for m in requested if m not in MODELS]
264
+ if bad:
265
+ sys.exit(f"unknown model(s): {bad}; choose from {list(MODELS.keys())}")
266
+
267
+ FIG_DIR.mkdir(parents=True, exist_ok=True)
268
+ data: dict = {}
269
+ item_ids = window_starts = None
270
+ for m in requested:
271
+ out = load_K_tensors(MODELS[m]["prefix"])
272
+ if out is None:
273
+ print(f"[{m}] missing tensors, skipping")
274
+ continue
275
+ yp, yt, ids, ws = out
276
+ if item_ids is None:
277
+ item_ids, window_starts = ids, ws
278
+ data[m] = (yp, yt)
279
+ print(f"[{m}] y_pred={yp.shape}, y_true={yt.shape}")
280
+ if not data:
281
+ sys.exit("no models loaded")
282
+
283
+ item_idx = list(item_ids).index(item)
284
+ W_total = next(iter(data.values()))[0].shape[1]
285
+
286
+ if multi:
287
+ windows = deterministic_windows(W_total, 3)
288
+ print(f"item {item} (idx={item_idx}); multi-window grid w={windows} "
289
+ f"(t0={[int(window_starts[w]) for w in windows]})")
290
+ out = FIG_DIR / f"uq_envelope_{item}_multi.pdf"
291
+ plot_multi(data, item, item_idx, windows,
292
+ np.asarray(window_starts), out)
293
+ print(f"wrote {out}")
294
+ return
295
+
296
+ if window_arg is None:
297
+ if narrowest:
298
+ yt_any = next(iter(data.values()))[1]
299
+ win_mean = yt_any[:, :, :, item_idx].mean(axis=-1) # (K, W)
300
+ spread = win_mean.std(axis=0) # (W,)
301
+ w = int(spread.argmin())
302
+ print(f"min cross-K spread window: w={w}, "
303
+ f"spread={spread[w]:.2f} "
304
+ f"(min={spread.min():.2f}, max={spread.max():.2f})")
305
+ else:
306
+ w = 12
307
+ print(f"default window: w={w}")
308
+ else:
309
+ w = window_arg
310
+ t0 = int(window_starts[w])
311
+ print(f"item {item} (idx={item_idx}); window w={w}, t_start={t0}")
312
+
313
+ if window_arg is not None:
314
+ suffix = f"{item}_w{w:02d}"
315
+ elif narrowest:
316
+ suffix = f"{item}_narrowest"
317
+ else:
318
+ suffix = item
319
+ if no_zoom:
320
+ suffix = f"{suffix}_nozoom"
321
+ out = FIG_DIR / f"uq_envelope_{suffix}.pdf"
322
+ plot_2x2(data, item, item_idx, w, t0, out, with_zoom=not no_zoom)
323
+ print(f"wrote {out}")
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
uq/sample_lhs.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Latin hypercube sample of K=20 demand-side perturbation configurations
2
+ for the §4.4 UQ experiment.
3
+
4
+ Three demand-side knobs, perturbed jointly around the §4 baseline:
5
+ phi_AR in [0.95, 0.999] AR(1) drift coefficient (used directly)
6
+ rho_G in [0.5, 2.0 ] macro-shock multiplier (applied jointly to
7
+ shock_count_scale and shock_height_scale)
8
+ rho_B in [0.5, 2.0 ] burst multiplier (applied jointly to
9
+ burst_rate_scale and burst_height_scale)
10
+
11
+ Writes output_uq/manifest.csv with columns (k, phi_AR, rho_G, rho_B).
12
+ The submit script reads this and launches one simulator job per row.
13
+
14
+ Synchronous scaling for shock and burst matches the §3.3 axis convention
15
+ (shock axis = N x h^G; burst axis = r x h^P). Drift is a single scalar
16
+ per run, replacing the per-item U[phi_lo, phi_hi] draw with phi_lo =
17
+ phi_hi = phi_AR_k (matching the §3.3 drift sweep convention).
18
+ """
19
+ from __future__ import annotations
20
+
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ from scipy.stats import qmc
26
+
27
+
28
+ K = 20
29
+ SEED = 2025
30
+
31
+ # (lo, hi) for each of the three demand-side knobs
32
+ RANGES = {
33
+ "phi_AR": (0.95, 0.999),
34
+ "rho_G": (0.5, 2.0),
35
+ "rho_B": (0.5, 2.0),
36
+ }
37
+
38
+ REPO = Path(__file__).resolve().parents[1]
39
+ OUT_DIR = REPO / "data" / "output_uq"
40
+
41
+
42
+ def main() -> None:
43
+ sampler = qmc.LatinHypercube(d=len(RANGES), seed=SEED)
44
+ unit = sampler.random(n=K) # (K, 3) in [0, 1)
45
+
46
+ los = np.array([RANGES[k][0] for k in RANGES])
47
+ his = np.array([RANGES[k][1] for k in RANGES])
48
+ scaled = los + unit * (his - los) # (K, 3) in ranges
49
+
50
+ df = pd.DataFrame({
51
+ "k": np.arange(1, K + 1),
52
+ "phi_AR": np.round(scaled[:, 0], 6),
53
+ "rho_G": np.round(scaled[:, 1], 6),
54
+ "rho_B": np.round(scaled[:, 2], 6),
55
+ })
56
+
57
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
58
+ out_path = OUT_DIR / "manifest.csv"
59
+ df.to_csv(out_path, index=False)
60
+ print(f"Wrote {out_path}")
61
+ print(df.to_string(index=False))
62
+
63
+ print("\nRange checks:")
64
+ for col in ["phi_AR", "rho_G", "rho_B"]:
65
+ lo, hi = RANGES[col]
66
+ print(f" {col:7s}: [{df[col].min():.4f}, {df[col].max():.4f}] "
67
+ f"(target [{lo}, {hi}])")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()