"""
ISOMORPH Demo — Visualization Module
=====================================
All functions return plotly.graph_objects.Figure objects for rendering in
Gradio gr.Plot components.
Public API
----------
make_network_animation(result, frame_step=None)
Animated US map: shipment particles travel along edges, nodes colored
by backlog stress. Play / Pause controls + scrubber slider.
make_node_timeseries(result, node_id, item_ids=None)
Subplot panel for one node: inventory, backlog, inflow, outflow,
and (at the destination) realized demand.
make_bullwhip_chart(result)
Bar chart of mean B = Var(inflow)/Var(outflow) per tier, from
destination (left) to hub/sources (right).
make_edge_heatmap(result)
Heatmap of edge utilization (fraction of daily capacity) over time.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ── Color constants ───────────────────────────────────────────────────────────
# One color per item (up to 5) — matplotlib tab10 subset
ITEM_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd", "#d62728"]
# Node stress: green (0 = no backlog) → yellow → red (1 = all backlog)
NODE_COLORSCALE = [[0.0, "#4CAF50"], [0.35, "#FFC107"], [1.0, "#F44336"]]
# Visual size of node marker by tier
_TIER_MARKER_SIZE = {0: 22, 1: 16, 2: 13, 3: 13, 4: 15, 5: 18, 6: 11}
# Human-readable tier labels (for bullwhip x-axis)
_TIER_LABEL = {
0: "Destination",
1: "Last-mile",
2: "Tier-4",
3: "Tier-3",
4: "Tier-2
(Atlanta)",
5: "Hub
(Nashville)",
6: "Sources",
}
# Descriptive role name shown in hover tooltip and node-type legend
_TIER_TYPE_NAME = {
0: "Destination (end customer)",
1: "Last-mile DC",
2: "Tier-4 Warehouse",
3: "Tier-3 Warehouse",
4: "Tier-2 Warehouse (Atlanta)",
5: "Regional Hub (Nashville)",
6: "Source / Supplier",
}
# Plotly Scattergeo marker symbol per tier — visually distinguishes role
_TIER_SYMBOL = {
0: "star", # Destination — prominent star
1: "square", # Last-mile DCs
2: "circle", # Tier-4 warehouses
3: "circle", # Tier-3 warehouses
4: "diamond", # Tier-2 (Atlanta pivot)
5: "hexagram", # Hub (Nashville)
6: "triangle-up", # Source suppliers
}
# Matplotlib equivalents for GIF rendering (one marker per node requires individual scatter calls)
_MPL_MARKER = {
"star": "*",
"square": "s",
"circle": "o",
"diamond": "D",
"hexagram": "h",
"triangle-up": "^",
}
# ============================================================================
# Private helpers
# ============================================================================
def _edge_travel_times(result) -> Dict[Tuple[str, str], float]:
"""Derive edge travel times from the first shipment that uses each hop."""
tt: Dict[Tuple[str, str], float] = {}
for s in result.shipments:
path = s["path_nodes"]
ets = s["edge_times"]
for h in range(len(ets)):
hop = (path[h], path[h + 1])
if hop not in tt:
tt[hop] = float(ets[h])
if len(tt) == len(result.edge_ids):
return tt # early exit once all edges covered
return tt
def _node_stress(result, t: int) -> List[float]:
"""
Return a per-node backlog-fraction in [0, 1] at timestep t.
stress = total_backlog / (total_backlog + total_inventory + 1)
Ordered to match result.node_ids.
"""
fracs = []
for nid in result.node_ids:
bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
fracs.append(bl / max(bl + inv, 1))
return fracs
def _node_hover_custom(result, t: int,
type_names: Optional[List[str]] = None) -> List[List]:
"""
Return per-node customdata rows [inv, bl, type_name] at time t.
type_names (static) is passed once and repeated in every frame so the
hovertemplate can reference %{customdata[2]}.
"""
rows = []
for i, nid in enumerate(result.node_ids):
inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids)
bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids)
tname = type_names[i] if type_names else ""
rows.append([inv, bl, tname])
return rows
def _interpolate_position(
shipment: dict,
t: int,
node_coords: Dict[str, Tuple[float, float]],
) -> Optional[Tuple[float, float]]:
"""
Return (lat, lon) for a shipment at integer day t, or None if not active.
The particle travels along path_nodes using edge_times for timing.
"""
d = shipment["day"]
a = shipment["arrival_day"]
if t < d or t >= a:
return None
path = shipment["path_nodes"]
ets = shipment["edge_times"]
elapsed = float(t - d)
cum = 0.0
for h, tt in enumerate(ets):
tt = float(tt)
if elapsed <= cum + tt + 1e-9:
frac = (elapsed - cum) / max(tt, 1e-6)
frac = max(0.0, min(1.0, frac))
lat_a, lon_a = node_coords[path[h]]
lat_b, lon_b = node_coords[path[h + 1]]
return lat_a + frac * (lat_b - lat_a), lon_a + frac * (lon_b - lon_a)
cum += tt
return None
def _frame_particles(result, t: int, item_filter: Optional[str] = None):
"""
Return (lats, lons, texts, colors) for all active shipment particles
at timestep t. colors is a list of CSS color strings keyed by item.
"""
item_color_map = {iid: ITEM_COLORS[i % len(ITEM_COLORS)]
for i, iid in enumerate(result.item_ids)}
lats, lons, texts, colors = [], [], [], []
for s in result.shipments:
if item_filter and s["item"] != item_filter:
continue
pos = _interpolate_position(s, t, result.node_coords)
if pos is None:
continue
lat, lon = pos
lats.append(lat)
lons.append(lon)
texts.append(
f"{s['item']}
{s['from']}→{s['to']}
"
f"qty: {s['units']} day {s['day']}→{s['arrival_day']}"
)
colors.append(item_color_map[s["item"]])
return lats, lons, texts, colors
def _bullwhip_ratios(result) -> Dict[str, Dict[str, float]]:
"""
Compute B_(n,i) = Var(inflow_(n,i)) / Var(outflow_(n,i)) per node/item.
At the destination node outflow is replaced by customer demand.
Source nodes (tier 6) are excluded (no inflow on network edges).
"""
dest_id = next(
nid for nid in result.node_ids if result.tier.get(nid) == 0
)
ratios: Dict[str, Dict[str, float]] = {}
for nid in result.node_ids:
tier = result.tier.get(nid, -1)
if tier == 6:
continue
ratios[nid] = {}
for iid in result.item_ids:
inflow_arr = result.inflow[nid][iid].astype(float)
if nid == dest_id:
outflow_arr = result.demand[iid].astype(float)
else:
outflow_arr = result.outflow[nid][iid].astype(float)
var_in = float(np.var(inflow_arr, ddof=1)) if result.T > 1 else 0.0
var_out = float(np.var(outflow_arr, ddof=1)) if result.T > 1 else 0.0
if var_out > 1.0: # guard against near-zero outflow variance
ratios[nid][iid] = var_in / var_out
return ratios
# ============================================================================
# 1. Animated network map
# ============================================================================
def make_network_animation(
result,
frame_step: Optional[int] = None,
frame_duration_ms: int = 150,
) -> go.Figure:
"""
Animated Scattergeo map of the supply-chain network.
Parameters
----------
result : SimResult
frame_step : int or None
Days between animation frames. Auto-computed to cap at ~200 frames.
frame_duration_ms : int
Milliseconds per frame during playback.
"""
T = result.T
max_frames = 200
if frame_step is None:
frame_step = max(1, T // max_frames)
frame_times = list(range(0, T, frame_step))
node_ids = result.node_ids
coords = result.node_coords
edge_tt = _edge_travel_times(result)
# ── Node metadata ──────────────────────────────────────────────────────
node_lats = [coords[n][0] for n in node_ids]
node_lons = [coords[n][1] for n in node_ids]
node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13)
for n in node_ids]
node_labels = node_ids
node_type_names = [_TIER_TYPE_NAME.get(result.tier.get(n, -1), "Warehouse")
for n in node_ids]
node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle")
for n in node_ids]
# ── Edge traces (static) ───────────────────────────────────────────────
edge_traces = []
max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0
for eid in result.edge_ids:
u, v = eid
if u not in coords or v not in coords:
continue
cap = result.edge_cap.get(eid, 1.0)
lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5
tt = edge_tt.get(eid, "?")
edge_traces.append(go.Scattergeo(
lat=[coords[u][0], coords[v][0], None],
lon=[coords[u][1], coords[v][1], None],
mode="lines",
line=dict(width=lw, color="rgba(140,140,160,0.55)"),
hoverinfo="text",
text=f"{u} → {v}
travel: {tt} day(s)
daily cap: {cap:.0f} units",
showlegend=False,
name=f"{u}→{v}",
))
n_edge_traces = len(edge_traces)
# ── Initial node trace (t=0) ───────────────────────────────────────────
stress_0 = _node_stress(result, 0)
custom_0 = _node_hover_custom(result, 0, node_type_names)
node_trace = go.Scattergeo(
lat=node_lats,
lon=node_lons,
mode="markers+text",
text=node_labels,
textposition="top center",
textfont=dict(size=9, color="black"),
marker=dict(
size=node_sizes,
symbol=node_symbols,
color=stress_0,
colorscale=NODE_COLORSCALE,
cmin=0.0, cmax=1.0,
colorbar=dict(
title="Backlog
stress",
thickness=12,
len=0.5,
x=1.01,
tickvals=[0, 0.5, 1],
ticktext=["0 (healthy)", "0.5", "1 (stockout)"],
tickfont=dict(size=9),
),
line=dict(width=1.5, color="white"),
),
customdata=custom_0,
hovertemplate=(
"%{text}
"
"Role: %{customdata[2]}
"
"Total inventory: %{customdata[0]:,} units
"
"Total backlog: %{customdata[1]:,} units
"
""
),
showlegend=False,
name="nodes",
)
# ── Initial shipment trace (t=0) ───────────────────────────────────────
lats0, lons0, texts0, colors0 = _frame_particles(result, 0)
ship_trace = go.Scattergeo(
lat=lats0,
lon=lons0,
mode="markers",
marker=dict(size=7, color=colors0, opacity=0.85,
line=dict(width=0.5, color="white")),
hoverinfo="text",
text=texts0,
showlegend=False,
name="shipments",
)
# ── Legend: item colors (shipment dots) ──────────────────────────────
item_legend_traces = []
for i, iid in enumerate(result.item_ids):
item_legend_traces.append(go.Scattergeo(
lat=[None], lon=[None],
mode="markers",
marker=dict(size=9, color=ITEM_COLORS[i % len(ITEM_COLORS)],
symbol="circle"),
name=iid,
legendgrouptitle_text="Shipment items" if i == 0 else None,
legendgroup="items",
showlegend=True,
))
# ── Legend: node-type symbols ─────────────────────────────────────────
# One proxy trace per distinct tier present in this result.
seen_tiers: dict = {}
for n in node_ids:
t_val = result.tier.get(n, -1)
if t_val not in seen_tiers:
seen_tiers[t_val] = (_TIER_TYPE_NAME.get(t_val, "Unknown"),
_TIER_SYMBOL.get(t_val, "circle"))
node_type_legend_traces = []
for i, (t_val, (tname, sym)) in enumerate(
sorted(seen_tiers.items(), key=lambda x: x[0])):
node_type_legend_traces.append(go.Scattergeo(
lat=[None], lon=[None],
mode="markers",
marker=dict(size=10, color="gray", symbol=sym),
name=tname,
legendgrouptitle_text="Node types" if i == 0 else None,
legendgroup="node_types",
showlegend=True,
))
# ── Assemble base figure ───────────────────────────────────────────────
all_traces = (edge_traces + [node_trace, ship_trace]
+ item_legend_traces + node_type_legend_traces)
fig = go.Figure(data=all_traces)
# Trace indices of dynamic traces (edge + node + ship; legend traces follow)
idx_nodes = n_edge_traces
idx_ships = n_edge_traces + 1
# ── Pre-compute frames ─────────────────────────────────────────────────
frames = []
for t in frame_times:
stress_t = _node_stress(result, t)
custom_t = _node_hover_custom(result, t, node_type_names)
lats_t, lons_t, texts_t, colors_t = _frame_particles(result, t)
frames.append(go.Frame(
data=[
# Plain dicts avoid serialising default None lat/lon values
# that go.Scattergeo() would inject and confuse Plotly.js.
{"type": "scattergeo",
"marker": {"color": stress_t},
"customdata": custom_t},
{"type": "scattergeo",
"lat": lats_t, "lon": lons_t,
"text": texts_t,
"marker": {"color": colors_t, "size": 7, "opacity": 0.85}},
],
traces=[idx_nodes, idx_ships],
layout={"title": {"text": f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}"}},
name=str(t),
))
fig.frames = frames
# ── Slider steps ──────────────────────────────────────────────────────
slider_steps = [
dict(
method="animate",
args=[[str(t)],
dict(mode="immediate",
frame=dict(duration=0, redraw=True),
transition=dict(duration=0))],
label=str(t),
)
for t in frame_times
]
# ── Layout ────────────────────────────────────────────────────────────
fig.update_layout(
title=dict(
text=f"ISOMORPH Supply Chain Digital Twin — Day 0 / {T - 1}",
x=0.5, xanchor="center", font=dict(size=14),
),
geo=dict(
scope="usa",
projection_type="albers usa",
showland=True, landcolor="rgb(243,243,243)",
showlakes=True, lakecolor="rgb(210,230,255)",
showrivers=True, rivercolor="rgb(210,230,255)",
showcoastlines=True, coastlinecolor="rgb(180,180,200)",
showsubunits=True, subunitcolor="rgb(200,200,215)",
bgcolor="rgba(255,255,255,0)",
),
legend=dict(
x=0.01, y=0.01,
bgcolor="rgba(255,255,255,0.82)",
bordercolor="lightgray", borderwidth=1,
font=dict(size=9),
tracegroupgap=6,
),
updatemenus=[dict(
type="buttons",
showactive=False,
y=1.08, x=0.0, xanchor="left",
buttons=[
dict(
label="▶ Play",
method="animate",
args=[None, dict(
frame=dict(duration=frame_duration_ms, redraw=True),
fromcurrent=True,
transition=dict(duration=0),
)],
),
dict(
label="⏸ Pause",
method="animate",
args=[[None], dict(
frame=dict(duration=0, redraw=False),
mode="immediate",
transition=dict(duration=0),
)],
),
],
)],
sliders=[dict(
currentvalue=dict(
prefix="Day: ",
font=dict(size=11),
visible=True,
xanchor="center",
),
pad=dict(t=50, b=10),
len=0.9,
x=0.05,
steps=slider_steps,
transition=dict(duration=0),
)],
margin=dict(l=0, r=0, t=80, b=60),
height=540,
paper_bgcolor="white",
)
return fig
# ============================================================================
# 2. Node detail time-series panel
# ============================================================================
def make_node_timeseries(
result,
node_id: str,
item_ids: Optional[List[str]] = None,
) -> go.Figure:
"""
4–5 subplot panel for a single node showing inventory, backlog, inflow,
outflow, and (destination only) realized demand.
Parameters
----------
result : SimResult
node_id : str
Node to visualise.
item_ids : list[str] or None
Subset of items to plot. Defaults to all items in result.
"""
if item_ids is None:
item_ids = result.item_ids
dest_id = next(n for n in result.node_ids if result.tier.get(n) == 0)
is_dest = (node_id == dest_id)
# Use plain Python lists — avoids numpy int32/int64 serialization issues
# inside Gradio 4.x's gr.Plot JSON path.
days = list(range(result.T))
n_panels = 5 if is_dest else 4
panel_titles = [
"On-hand inventory (units stored at this node)",
"Backlog (unfulfilled demand pending)",
"Inflow (units arriving per day)",
"Outflow (units shipped out per day)",
]
if is_dest:
panel_titles.append("Realized demand (customer orders received per day)")
fig = make_subplots(
rows=n_panels, cols=1,
shared_xaxes=True,
subplot_titles=panel_titles,
vertical_spacing=0.06,
)
for idx, iid in enumerate(item_ids):
color = ITEM_COLORS[idx % len(ITEM_COLORS)]
# Capture loop variables by value via default args to avoid closure issues.
def _add(row, y_arr, dash="solid", _iid=iid, _color=color):
fig.add_trace(
go.Scatter(
x=days,
y=y_arr.tolist(), # convert np.int32 → Python ints
mode="lines",
name=_iid,
line=dict(color=_color, width=1.5, dash=dash),
legendgroup=_iid,
showlegend=(row == 1),
hovertemplate=f"Day %{{x}}
{_iid}: %{{y:,.0f}}",
),
row=row, col=1,
)
_add(1, result.inventory[node_id][iid])
_add(2, result.backlog[node_id][iid])
_add(3, result.inflow[node_id][iid])
_add(4, result.outflow[node_id][iid])
if is_dest:
_add(5, result.demand[iid], dash="dot")
# Disruption event markers — one vline per panel.
# add_vline(row=, col=) was not available in all Plotly 5.x builds;
# iterate over subplot y-axis references instead.
if result.disruption_log:
yref_list = ["y"] + [f"y{i}" for i in range(2, n_panels + 1)]
for ev in result.disruption_log:
for yref in yref_list:
fig.add_shape(
type="line",
x0=ev["day"], x1=ev["day"],
y0=0, y1=1,
xref="x", yref=f"{yref} domain",
line=dict(dash="dash", color="rgba(255,80,80,0.5)", width=1),
)
# Panel y-axis labels
y_labels = ["Units", "Units", "Units/day", "Units/day"]
if is_dest:
y_labels.append("Units/day")
for row, lbl in enumerate(y_labels, start=1):
fig.update_yaxes(title_text=lbl, row=row, col=1,
title_font=dict(size=10), tickfont=dict(size=9))
fig.update_xaxes(title_text="Day", row=n_panels, col=1,
tickfont=dict(size=9))
tier_name = _TIER_LABEL.get(result.tier.get(node_id, -1), "")
type_name = _TIER_TYPE_NAME.get(result.tier.get(node_id, -1), "")
disruption_note = (
" · red dashes = disruption events"
if result.disruption_log else ""
)
fig.update_layout(
title=dict(
text=(f"{node_id} — {type_name}"
f"{disruption_note}"),
x=0.5, xanchor="center", font=dict(size=13),
),
legend=dict(
title="Items (SKUs)",
x=1.01, y=1.0,
font=dict(size=10),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="lightgray", borderwidth=1,
),
height=180 * n_panels + 80,
paper_bgcolor="white",
plot_bgcolor="rgba(248,248,252,1)",
margin=dict(l=60, r=120, t=60, b=50),
)
# Light grid
fig.update_xaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
return fig
# ============================================================================
# 3. Bullwhip amplification chart
# ============================================================================
def make_bullwhip_chart(result) -> go.Figure:
"""
Bar chart of tier-level bullwhip ratio B = Var(inflow)/Var(outflow).
Bars are grouped by item; tiers run from destination (left) to hub (right).
A dashed reference line at B = 1 marks the no-amplification baseline.
"""
ratios = _bullwhip_ratios(result)
# Collect tiers present (excluding sources, tier 6)
tiers_present = sorted(
{result.tier[n] for n in ratios if ratios[n]},
)
# Per-item bar traces
fig = go.Figure()
for idx, iid in enumerate(result.item_ids):
tier_means = []
for t in tiers_present:
nodes_in_tier = [n for n in ratios
if result.tier.get(n) == t and iid in ratios[n]]
if nodes_in_tier:
tier_means.append(
float(np.mean([ratios[n][iid] for n in nodes_in_tier]))
)
else:
tier_means.append(None)
x_labels = [_TIER_LABEL.get(t, f"Tier {t}") for t in tiers_present]
fig.add_trace(go.Bar(
x=x_labels,
y=tier_means,
name=iid,
marker_color=ITEM_COLORS[idx % len(ITEM_COLORS)],
opacity=0.82,
text=[f"{v:.2f}" if v is not None else "" for v in tier_means],
textposition="outside",
textfont=dict(size=9),
))
# Reference line at B = 1
fig.add_hline(
y=1.0, line_dash="dash",
line_color="rgba(80,80,80,0.6)", line_width=1.5,
annotation_text="B = 1 (no amplification)",
annotation_position="top right",
annotation_font_size=9,
)
fig.update_layout(
title=dict(
text=(
"Bullwhip Amplification by Tier
"
"B > 1 means demand variability grows as orders travel upstream"
),
x=0.5, xanchor="center", font=dict(size=13),
),
xaxis=dict(
title="Network tier (NewYork = downstream → Suppliers = upstream)",
title_font=dict(size=11),
tickfont=dict(size=10),
),
yaxis=dict(
title="B = Var(inflow) / Var(outflow)",
title_font=dict(size=11),
tickfont=dict(size=9),
rangemode="tozero",
),
barmode="group",
legend=dict(
title="Items",
x=1.01, y=1.0,
font=dict(size=10),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="lightgray", borderwidth=1,
),
height=400,
paper_bgcolor="white",
plot_bgcolor="rgba(248,248,252,1)",
margin=dict(l=60, r=120, t=60, b=60),
)
fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)")
return fig
# ============================================================================
# 4. Edge utilization heatmap
# ============================================================================
def make_edge_heatmap(result) -> go.Figure:
"""
Heatmap of edge utilization (fraction of daily capacity) over time.
Y-axis: edges sorted downstream→upstream.
X-axis: simulation day.
Color: 0 (white/green) = empty → 1 (red) = at capacity.
Disruption events are shown as vertical dashed lines.
"""
# Sort edges: downstream edges first (by tier of source node)
def _edge_tier(eid):
return result.tier.get(eid[0], 99)
sorted_edges = sorted(result.edge_ids, key=_edge_tier)
edge_labels = [f"{u} → {v}" for u, v in sorted_edges]
z = np.array([result.edge_util[eid] for eid in sorted_edges],
dtype=np.float64) # shape (n_edges, T)
days = list(range(result.T))
fig = go.Figure(go.Heatmap(
z=z.tolist(),
x=days,
y=edge_labels,
colorscale=[
[0.0, "rgb(240,248,255)"], # near-empty: very light blue
[0.5, "rgb(255,200,100)"], # half full: yellow-orange
[1.0, "rgb(220,50,50)"], # saturated: red
],
zmin=0.0, zmax=1.0,
colorbar=dict(
title="Utilization
fraction",
thickness=13,
len=0.7,
tickvals=[0, 0.5, 1],
ticktext=["0%", "50%", "100%"],
tickfont=dict(size=9),
),
hovertemplate=(
"Edge: %{y}
Day: %{x}
Utilization: %{z:.1%}"
),
))
# Disruption event markers
for ev in result.disruption_log:
fig.add_vline(
x=ev["day"], line_dash="dash",
line_color="rgba(80,80,255,0.6)", line_width=1.5,
)
fig.add_annotation(
x=ev["day"], y=1.01, xref="x", yref="paper",
text="disruption", showarrow=False,
font=dict(size=8, color="blue"), textangle=-90,
)
fig.update_layout(
title=dict(
text=(
"Edge Utilization over Time
"
"Fraction of daily shipping capacity used per edge "
"(0 = empty, 1 = fully saturated). "
"Blue dashes mark disruption events."
),
x=0.5, xanchor="center", font=dict(size=13),
),
xaxis=dict(
title="Day",
title_font=dict(size=11),
tickfont=dict(size=9),
),
yaxis=dict(
title="Edge",
title_font=dict(size=11),
tickfont=dict(size=9),
autorange="reversed", # downstream edges at top
),
height=max(320, 30 * len(sorted_edges) + 120),
paper_bgcolor="white",
margin=dict(l=160, r=80, t=60, b=60),
)
return fig
# ============================================================================
# 5. HTML wrapper for animated network map (Gradio gr.HTML compatible)
# ============================================================================
def make_network_animation_html(
result,
frame_step: Optional[int] = None,
frame_duration_ms: int = 150,
) -> str:
"""
Return an