""" ISOMORPH Demo — Visualization Module ===================================== All functions return plotly.graph_objects.Figure objects for rendering in Gradio gr.Plot components. Public API ---------- make_network_animation(result, frame_step=None) Animated US map: shipment particles travel along edges, nodes colored by backlog stress. Play / Pause controls + scrubber slider. make_node_timeseries(result, node_id, item_ids=None) Subplot panel for one node: inventory, backlog, inflow, outflow, and (at the destination) realized demand. make_bullwhip_chart(result) Bar chart of mean B = Var(inflow)/Var(outflow) per tier, from destination (left) to hub/sources (right). make_edge_heatmap(result) Heatmap of edge utilization (fraction of daily capacity) over time. """ from __future__ import annotations from typing import Dict, List, Optional, Tuple import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots # ── Color constants ─────────────────────────────────────────────────────────── # One color per item (up to 5) — matplotlib tab10 subset ITEM_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd", "#d62728"] # Node stress: green (0 = no backlog) → yellow → red (1 = all backlog) NODE_COLORSCALE = [[0.0, "#4CAF50"], [0.35, "#FFC107"], [1.0, "#F44336"]] # Visual size of node marker by tier _TIER_MARKER_SIZE = {0: 22, 1: 16, 2: 13, 3: 13, 4: 15, 5: 18, 6: 11} # Human-readable tier labels (for bullwhip x-axis) _TIER_LABEL = { 0: "Destination", 1: "Last-mile", 2: "Tier-4", 3: "Tier-3", 4: "Tier-2
(Atlanta)", 5: "Hub
(Nashville)", 6: "Sources", } # Descriptive role name shown in hover tooltip and node-type legend _TIER_TYPE_NAME = { 0: "Destination (end customer)", 1: "Last-mile DC", 2: "Tier-4 Warehouse", 3: "Tier-3 Warehouse", 4: "Tier-2 Warehouse (Atlanta)", 5: "Regional Hub (Nashville)", 6: "Source / Supplier", } # Plotly Scattergeo marker symbol per tier — visually distinguishes role _TIER_SYMBOL = { 0: "star", # Destination — prominent star 1: "square", # Last-mile DCs 2: "circle", # Tier-4 warehouses 3: "circle", # Tier-3 warehouses 4: "diamond", # Tier-2 (Atlanta pivot) 5: "hexagram", # Hub (Nashville) 6: "triangle-up", # Source suppliers } # Matplotlib equivalents for GIF rendering (one marker per node requires individual scatter calls) _MPL_MARKER = { "star": "*", "square": "s", "circle": "o", "diamond": "D", "hexagram": "h", "triangle-up": "^", } # ============================================================================ # Private helpers # ============================================================================ def _edge_travel_times(result) -> Dict[Tuple[str, str], float]: """Derive edge travel times from the first shipment that uses each hop.""" tt: Dict[Tuple[str, str], float] = {} for s in result.shipments: path = s["path_nodes"] ets = s["edge_times"] for h in range(len(ets)): hop = (path[h], path[h + 1]) if hop not in tt: tt[hop] = float(ets[h]) if len(tt) == len(result.edge_ids): return tt # early exit once all edges covered return tt def _node_stress(result, t: int) -> List[float]: """ Return a per-node backlog-fraction in [0, 1] at timestep t. stress = total_backlog / (total_backlog + total_inventory + 1) Ordered to match result.node_ids. """ fracs = [] for nid in result.node_ids: bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids) inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids) fracs.append(bl / max(bl + inv, 1)) return fracs def _node_hover_custom(result, t: int, type_names: Optional[List[str]] = None) -> List[List]: """ Return per-node customdata rows [inv, bl, type_name] at time t. type_names (static) is passed once and repeated in every frame so the hovertemplate can reference %{customdata[2]}. """ rows = [] for i, nid in enumerate(result.node_ids): inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids) bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids) tname = type_names[i] if type_names else "" rows.append([inv, bl, tname]) return rows def _interpolate_position( shipment: dict, t: int, node_coords: Dict[str, Tuple[float, float]], ) -> Optional[Tuple[float, float]]: """ Return (lat, lon) for a shipment at integer day t, or None if not active. The particle travels along path_nodes using edge_times for timing. """ d = shipment["day"] a = shipment["arrival_day"] if t < d or t >= a: return None path = shipment["path_nodes"] ets = shipment["edge_times"] elapsed = float(t - d) cum = 0.0 for h, tt in enumerate(ets): tt = float(tt) if elapsed <= cum + tt + 1e-9: frac = (elapsed - cum) / max(tt, 1e-6) frac = max(0.0, min(1.0, frac)) lat_a, lon_a = node_coords[path[h]] lat_b, lon_b = node_coords[path[h + 1]] return lat_a + frac * (lat_b - lat_a), lon_a + frac * (lon_b - lon_a) cum += tt return None def _frame_particles(result, t: int, item_filter: Optional[str] = None): """ Return (lats, lons, texts, colors) for all active shipment particles at timestep t. colors is a list of CSS color strings keyed by item. """ item_color_map = {iid: ITEM_COLORS[i % len(ITEM_COLORS)] for i, iid in enumerate(result.item_ids)} lats, lons, texts, colors = [], [], [], [] for s in result.shipments: if item_filter and s["item"] != item_filter: continue pos = _interpolate_position(s, t, result.node_coords) if pos is None: continue lat, lon = pos lats.append(lat) lons.append(lon) texts.append( f"{s['item']}
{s['from']}→{s['to']}
" f"qty: {s['units']} day {s['day']}→{s['arrival_day']}" ) colors.append(item_color_map[s["item"]]) return lats, lons, texts, colors def _bullwhip_ratios(result) -> Dict[str, Dict[str, float]]: """ Compute B_(n,i) = Var(inflow_(n,i)) / Var(outflow_(n,i)) per node/item. At the destination node outflow is replaced by customer demand. Source nodes (tier 6) are excluded (no inflow on network edges). """ dest_id = next( nid for nid in result.node_ids if result.tier.get(nid) == 0 ) ratios: Dict[str, Dict[str, float]] = {} for nid in result.node_ids: tier = result.tier.get(nid, -1) if tier == 6: continue ratios[nid] = {} for iid in result.item_ids: inflow_arr = result.inflow[nid][iid].astype(float) if nid == dest_id: outflow_arr = result.demand[iid].astype(float) else: outflow_arr = result.outflow[nid][iid].astype(float) var_in = float(np.var(inflow_arr, ddof=1)) if result.T > 1 else 0.0 var_out = float(np.var(outflow_arr, ddof=1)) if result.T > 1 else 0.0 if var_out > 1.0: # guard against near-zero outflow variance ratios[nid][iid] = var_in / var_out return ratios # ============================================================================ # 1. Animated network map # ============================================================================ def make_network_animation( result, frame_step: Optional[int] = None, frame_duration_ms: int = 150, ) -> go.Figure: """ Animated Scattergeo map of the supply-chain network. Parameters ---------- result : SimResult frame_step : int or None Days between animation frames. Auto-computed to cap at ~200 frames. frame_duration_ms : int Milliseconds per frame during playback. """ T = result.T max_frames = 200 if frame_step is None: frame_step = max(1, T // max_frames) frame_times = list(range(0, T, frame_step)) node_ids = result.node_ids coords = result.node_coords edge_tt = _edge_travel_times(result) # ── Node metadata ────────────────────────────────────────────────────── node_lats = [coords[n][0] for n in node_ids] node_lons = [coords[n][1] for n in node_ids] node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) for n in node_ids] node_labels = node_ids node_type_names = [_TIER_TYPE_NAME.get(result.tier.get(n, -1), "Warehouse") for n in node_ids] node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle") for n in node_ids] # ── Edge traces (static) ─────────────────────────────────────────────── edge_traces = [] max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0 for eid in result.edge_ids: u, v = eid if u not in coords or v not in coords: continue cap = result.edge_cap.get(eid, 1.0) lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5 tt = edge_tt.get(eid, "?") edge_traces.append(go.Scattergeo( lat=[coords[u][0], coords[v][0], None], lon=[coords[u][1], coords[v][1], None], mode="lines", line=dict(width=lw, color="rgba(140,140,160,0.55)"), hoverinfo="text", text=f"{u} → {v}
travel: {tt} day(s)
daily cap: {cap:.0f} units", showlegend=False, name=f"{u}→{v}", )) n_edge_traces = len(edge_traces) # ── Initial node trace (t=0) ─────────────────────────────────────────── stress_0 = _node_stress(result, 0) custom_0 = _node_hover_custom(result, 0, node_type_names) node_trace = go.Scattergeo( lat=node_lats, lon=node_lons, mode="markers+text", text=node_labels, textposition="top center", textfont=dict(size=9, color="black"), marker=dict( size=node_sizes, symbol=node_symbols, color=stress_0, colorscale=NODE_COLORSCALE, cmin=0.0, cmax=1.0, colorbar=dict( title="Backlog
stress", thickness=12, len=0.5, x=1.01, tickvals=[0, 0.5, 1], ticktext=["0 (healthy)", "0.5", "1 (stockout)"], tickfont=dict(size=9), ), line=dict(width=1.5, color="white"), ), customdata=custom_0, hovertemplate=( "%{text}
" "Role: %{customdata[2]}
" "Total inventory: %{customdata[0]:,} units
" "Total backlog: %{customdata[1]:,} units
" "" ), showlegend=False, name="nodes", ) # ── Initial shipment trace (t=0) ─────────────────────────────────────── lats0, lons0, texts0, colors0 = _frame_particles(result, 0) ship_trace = go.Scattergeo( lat=lats0, lon=lons0, mode="markers", marker=dict(size=7, color=colors0, opacity=0.85, line=dict(width=0.5, color="white")), hoverinfo="text", text=texts0, showlegend=False, name="shipments", ) # ── Legend: item colors (shipment dots) ────────────────────────────── item_legend_traces = [] for i, iid in enumerate(result.item_ids): item_legend_traces.append(go.Scattergeo( lat=[None], lon=[None], mode="markers", marker=dict(size=9, color=ITEM_COLORS[i % len(ITEM_COLORS)], symbol="circle"), name=iid, legendgrouptitle_text="Shipment items" if i == 0 else None, legendgroup="items", showlegend=True, )) # ── Legend: node-type symbols ───────────────────────────────────────── # One proxy trace per distinct tier present in this result. seen_tiers: dict = {} for n in node_ids: t_val = result.tier.get(n, -1) if t_val not in seen_tiers: seen_tiers[t_val] = (_TIER_TYPE_NAME.get(t_val, "Unknown"), _TIER_SYMBOL.get(t_val, "circle")) node_type_legend_traces = [] for i, (t_val, (tname, sym)) in enumerate( sorted(seen_tiers.items(), key=lambda x: x[0])): node_type_legend_traces.append(go.Scattergeo( lat=[None], lon=[None], mode="markers", marker=dict(size=10, color="gray", symbol=sym), name=tname, legendgrouptitle_text="Node types" if i == 0 else None, legendgroup="node_types", showlegend=True, )) # ── Assemble base figure ─────────────────────────────────────────────── all_traces = (edge_traces + [node_trace, ship_trace] + item_legend_traces + node_type_legend_traces) fig = go.Figure(data=all_traces) # Trace indices of dynamic traces (edge + node + ship; legend traces follow) idx_nodes = n_edge_traces idx_ships = n_edge_traces + 1 # ── Pre-compute frames ───────────────────────────────────────────────── frames = [] for t in frame_times: stress_t = _node_stress(result, t) custom_t = _node_hover_custom(result, t, node_type_names) lats_t, lons_t, texts_t, colors_t = _frame_particles(result, t) frames.append(go.Frame( data=[ # Plain dicts avoid serialising default None lat/lon values # that go.Scattergeo() would inject and confuse Plotly.js. {"type": "scattergeo", "marker": {"color": stress_t}, "customdata": custom_t}, {"type": "scattergeo", "lat": lats_t, "lon": lons_t, "text": texts_t, "marker": {"color": colors_t, "size": 7, "opacity": 0.85}}, ], traces=[idx_nodes, idx_ships], layout={"title": {"text": f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}"}}, name=str(t), )) fig.frames = frames # ── Slider steps ────────────────────────────────────────────────────── slider_steps = [ dict( method="animate", args=[[str(t)], dict(mode="immediate", frame=dict(duration=0, redraw=True), transition=dict(duration=0))], label=str(t), ) for t in frame_times ] # ── Layout ──────────────────────────────────────────────────────────── fig.update_layout( title=dict( text=f"ISOMORPH Supply Chain Digital Twin — Day 0 / {T - 1}", x=0.5, xanchor="center", font=dict(size=14), ), geo=dict( scope="usa", projection_type="albers usa", showland=True, landcolor="rgb(243,243,243)", showlakes=True, lakecolor="rgb(210,230,255)", showrivers=True, rivercolor="rgb(210,230,255)", showcoastlines=True, coastlinecolor="rgb(180,180,200)", showsubunits=True, subunitcolor="rgb(200,200,215)", bgcolor="rgba(255,255,255,0)", ), legend=dict( x=0.01, y=0.01, bgcolor="rgba(255,255,255,0.82)", bordercolor="lightgray", borderwidth=1, font=dict(size=9), tracegroupgap=6, ), updatemenus=[dict( type="buttons", showactive=False, y=1.08, x=0.0, xanchor="left", buttons=[ dict( label="▶ Play", method="animate", args=[None, dict( frame=dict(duration=frame_duration_ms, redraw=True), fromcurrent=True, transition=dict(duration=0), )], ), dict( label="⏸ Pause", method="animate", args=[[None], dict( frame=dict(duration=0, redraw=False), mode="immediate", transition=dict(duration=0), )], ), ], )], sliders=[dict( currentvalue=dict( prefix="Day: ", font=dict(size=11), visible=True, xanchor="center", ), pad=dict(t=50, b=10), len=0.9, x=0.05, steps=slider_steps, transition=dict(duration=0), )], margin=dict(l=0, r=0, t=80, b=60), height=540, paper_bgcolor="white", ) return fig # ============================================================================ # 2. Node detail time-series panel # ============================================================================ def make_node_timeseries( result, node_id: str, item_ids: Optional[List[str]] = None, ) -> go.Figure: """ 4–5 subplot panel for a single node showing inventory, backlog, inflow, outflow, and (destination only) realized demand. Parameters ---------- result : SimResult node_id : str Node to visualise. item_ids : list[str] or None Subset of items to plot. Defaults to all items in result. """ if item_ids is None: item_ids = result.item_ids dest_id = next(n for n in result.node_ids if result.tier.get(n) == 0) is_dest = (node_id == dest_id) # Use plain Python lists — avoids numpy int32/int64 serialization issues # inside Gradio 4.x's gr.Plot JSON path. days = list(range(result.T)) n_panels = 5 if is_dest else 4 panel_titles = [ "On-hand inventory (units stored at this node)", "Backlog (unfulfilled demand pending)", "Inflow (units arriving per day)", "Outflow (units shipped out per day)", ] if is_dest: panel_titles.append("Realized demand (customer orders received per day)") fig = make_subplots( rows=n_panels, cols=1, shared_xaxes=True, subplot_titles=panel_titles, vertical_spacing=0.06, ) for idx, iid in enumerate(item_ids): color = ITEM_COLORS[idx % len(ITEM_COLORS)] # Capture loop variables by value via default args to avoid closure issues. def _add(row, y_arr, dash="solid", _iid=iid, _color=color): fig.add_trace( go.Scatter( x=days, y=y_arr.tolist(), # convert np.int32 → Python ints mode="lines", name=_iid, line=dict(color=_color, width=1.5, dash=dash), legendgroup=_iid, showlegend=(row == 1), hovertemplate=f"Day %{{x}}
{_iid}: %{{y:,.0f}}", ), row=row, col=1, ) _add(1, result.inventory[node_id][iid]) _add(2, result.backlog[node_id][iid]) _add(3, result.inflow[node_id][iid]) _add(4, result.outflow[node_id][iid]) if is_dest: _add(5, result.demand[iid], dash="dot") # Disruption event markers — one vline per panel. # add_vline(row=, col=) was not available in all Plotly 5.x builds; # iterate over subplot y-axis references instead. if result.disruption_log: yref_list = ["y"] + [f"y{i}" for i in range(2, n_panels + 1)] for ev in result.disruption_log: for yref in yref_list: fig.add_shape( type="line", x0=ev["day"], x1=ev["day"], y0=0, y1=1, xref="x", yref=f"{yref} domain", line=dict(dash="dash", color="rgba(255,80,80,0.5)", width=1), ) # Panel y-axis labels y_labels = ["Units", "Units", "Units/day", "Units/day"] if is_dest: y_labels.append("Units/day") for row, lbl in enumerate(y_labels, start=1): fig.update_yaxes(title_text=lbl, row=row, col=1, title_font=dict(size=10), tickfont=dict(size=9)) fig.update_xaxes(title_text="Day", row=n_panels, col=1, tickfont=dict(size=9)) tier_name = _TIER_LABEL.get(result.tier.get(node_id, -1), "") type_name = _TIER_TYPE_NAME.get(result.tier.get(node_id, -1), "") disruption_note = ( " · red dashes = disruption events" if result.disruption_log else "" ) fig.update_layout( title=dict( text=(f"{node_id} — {type_name}" f"{disruption_note}"), x=0.5, xanchor="center", font=dict(size=13), ), legend=dict( title="Items (SKUs)", x=1.01, y=1.0, font=dict(size=10), bgcolor="rgba(255,255,255,0.8)", bordercolor="lightgray", borderwidth=1, ), height=180 * n_panels + 80, paper_bgcolor="white", plot_bgcolor="rgba(248,248,252,1)", margin=dict(l=60, r=120, t=60, b=50), ) # Light grid fig.update_xaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") return fig # ============================================================================ # 3. Bullwhip amplification chart # ============================================================================ def make_bullwhip_chart(result) -> go.Figure: """ Bar chart of tier-level bullwhip ratio B = Var(inflow)/Var(outflow). Bars are grouped by item; tiers run from destination (left) to hub (right). A dashed reference line at B = 1 marks the no-amplification baseline. """ ratios = _bullwhip_ratios(result) # Collect tiers present (excluding sources, tier 6) tiers_present = sorted( {result.tier[n] for n in ratios if ratios[n]}, ) # Per-item bar traces fig = go.Figure() for idx, iid in enumerate(result.item_ids): tier_means = [] for t in tiers_present: nodes_in_tier = [n for n in ratios if result.tier.get(n) == t and iid in ratios[n]] if nodes_in_tier: tier_means.append( float(np.mean([ratios[n][iid] for n in nodes_in_tier])) ) else: tier_means.append(None) x_labels = [_TIER_LABEL.get(t, f"Tier {t}") for t in tiers_present] fig.add_trace(go.Bar( x=x_labels, y=tier_means, name=iid, marker_color=ITEM_COLORS[idx % len(ITEM_COLORS)], opacity=0.82, text=[f"{v:.2f}" if v is not None else "" for v in tier_means], textposition="outside", textfont=dict(size=9), )) # Reference line at B = 1 fig.add_hline( y=1.0, line_dash="dash", line_color="rgba(80,80,80,0.6)", line_width=1.5, annotation_text="B = 1 (no amplification)", annotation_position="top right", annotation_font_size=9, ) fig.update_layout( title=dict( text=( "Bullwhip Amplification by Tier
" "B > 1 means demand variability grows as orders travel upstream" ), x=0.5, xanchor="center", font=dict(size=13), ), xaxis=dict( title="Network tier (NewYork = downstream → Suppliers = upstream)", title_font=dict(size=11), tickfont=dict(size=10), ), yaxis=dict( title="B = Var(inflow) / Var(outflow)", title_font=dict(size=11), tickfont=dict(size=9), rangemode="tozero", ), barmode="group", legend=dict( title="Items", x=1.01, y=1.0, font=dict(size=10), bgcolor="rgba(255,255,255,0.8)", bordercolor="lightgray", borderwidth=1, ), height=400, paper_bgcolor="white", plot_bgcolor="rgba(248,248,252,1)", margin=dict(l=60, r=120, t=60, b=60), ) fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") return fig # ============================================================================ # 4. Edge utilization heatmap # ============================================================================ def make_edge_heatmap(result) -> go.Figure: """ Heatmap of edge utilization (fraction of daily capacity) over time. Y-axis: edges sorted downstream→upstream. X-axis: simulation day. Color: 0 (white/green) = empty → 1 (red) = at capacity. Disruption events are shown as vertical dashed lines. """ # Sort edges: downstream edges first (by tier of source node) def _edge_tier(eid): return result.tier.get(eid[0], 99) sorted_edges = sorted(result.edge_ids, key=_edge_tier) edge_labels = [f"{u} → {v}" for u, v in sorted_edges] z = np.array([result.edge_util[eid] for eid in sorted_edges], dtype=np.float64) # shape (n_edges, T) days = list(range(result.T)) fig = go.Figure(go.Heatmap( z=z.tolist(), x=days, y=edge_labels, colorscale=[ [0.0, "rgb(240,248,255)"], # near-empty: very light blue [0.5, "rgb(255,200,100)"], # half full: yellow-orange [1.0, "rgb(220,50,50)"], # saturated: red ], zmin=0.0, zmax=1.0, colorbar=dict( title="Utilization
fraction", thickness=13, len=0.7, tickvals=[0, 0.5, 1], ticktext=["0%", "50%", "100%"], tickfont=dict(size=9), ), hovertemplate=( "Edge: %{y}
Day: %{x}
Utilization: %{z:.1%}" ), )) # Disruption event markers for ev in result.disruption_log: fig.add_vline( x=ev["day"], line_dash="dash", line_color="rgba(80,80,255,0.6)", line_width=1.5, ) fig.add_annotation( x=ev["day"], y=1.01, xref="x", yref="paper", text="disruption", showarrow=False, font=dict(size=8, color="blue"), textangle=-90, ) fig.update_layout( title=dict( text=( "Edge Utilization over Time
" "Fraction of daily shipping capacity used per edge " "(0 = empty, 1 = fully saturated). " "Blue dashes mark disruption events." ), x=0.5, xanchor="center", font=dict(size=13), ), xaxis=dict( title="Day", title_font=dict(size=11), tickfont=dict(size=9), ), yaxis=dict( title="Edge", title_font=dict(size=11), tickfont=dict(size=9), autorange="reversed", # downstream edges at top ), height=max(320, 30 * len(sorted_edges) + 120), paper_bgcolor="white", margin=dict(l=160, r=80, t=60, b=60), ) return fig # ============================================================================ # 5. HTML wrapper for animated network map (Gradio gr.HTML compatible) # ============================================================================ def make_network_animation_html( result, frame_step: Optional[int] = None, frame_duration_ms: int = 150, ) -> str: """ Return an ' ) # ============================================================================ # 6. Animated GIF export # ============================================================================ def make_network_animation_gif( result, frame_step: Optional[int] = None, frame_duration_ms: int = 150, max_frames: int = 80, ) -> str: """ Render the network animation as an animated GIF and return the temp file path. Uses matplotlib (Agg backend) with cartopy for the US map background. No kaleido or Chrome required. Parameters ---------- result : SimResult frame_step : int or None Days between frames. Auto-computed to cap at max_frames. frame_duration_ms : int Milliseconds per frame during playback. max_frames : int Maximum number of frames to render (keeps file size reasonable). """ import io as _io import tempfile import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.colors as mcolors try: from PIL import Image except ImportError: raise ImportError("Pillow is required for GIF export: pip install Pillow") try: import cartopy.crs as ccrs import cartopy.feature as cfeature _PROJ = ccrs.LambertConformal(central_longitude=-96, central_latitude=39, standard_parallels=(33, 45)) _GEO = ccrs.PlateCarree() HAS_CARTOPY = True except ImportError: HAS_CARTOPY = False # Stress colormap: green(0) → yellow(0.35) → red(1.0) stress_cmap = mcolors.LinearSegmentedColormap.from_list( "stress", [(0.00, "#4CAF50"), (0.35, "#FFC107"), (1.00, "#F44336")], ) stress_norm = mcolors.Normalize(vmin=0.0, vmax=1.0) T = result.T if frame_step is None: frame_step = max(1, T // max_frames) frame_times = list(range(0, T, frame_step)) coords = result.node_coords node_ids = result.node_ids node_lons = [coords[n][1] for n in node_ids] node_lats = [coords[n][0] for n in node_ids] # s is marker area in pt² — keep proportional to Plotly sizes but much smaller node_s = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) * 12 for n in node_ids] max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0 # Precompute: (lon0, lat0, lon1, lat1, linewidth) edges: List[Tuple] = [] for eid in result.edge_ids: u, v = eid if u not in coords or v not in coords: continue cap = result.edge_cap.get(eid, 1.0) lw = 0.6 + 2.0 * (cap / max_cap) ** 0.5 edges.append((coords[u][1], coords[u][0], coords[v][1], coords[v][0], lw)) W, H, DPI = 8.5, 5.0, 90 pil_frames = [] for t in frame_times: stress_t = _node_stress(result, t) node_colors = [stress_cmap(stress_norm(s)) for s in stress_t] # RGBA tuples node_markers = [_MPL_MARKER.get(_TIER_SYMBOL.get(result.tier.get(n, 3), "circle"), "o") for n in node_ids] lats_t, lons_t, _, colors_t = _frame_particles(result, t) if HAS_CARTOPY: fig = plt.figure(figsize=(W, H), dpi=DPI, facecolor="#f4f6f9") ax = fig.add_subplot(1, 1, 1, projection=_PROJ) ax.set_extent([-125, -66, 24, 50], crs=_GEO) ax.add_feature(cfeature.LAND, facecolor="#f0f0ee", zorder=0) ax.add_feature(cfeature.OCEAN, facecolor="#d0e4f0", zorder=0) ax.add_feature(cfeature.LAKES, facecolor="#d0e4f0", alpha=0.6, zorder=1) ax.add_feature(cfeature.STATES, edgecolor="#c0c0c0", linewidth=0.4, zorder=2) ax.add_feature(cfeature.COASTLINE,edgecolor="#999999", linewidth=0.5, zorder=2) for lon0, lat0, lon1, lat1, lw in edges: ax.plot([lon0, lon1], [lat0, lat1], transform=_GEO, color="#7a7a90", linewidth=lw, alpha=0.65, zorder=3, solid_capstyle="round") for nid, lon, lat, color, s, mkr in zip( node_ids, node_lons, node_lats, node_colors, node_s, node_markers): ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=4, linewidths=1.5, edgecolors="white", alpha=0.85, transform=_GEO) ax.text(lon, lat + 0.9, nid, transform=_GEO, fontsize=6.5, ha="center", va="bottom", fontweight="bold", color="#1a1a2e", zorder=6, bbox=dict(boxstyle="round,pad=0.12", fc="white", ec="none", alpha=0.65)) if lons_t: rgba_t = [mcolors.to_rgba(c) for c in colors_t] ax.scatter(lons_t, lats_t, c=rgba_t, s=45, zorder=8, linewidths=0.5, edgecolors="white", alpha=1.0, transform=_GEO) ax.set_title( f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}", fontsize=10, fontweight="bold", pad=6, color="#1a1a2e", ) fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01) else: # Fallback: plain axes, no geographic background fig, ax = plt.subplots(figsize=(W, H), dpi=DPI) fig.patch.set_facecolor("#f4f6f9") ax.set_facecolor("#dce8f0") ax.set_xlim(-128, -65) ax.set_ylim(23, 50) ax.axis("off") for lon0, lat0, lon1, lat1, lw in edges: ax.plot([lon0, lon1], [lat0, lat1], color="#7a7a90", linewidth=lw, alpha=0.65, zorder=1, solid_capstyle="round") for nid, lon, lat, color, s, mkr in zip( node_ids, node_lons, node_lats, node_colors, node_s, node_markers): ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=3, linewidths=1.5, edgecolors="white", alpha=0.85) ax.text(lon, lat + 0.8, nid, fontsize=6.5, ha="center", va="bottom", fontweight="bold", color="#1a1a2e", zorder=5, bbox=dict(boxstyle="round,pad=0.12", fc="white", ec="none", alpha=0.65)) if lons_t: ax.scatter(lons_t, lats_t, c=colors_t, s=45, zorder=8, linewidths=0.5, edgecolors="white", alpha=0.95) ax.set_title( f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}", fontsize=10, fontweight="bold", pad=6, color="#1a1a2e", ) fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01) buf = _io.BytesIO() fig.savefig(buf, format="png", dpi=DPI, facecolor=fig.get_facecolor()) plt.close(fig) buf.seek(0) pil_frames.append(Image.open(buf).convert("RGB")) # Normalize all frames to the same size (first frame drives dimensions) W_px, H_px = pil_frames[0].size pil_frames = [f.resize((W_px, H_px), Image.LANCZOS) for f in pil_frames] # GIF supports only 256 colours. The map background consumes many palette # slots with similar grays/blues, crowding out item and stress colours. # Fix: append large solid-colour blocks of every important colour to the # combined image before quantising, so the median-cut algorithm is forced # to reserve palette entries for them. pinned_hex = ( ITEM_COLORS[:5] # all 5 item colours + ["#4CAF50", "#8BC34A", "#FFC107", "#FF9800", # stress ramp "#FF5722", "#F44336"] + ["#ffffff", "#1a1a2e", "#7a7a90", # white / text / edges "#f0f0ee", "#d0e4f0", "#c0c0c0"] # map background tones ) BLOCK = 2000 # large block → many pixels → dominates palette selection ref_strips = [] for h in pinned_hex: r, g, b = int(h.lstrip("#")[0:2], 16), int(h.lstrip("#")[2:4], 16), int(h.lstrip("#")[4:6], 16) ref_strips.append(Image.new("RGB", (BLOCK, H_px), (r, g, b))) ref = Image.new("RGB", (BLOCK * len(ref_strips), H_px)) for i, s in enumerate(ref_strips): ref.paste(s, (i * BLOCK, 0)) combined = Image.new("RGB", (W_px * len(pil_frames) + ref.width, H_px)) for i, f in enumerate(pil_frames): combined.paste(f, (i * W_px, 0)) combined.paste(ref, (W_px * len(pil_frames), 0)) palette_source = combined.quantize(colors=255, method=Image.Quantize.MEDIANCUT) quantized = [f.quantize(palette=palette_source, dither=0) for f in pil_frames] tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False, prefix="isomorph_") quantized[0].save( tmp.name, save_all=True, append_images=quantized[1:], loop=0, duration=frame_duration_ms, optimize=False, ) tmp.close() return tmp.name