ISOMORPH-demo / demo /visualize.py
HyeminGu
Modified GIF
3eb868e
"""
ISOMORPH Demo — Visualization Module
=====================================
All functions return plotly.graph_objects.Figure objects for rendering in
Gradio gr.Plot components.
Public API
----------
make_network_animation(result, frame_step=None)
Animated US map: shipment particles travel along edges, nodes colored
by backlog stress. Play / Pause controls + scrubber slider.
make_node_timeseries(result, node_id, item_ids=None)
Subplot panel for one node: inventory, backlog, inflow, outflow,
and (at the destination) realized demand.
make_bullwhip_chart(result)
Bar chart of mean B = Var(inflow)/Var(outflow) per tier, from
destination (left) to hub/sources (right).
make_edge_heatmap(result)
Heatmap of edge utilization (fraction of daily capacity) over time.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ── Color constants ───────────────────────────────────────────────────────────
# One color per item (up to 5) — matplotlib tab10 subset
ITEM_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd", "#d62728"]
# Node stress: green (0 = no backlog) → yellow → red (1 = all backlog)
NODE_COLORSCALE = [[0.0, "#4CAF50"], [0.35, "#FFC107"], [1.0, "#F44336"]]
# Visual size of node marker by tier
_TIER_MARKER_SIZE = {0: 22, 1: 16, 2: 13, 3: 13, 4: 15, 5: 18, 6: 11}
# Human-readable tier labels (for bullwhip x-axis)
_TIER_LABEL = {
0: "Destination",
1: "Last-mile",
2: "Tier-4",
3: "Tier-3",
4: "Tier-2<br>(Atlanta)",
5: "Hub<br>(Nashville)",
6: "Sources",
}
# Descriptive role name shown in hover tooltip and node-type legend
_TIER_TYPE_NAME = {
0: "Destination (end customer)",
1: "Last-mile DC",
2: "Tier-4 Warehouse",
3: "Tier-3 Warehouse",
4: "Tier-2 Warehouse (Atlanta)",
5: "Regional Hub (Nashville)",
6: "Source / Supplier",
}
# Plotly Scattergeo marker symbol per tier — visually distinguishes role
_TIER_SYMBOL = {
0: "star", # Destination — prominent star
1: "square", # Last-mile DCs
2: "circle", # Tier-4 warehouses
3: "circle", # Tier-3 warehouses
4: "diamond", # Tier-2 (Atlanta pivot)
5: "hexagram", # Hub (Nashville)
6: "triangle-up", # Source suppliers
}
# Matplotlib equivalents for GIF rendering (one marker per node requires individual scatter calls)
_MPL_MARKER = {
"star": "*",
"square": "s",
"circle": "o",
"diamond": "D",
"hexagram": "h",
"triangle-up": "^",
}
# ============================================================================
# Private helpers
# ============================================================================
def _edge_travel_times(result) -> Dict[Tuple[str, str], float]:
"""Derive edge travel times from the first shipment that uses each hop."""
tt: Dict[Tuple[str, str], float] = {}
for s in result.shipments:
path = s["path_nodes"]
ets = s["edge_times"]
for h in range(len(ets)):
hop = (path[h], path[h + 1])
if hop not in tt:
tt[hop] = float(ets[h])
if len(tt) == len(result.edge_ids):
return tt # early exit once all edges covered
return tt
def _node_stress(result, t: int) -> List[float]:
"""
Return a per-node backlog-fraction in [0, 1] at timestep t.
stress = total_backlog / (total_backlog + total_inventory + 1)
Ordered to match result.node_ids.
"""
fracs = []
for nid in result.node_ids:
bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
fracs.append(bl / max(bl + inv, 1))
return fracs
def _node_hover_custom(result, t: int,
type_names: Optional[List[str]] = None) -> List[List]:
"""
Return per-node customdata rows [inv, bl, type_name] at time t.
type_names (static) is passed once and repeated in every frame so the
hovertemplate can reference %{customdata[2]}.
"""
rows = []
for i, nid in enumerate(result.node_ids):
inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
tname = type_names[i] if type_names else ""
rows.append([inv, bl, tname])
return rows
def _interpolate_position(
shipment: dict,
t: int,
node_coords: Dict[str, Tuple[float, float]],
) -> Optional[Tuple[float, float]]:
"""
Return (lat, lon) for a shipment at integer day t, or None if not active.
The particle travels along path_nodes using edge_times for timing.
"""
d = shipment["day"]
a = shipment["arrival_day"]
if t < d or t >= a:
return None
path = shipment["path_nodes"]
ets = shipment["edge_times"]
elapsed = float(t - d)
cum = 0.0
for h, tt in enumerate(ets):
tt = float(tt)
if elapsed <= cum + tt + 1e-9:
frac = (elapsed - cum) / max(tt, 1e-6)
frac = max(0.0, min(1.0, frac))
lat_a, lon_a = node_coords[path[h]]
lat_b, lon_b = node_coords[path[h + 1]]
return lat_a + frac * (lat_b - lat_a), lon_a + frac * (lon_b - lon_a)
cum += tt
return None
def _frame_particles(result, t: int, item_filter: Optional[str] = None):
"""
Return (lats, lons, texts, colors) for all active shipment particles
at timestep t. colors is a list of CSS color strings keyed by item.
"""
item_color_map = {iid: ITEM_COLORS[i % len(ITEM_COLORS)]
for i, iid in enumerate(result.item_ids)}
lats, lons, texts, colors = [], [], [], []
for s in result.shipments:
if item_filter and s["item"] != item_filter:
continue
pos = _interpolate_position(s, t, result.node_coords)
if pos is None:
continue
lat, lon = pos
lats.append(lat)
lons.append(lon)
texts.append(
f"<b>{s['item']}</b><br>{s['from']}{s['to']}<br>"
f"qty: {s['units']} day {s['day']}{s['arrival_day']}"
)
colors.append(item_color_map[s["item"]])
return lats, lons, texts, colors
def _bullwhip_ratios(result) -> Dict[str, Dict[str, float]]:
"""
Compute B_(n,i) = Var(inflow_(n,i)) / Var(outflow_(n,i)) per node/item.
At the destination node outflow is replaced by customer demand.
Source nodes (tier 6) are excluded (no inflow on network edges).
"""
dest_id = next(
nid for nid in result.node_ids if result.tier.get(nid) == 0
)
ratios: Dict[str, Dict[str, float]] = {}
for nid in result.node_ids:
tier = result.tier.get(nid, -1)
if tier == 6:
continue
ratios[nid] = {}
for iid in result.item_ids:
inflow_arr = result.inflow[nid][iid].astype(float)
if nid == dest_id:
outflow_arr = result.demand[iid].astype(float)
else:
outflow_arr = result.outflow[nid][iid].astype(float)
var_in = float(np.var(inflow_arr, ddof=1)) if result.T > 1 else 0.0
var_out = float(np.var(outflow_arr, ddof=1)) if result.T > 1 else 0.0
if var_out > 1.0: # guard against near-zero outflow variance
ratios[nid][iid] = var_in / var_out
return ratios
# ============================================================================
# 1. Animated network map
# ============================================================================
def make_network_animation(
result,
frame_step: Optional[int] = None,
frame_duration_ms: int = 150,
) -> go.Figure:
"""
Animated Scattergeo map of the supply-chain network.
Parameters
----------
result : SimResult
frame_step : int or None
Days between animation frames. Auto-computed to cap at ~200 frames.
frame_duration_ms : int
Milliseconds per frame during playback.
"""
T = result.T
max_frames = 200
if frame_step is None:
frame_step = max(1, T // max_frames)
frame_times = list(range(0, T, frame_step))
node_ids = result.node_ids
coords = result.node_coords
edge_tt = _edge_travel_times(result)
# ── Node metadata ──────────────────────────────────────────────────────
node_lats = [coords[n][0] for n in node_ids]
node_lons = [coords[n][1] for n in node_ids]
node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13)
for n in node_ids]
node_labels = node_ids
node_type_names = [_TIER_TYPE_NAME.get(result.tier.get(n, -1), "Warehouse")
for n in node_ids]
node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle")
for n in node_ids]
# ── Edge traces (static) ───────────────────────────────────────────────
edge_traces = []
max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0
for eid in result.edge_ids:
u, v = eid
if u not in coords or v not in coords:
continue
cap = result.edge_cap.get(eid, 1.0)
lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5
tt = edge_tt.get(eid, "?")
edge_traces.append(go.Scattergeo(
lat=[coords[u][0], coords[v][0], None],
lon=[coords[u][1], coords[v][1], None],
mode="lines",
line=dict(width=lw, color="rgba(140,140,160,0.55)"),
hoverinfo="text",
text=f"{u}{v}<br>travel: {tt} day(s)<br>daily cap: {cap:.0f} units",
showlegend=False,
name=f"{u}{v}",
))
n_edge_traces = len(edge_traces)
# ── Initial node trace (t=0) ───────────────────────────────────────────
stress_0 = _node_stress(result, 0)
custom_0 = _node_hover_custom(result, 0, node_type_names)
node_trace = go.Scattergeo(
lat=node_lats,
lon=node_lons,
mode="markers+text",
text=node_labels,
textposition="top center",
textfont=dict(size=9, color="black"),
marker=dict(
size=node_sizes,
symbol=node_symbols,
color=stress_0,
colorscale=NODE_COLORSCALE,
cmin=0.0, cmax=1.0,
colorbar=dict(
title="Backlog<br>stress",
thickness=12,
len=0.5,
x=1.01,
tickvals=[0, 0.5, 1],
ticktext=["0 (healthy)", "0.5", "1 (stockout)"],
tickfont=dict(size=9),
),
line=dict(width=1.5, color="white"),
),
customdata=custom_0,
hovertemplate=(
"<b>%{text}</b><br>"
"Role: %{customdata[2]}<br>"
"Total inventory: %{customdata[0]:,} units<br>"
"Total backlog: %{customdata[1]:,} units<br>"
"<extra></extra>"
),
showlegend=False,
name="nodes",
)
# ── Initial shipment trace (t=0) ───────────────────────────────────────
lats0, lons0, texts0, colors0 = _frame_particles(result, 0)
ship_trace = go.Scattergeo(
lat=lats0,
lon=lons0,
mode="markers",
marker=dict(size=7, color=colors0, opacity=0.85,
line=dict(width=0.5, color="white")),
hoverinfo="text",
text=texts0,
showlegend=False,
name="shipments",
)
# ── Legend: item colors (shipment dots) ──────────────────────────────
item_legend_traces = []
for i, iid in enumerate(result.item_ids):
item_legend_traces.append(go.Scattergeo(
lat=[None], lon=[None],
mode="markers",
marker=dict(size=9, color=ITEM_COLORS[i % len(ITEM_COLORS)],
symbol="circle"),
name=iid,
legendgrouptitle_text="Shipment items" if i == 0 else None,
legendgroup="items",
showlegend=True,
))
# ── Legend: node-type symbols ─────────────────────────────────────────
# One proxy trace per distinct tier present in this result.
seen_tiers: dict = {}
for n in node_ids:
t_val = result.tier.get(n, -1)
if t_val not in seen_tiers:
seen_tiers[t_val] = (_TIER_TYPE_NAME.get(t_val, "Unknown"),
_TIER_SYMBOL.get(t_val, "circle"))
node_type_legend_traces = []
for i, (t_val, (tname, sym)) in enumerate(
sorted(seen_tiers.items(), key=lambda x: x[0])):
node_type_legend_traces.append(go.Scattergeo(
lat=[None], lon=[None],
mode="markers",
marker=dict(size=10, color="gray", symbol=sym),
name=tname,
legendgrouptitle_text="Node types" if i == 0 else None,
legendgroup="node_types",
showlegend=True,
))
# ── Assemble base figure ───────────────────────────────────────────────
all_traces = (edge_traces + [node_trace, ship_trace]
+ item_legend_traces + node_type_legend_traces)
fig = go.Figure(data=all_traces)
# Trace indices of dynamic traces (edge + node + ship; legend traces follow)
idx_nodes = n_edge_traces
idx_ships = n_edge_traces + 1
# ── Pre-compute frames ─────────────────────────────────────────────────
frames = []
for t in frame_times:
stress_t = _node_stress(result, t)
custom_t = _node_hover_custom(result, t, node_type_names)
lats_t, lons_t, texts_t, colors_t = _frame_particles(result, t)
frames.append(go.Frame(
data=[
# Plain dicts avoid serialising default None lat/lon values
# that go.Scattergeo() would inject and confuse Plotly.js.
{"type": "scattergeo",
"marker": {"color": stress_t},
"customdata": custom_t},
{"type": "scattergeo",
"lat": lats_t, "lon": lons_t,
"text": texts_t,
"marker": {"color": colors_t, "size": 7, "opacity": 0.85}},
],
traces=[idx_nodes, idx_ships],
layout={"title": {"text": f"<b>ISOMORPH Supply Chain Digital Twin</b> — Day {t} / {T - 1}"}},
name=str(t),
))
fig.frames = frames
# ── Slider steps ──────────────────────────────────────────────────────
slider_steps = [
dict(
method="animate",
args=[[str(t)],
dict(mode="immediate",
frame=dict(duration=0, redraw=True),
transition=dict(duration=0))],
label=str(t),
)
for t in frame_times
]
# ── Layout ────────────────────────────────────────────────────────────
fig.update_layout(
title=dict(
text=f"<b>ISOMORPH Supply Chain Digital Twin</b> — Day 0 / {T - 1}",
x=0.5, xanchor="center", font=dict(size=14),
),
geo=dict(
scope="usa",
projection_type="albers usa",
showland=True, landcolor="rgb(243,243,243)",
showlakes=True, lakecolor="rgb(210,230,255)",
showrivers=True, rivercolor="rgb(210,230,255)",
showcoastlines=True, coastlinecolor="rgb(180,180,200)",
showsubunits=True, subunitcolor="rgb(200,200,215)",
bgcolor="rgba(255,255,255,0)",
),
legend=dict(
x=0.01, y=0.01,
bgcolor="rgba(255,255,255,0.82)",
bordercolor="lightgray", borderwidth=1,
font=dict(size=9),
tracegroupgap=6,
),
updatemenus=[dict(
type="buttons",
showactive=False,
y=1.08, x=0.0, xanchor="left",
buttons=[
dict(
label="▶ Play",
method="animate",
args=[None, dict(
frame=dict(duration=frame_duration_ms, redraw=True),
fromcurrent=True,
transition=dict(duration=0),
)],
),
dict(
label="⏸ Pause",
method="animate",
args=[[None], dict(
frame=dict(duration=0, redraw=False),
mode="immediate",
transition=dict(duration=0),
)],
),
],
)],
sliders=[dict(
currentvalue=dict(
prefix="Day: ",
font=dict(size=11),
visible=True,
xanchor="center",
),
pad=dict(t=50, b=10),
len=0.9,
x=0.05,
steps=slider_steps,
transition=dict(duration=0),
)],
margin=dict(l=0, r=0, t=80, b=60),
height=540,
paper_bgcolor="white",
)
return fig
# ============================================================================
# 2. Node detail time-series panel
# ============================================================================
def make_node_timeseries(
result,
node_id: str,
item_ids: Optional[List[str]] = None,
) -> go.Figure:
"""
4–5 subplot panel for a single node showing inventory, backlog, inflow,
outflow, and (destination only) realized demand.
Parameters
----------
result : SimResult
node_id : str
Node to visualise.
item_ids : list[str] or None
Subset of items to plot. Defaults to all items in result.
"""
if item_ids is None:
item_ids = result.item_ids
dest_id = next(n for n in result.node_ids if result.tier.get(n) == 0)
is_dest = (node_id == dest_id)
# Use plain Python lists — avoids numpy int32/int64 serialization issues
# inside Gradio 4.x's gr.Plot JSON path.
days = list(range(result.T))
n_panels = 5 if is_dest else 4
panel_titles = [
"On-hand inventory (units stored at this node)",
"Backlog (unfulfilled demand pending)",
"Inflow (units arriving per day)",
"Outflow (units shipped out per day)",
]
if is_dest:
panel_titles.append("Realized demand (customer orders received per day)")
fig = make_subplots(
rows=n_panels, cols=1,
shared_xaxes=True,
subplot_titles=panel_titles,
vertical_spacing=0.06,
)
for idx, iid in enumerate(item_ids):
color = ITEM_COLORS[idx % len(ITEM_COLORS)]
# Capture loop variables by value via default args to avoid closure issues.
def _add(row, y_arr, dash="solid", _iid=iid, _color=color):
fig.add_trace(
go.Scatter(
x=days,
y=y_arr.tolist(), # convert np.int32 → Python ints
mode="lines",
name=_iid,
line=dict(color=_color, width=1.5, dash=dash),
legendgroup=_iid,
showlegend=(row == 1),
hovertemplate=f"Day %{{x}}<br>{_iid}: %{{y:,.0f}}<extra></extra>",
),
row=row, col=1,
)
_add(1, result.inventory[node_id][iid])
_add(2, result.backlog[node_id][iid])
_add(3, result.inflow[node_id][iid])
_add(4, result.outflow[node_id][iid])
if is_dest:
_add(5, result.demand[iid], dash="dot")
# Disruption event markers — one vline per panel.
# add_vline(row=, col=) was not available in all Plotly 5.x builds;
# iterate over subplot y-axis references instead.
if result.disruption_log:
yref_list = ["y"] + [f"y{i}" for i in range(2, n_panels + 1)]
for ev in result.disruption_log:
for yref in yref_list:
fig.add_shape(
type="line",
x0=ev["day"], x1=ev["day"],
y0=0, y1=1,
xref="x", yref=f"{yref} domain",
line=dict(dash="dash", color="rgba(255,80,80,0.5)", width=1),
)
# Panel y-axis labels
y_labels = ["Units", "Units", "Units/day", "Units/day"]
if is_dest:
y_labels.append("Units/day")
for row, lbl in enumerate(y_labels, start=1):
fig.update_yaxes(title_text=lbl, row=row, col=1,
title_font=dict(size=10), tickfont=dict(size=9))
fig.update_xaxes(title_text="Day", row=n_panels, col=1,
tickfont=dict(size=9))
tier_name = _TIER_LABEL.get(result.tier.get(node_id, -1), "")
type_name = _TIER_TYPE_NAME.get(result.tier.get(node_id, -1), "")
disruption_note = (
" · <span style='color:rgba(255,80,80,0.9)'>red dashes = disruption events</span>"
if result.disruption_log else ""
)
fig.update_layout(
title=dict(
text=(f"<b>{node_id}</b> — {type_name}"
f"{disruption_note}"),
x=0.5, xanchor="center", font=dict(size=13),
),
legend=dict(
title="Items (SKUs)",
x=1.01, y=1.0,
font=dict(size=10),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="lightgray", borderwidth=1,
),
height=180 * n_panels + 80,
paper_bgcolor="white",
plot_bgcolor="rgba(248,248,252,1)",
margin=dict(l=60, r=120, t=60, b=50),
)
# Light grid
fig.update_xaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
return fig
# ============================================================================
# 3. Bullwhip amplification chart
# ============================================================================
def make_bullwhip_chart(result) -> go.Figure:
"""
Bar chart of tier-level bullwhip ratio B = Var(inflow)/Var(outflow).
Bars are grouped by item; tiers run from destination (left) to hub (right).
A dashed reference line at B = 1 marks the no-amplification baseline.
"""
ratios = _bullwhip_ratios(result)
# Collect tiers present (excluding sources, tier 6)
tiers_present = sorted(
{result.tier[n] for n in ratios if ratios[n]},
)
# Per-item bar traces
fig = go.Figure()
for idx, iid in enumerate(result.item_ids):
tier_means = []
for t in tiers_present:
nodes_in_tier = [n for n in ratios
if result.tier.get(n) == t and iid in ratios[n]]
if nodes_in_tier:
tier_means.append(
float(np.mean([ratios[n][iid] for n in nodes_in_tier]))
)
else:
tier_means.append(None)
x_labels = [_TIER_LABEL.get(t, f"Tier {t}") for t in tiers_present]
fig.add_trace(go.Bar(
x=x_labels,
y=tier_means,
name=iid,
marker_color=ITEM_COLORS[idx % len(ITEM_COLORS)],
opacity=0.82,
text=[f"{v:.2f}" if v is not None else "" for v in tier_means],
textposition="outside",
textfont=dict(size=9),
))
# Reference line at B = 1
fig.add_hline(
y=1.0, line_dash="dash",
line_color="rgba(80,80,80,0.6)", line_width=1.5,
annotation_text="B = 1 (no amplification)",
annotation_position="top right",
annotation_font_size=9,
)
fig.update_layout(
title=dict(
text=(
"<b>Bullwhip Amplification by Tier</b><br>"
"<sup>B &gt; 1 means demand variability grows as orders travel upstream</sup>"
),
x=0.5, xanchor="center", font=dict(size=13),
),
xaxis=dict(
title="Network tier (NewYork = downstream → Suppliers = upstream)",
title_font=dict(size=11),
tickfont=dict(size=10),
),
yaxis=dict(
title="B = Var(inflow) / Var(outflow)",
title_font=dict(size=11),
tickfont=dict(size=9),
rangemode="tozero",
),
barmode="group",
legend=dict(
title="Items",
x=1.01, y=1.0,
font=dict(size=10),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="lightgray", borderwidth=1,
),
height=400,
paper_bgcolor="white",
plot_bgcolor="rgba(248,248,252,1)",
margin=dict(l=60, r=120, t=60, b=60),
)
fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
return fig
# ============================================================================
# 4. Edge utilization heatmap
# ============================================================================
def make_edge_heatmap(result) -> go.Figure:
"""
Heatmap of edge utilization (fraction of daily capacity) over time.
Y-axis: edges sorted downstream→upstream.
X-axis: simulation day.
Color: 0 (white/green) = empty → 1 (red) = at capacity.
Disruption events are shown as vertical dashed lines.
"""
# Sort edges: downstream edges first (by tier of source node)
def _edge_tier(eid):
return result.tier.get(eid[0], 99)
sorted_edges = sorted(result.edge_ids, key=_edge_tier)
edge_labels = [f"{u}{v}" for u, v in sorted_edges]
z = np.array([result.edge_util[eid] for eid in sorted_edges],
dtype=np.float64) # shape (n_edges, T)
days = list(range(result.T))
fig = go.Figure(go.Heatmap(
z=z.tolist(),
x=days,
y=edge_labels,
colorscale=[
[0.0, "rgb(240,248,255)"], # near-empty: very light blue
[0.5, "rgb(255,200,100)"], # half full: yellow-orange
[1.0, "rgb(220,50,50)"], # saturated: red
],
zmin=0.0, zmax=1.0,
colorbar=dict(
title="Utilization<br>fraction",
thickness=13,
len=0.7,
tickvals=[0, 0.5, 1],
ticktext=["0%", "50%", "100%"],
tickfont=dict(size=9),
),
hovertemplate=(
"Edge: %{y}<br>Day: %{x}<br>Utilization: %{z:.1%}<extra></extra>"
),
))
# Disruption event markers
for ev in result.disruption_log:
fig.add_vline(
x=ev["day"], line_dash="dash",
line_color="rgba(80,80,255,0.6)", line_width=1.5,
)
fig.add_annotation(
x=ev["day"], y=1.01, xref="x", yref="paper",
text="disruption", showarrow=False,
font=dict(size=8, color="blue"), textangle=-90,
)
fig.update_layout(
title=dict(
text=(
"<b>Edge Utilization over Time</b><br>"
"<sup>Fraction of daily shipping capacity used per edge "
"(0 = empty, 1 = fully saturated). "
"Blue dashes mark disruption events.</sup>"
),
x=0.5, xanchor="center", font=dict(size=13),
),
xaxis=dict(
title="Day",
title_font=dict(size=11),
tickfont=dict(size=9),
),
yaxis=dict(
title="Edge",
title_font=dict(size=11),
tickfont=dict(size=9),
autorange="reversed", # downstream edges at top
),
height=max(320, 30 * len(sorted_edges) + 120),
paper_bgcolor="white",
margin=dict(l=160, r=80, t=60, b=60),
)
return fig
# ============================================================================
# 5. HTML wrapper for animated network map (Gradio gr.HTML compatible)
# ============================================================================
def make_network_animation_html(
result,
frame_step: Optional[int] = None,
frame_duration_ms: int = 150,
) -> str:
"""
Return an <iframe srcdoc=...> string for the animated network map.
Gradio's gr.Plot uses Plotly.react() internally, which strips the frames
array and breaks Play/Pause. An <iframe> with a full standalone HTML page
(full_html=True) calls Plotly.newPlot() directly so frames are preserved
and the animation works. The iframe also avoids the height:100% collapse
that occurs when embedding partial HTML in gr.HTML.
"""
import html as _html
import plotly.io as pio
fig = make_network_animation(result, frame_step, frame_duration_ms)
full_html = pio.to_html(
fig,
include_plotlyjs="cdn",
full_html=True,
config={"responsive": True},
)
escaped = _html.escape(full_html, quote=True)
return (
f'<iframe srcdoc="{escaped}" '
f'width="100%" height="600" frameborder="0" scrolling="no" '
f'style="border:none;display:block;"></iframe>'
)
# ============================================================================
# 6. Animated GIF export
# ============================================================================
def make_network_animation_gif(
result,
frame_step: Optional[int] = None,
frame_duration_ms: int = 150,
max_frames: int = 80,
) -> str:
"""
Render the network animation as an animated GIF and return the temp file path.
Uses matplotlib (Agg backend) with cartopy for the US map background.
No kaleido or Chrome required.
Parameters
----------
result : SimResult
frame_step : int or None
Days between frames. Auto-computed to cap at max_frames.
frame_duration_ms : int
Milliseconds per frame during playback.
max_frames : int
Maximum number of frames to render (keeps file size reasonable).
"""
import io as _io
import tempfile
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
try:
from PIL import Image
except ImportError:
raise ImportError("Pillow is required for GIF export: pip install Pillow")
try:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
_PROJ = ccrs.LambertConformal(central_longitude=-96, central_latitude=39,
standard_parallels=(33, 45))
_GEO = ccrs.PlateCarree()
HAS_CARTOPY = True
except ImportError:
HAS_CARTOPY = False
# Stress colormap: green(0) → yellow(0.35) → red(1.0)
stress_cmap = mcolors.LinearSegmentedColormap.from_list(
"stress",
[(0.00, "#4CAF50"), (0.35, "#FFC107"), (1.00, "#F44336")],
)
stress_norm = mcolors.Normalize(vmin=0.0, vmax=1.0)
T = result.T
if frame_step is None:
frame_step = max(1, T // max_frames)
frame_times = list(range(0, T, frame_step))
coords = result.node_coords
node_ids = result.node_ids
node_lons = [coords[n][1] for n in node_ids]
node_lats = [coords[n][0] for n in node_ids]
# s is marker area in pt² — keep proportional to Plotly sizes but much smaller
node_s = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) * 12 for n in node_ids]
max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0
# Precompute: (lon0, lat0, lon1, lat1, linewidth)
edges: List[Tuple] = []
for eid in result.edge_ids:
u, v = eid
if u not in coords or v not in coords:
continue
cap = result.edge_cap.get(eid, 1.0)
lw = 0.6 + 2.0 * (cap / max_cap) ** 0.5
edges.append((coords[u][1], coords[u][0], coords[v][1], coords[v][0], lw))
W, H, DPI = 8.5, 5.0, 90
pil_frames = []
for t in frame_times:
stress_t = _node_stress(result, t)
node_colors = [stress_cmap(stress_norm(s)) for s in stress_t] # RGBA tuples
node_markers = [_MPL_MARKER.get(_TIER_SYMBOL.get(result.tier.get(n, 3), "circle"), "o")
for n in node_ids]
lats_t, lons_t, _, colors_t = _frame_particles(result, t)
if HAS_CARTOPY:
fig = plt.figure(figsize=(W, H), dpi=DPI, facecolor="#f4f6f9")
ax = fig.add_subplot(1, 1, 1, projection=_PROJ)
ax.set_extent([-125, -66, 24, 50], crs=_GEO)
ax.add_feature(cfeature.LAND, facecolor="#f0f0ee", zorder=0)
ax.add_feature(cfeature.OCEAN, facecolor="#d0e4f0", zorder=0)
ax.add_feature(cfeature.LAKES, facecolor="#d0e4f0", alpha=0.6, zorder=1)
ax.add_feature(cfeature.STATES, edgecolor="#c0c0c0", linewidth=0.4, zorder=2)
ax.add_feature(cfeature.COASTLINE,edgecolor="#999999", linewidth=0.5, zorder=2)
for lon0, lat0, lon1, lat1, lw in edges:
ax.plot([lon0, lon1], [lat0, lat1], transform=_GEO,
color="#7a7a90", linewidth=lw, alpha=0.65, zorder=3,
solid_capstyle="round")
for nid, lon, lat, color, s, mkr in zip(
node_ids, node_lons, node_lats, node_colors, node_s, node_markers):
ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=4,
linewidths=1.5, edgecolors="white", alpha=0.85, transform=_GEO)
ax.text(lon, lat + 0.9, nid, transform=_GEO,
fontsize=6.5, ha="center", va="bottom",
fontweight="bold", color="#1a1a2e", zorder=6,
bbox=dict(boxstyle="round,pad=0.12", fc="white",
ec="none", alpha=0.65))
if lons_t:
rgba_t = [mcolors.to_rgba(c) for c in colors_t]
ax.scatter(lons_t, lats_t, c=rgba_t, s=45, zorder=8,
linewidths=0.5, edgecolors="white", alpha=1.0,
transform=_GEO)
ax.set_title(
f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}",
fontsize=10, fontweight="bold", pad=6, color="#1a1a2e",
)
fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01)
else:
# Fallback: plain axes, no geographic background
fig, ax = plt.subplots(figsize=(W, H), dpi=DPI)
fig.patch.set_facecolor("#f4f6f9")
ax.set_facecolor("#dce8f0")
ax.set_xlim(-128, -65)
ax.set_ylim(23, 50)
ax.axis("off")
for lon0, lat0, lon1, lat1, lw in edges:
ax.plot([lon0, lon1], [lat0, lat1], color="#7a7a90", linewidth=lw,
alpha=0.65, zorder=1, solid_capstyle="round")
for nid, lon, lat, color, s, mkr in zip(
node_ids, node_lons, node_lats, node_colors, node_s, node_markers):
ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=3,
linewidths=1.5, edgecolors="white", alpha=0.85)
ax.text(lon, lat + 0.8, nid, fontsize=6.5, ha="center", va="bottom",
fontweight="bold", color="#1a1a2e", zorder=5,
bbox=dict(boxstyle="round,pad=0.12", fc="white",
ec="none", alpha=0.65))
if lons_t:
ax.scatter(lons_t, lats_t, c=colors_t, s=45, zorder=8,
linewidths=0.5, edgecolors="white", alpha=0.95)
ax.set_title(
f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}",
fontsize=10, fontweight="bold", pad=6, color="#1a1a2e",
)
fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01)
buf = _io.BytesIO()
fig.savefig(buf, format="png", dpi=DPI, facecolor=fig.get_facecolor())
plt.close(fig)
buf.seek(0)
pil_frames.append(Image.open(buf).convert("RGB"))
# Normalize all frames to the same size (first frame drives dimensions)
W_px, H_px = pil_frames[0].size
pil_frames = [f.resize((W_px, H_px), Image.LANCZOS) for f in pil_frames]
# GIF supports only 256 colours. The map background consumes many palette
# slots with similar grays/blues, crowding out item and stress colours.
# Fix: append large solid-colour blocks of every important colour to the
# combined image before quantising, so the median-cut algorithm is forced
# to reserve palette entries for them.
pinned_hex = (
ITEM_COLORS[:5] # all 5 item colours
+ ["#4CAF50", "#8BC34A", "#FFC107", "#FF9800", # stress ramp
"#FF5722", "#F44336"]
+ ["#ffffff", "#1a1a2e", "#7a7a90", # white / text / edges
"#f0f0ee", "#d0e4f0", "#c0c0c0"] # map background tones
)
BLOCK = 2000 # large block → many pixels → dominates palette selection
ref_strips = []
for h in pinned_hex:
r, g, b = int(h.lstrip("#")[0:2], 16), int(h.lstrip("#")[2:4], 16), int(h.lstrip("#")[4:6], 16)
ref_strips.append(Image.new("RGB", (BLOCK, H_px), (r, g, b)))
ref = Image.new("RGB", (BLOCK * len(ref_strips), H_px))
for i, s in enumerate(ref_strips):
ref.paste(s, (i * BLOCK, 0))
combined = Image.new("RGB", (W_px * len(pil_frames) + ref.width, H_px))
for i, f in enumerate(pil_frames):
combined.paste(f, (i * W_px, 0))
combined.paste(ref, (W_px * len(pil_frames), 0))
palette_source = combined.quantize(colors=255, method=Image.Quantize.MEDIANCUT)
quantized = [f.quantize(palette=palette_source, dither=0) for f in pil_frames]
tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False, prefix="isomorph_")
quantized[0].save(
tmp.name,
save_all=True,
append_images=quantized[1:],
loop=0,
duration=frame_duration_ms,
optimize=False,
)
tmp.close()
return tmp.name