Spaces:
Running
Running
| """ | |
| ISOMORPH Demo — Visualization Module | |
| ===================================== | |
| All functions return plotly.graph_objects.Figure objects for rendering in | |
| Gradio gr.Plot components. | |
| Public API | |
| ---------- | |
| make_network_animation(result, frame_step=None) | |
| Animated US map: shipment particles travel along edges, nodes colored | |
| by backlog stress. Play / Pause controls + scrubber slider. | |
| make_node_timeseries(result, node_id, item_ids=None) | |
| Subplot panel for one node: inventory, backlog, inflow, outflow, | |
| and (at the destination) realized demand. | |
| make_bullwhip_chart(result) | |
| Bar chart of mean B = Var(inflow)/Var(outflow) per tier, from | |
| destination (left) to hub/sources (right). | |
| make_edge_heatmap(result) | |
| Heatmap of edge utilization (fraction of daily capacity) over time. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| # ── Color constants ─────────────────────────────────────────────────────────── | |
| # One color per item (up to 5) — matplotlib tab10 subset | |
| ITEM_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd", "#d62728"] | |
| # Node stress: green (0 = no backlog) → yellow → red (1 = all backlog) | |
| NODE_COLORSCALE = [[0.0, "#4CAF50"], [0.35, "#FFC107"], [1.0, "#F44336"]] | |
| # Visual size of node marker by tier | |
| _TIER_MARKER_SIZE = {0: 22, 1: 16, 2: 13, 3: 13, 4: 15, 5: 18, 6: 11} | |
| # Human-readable tier labels (for bullwhip x-axis) | |
| _TIER_LABEL = { | |
| 0: "Destination", | |
| 1: "Last-mile", | |
| 2: "Tier-4", | |
| 3: "Tier-3", | |
| 4: "Tier-2<br>(Atlanta)", | |
| 5: "Hub<br>(Nashville)", | |
| 6: "Sources", | |
| } | |
| # Descriptive role name shown in hover tooltip and node-type legend | |
| _TIER_TYPE_NAME = { | |
| 0: "Destination (end customer)", | |
| 1: "Last-mile DC", | |
| 2: "Tier-4 Warehouse", | |
| 3: "Tier-3 Warehouse", | |
| 4: "Tier-2 Warehouse (Atlanta)", | |
| 5: "Regional Hub (Nashville)", | |
| 6: "Source / Supplier", | |
| } | |
| # Plotly Scattergeo marker symbol per tier — visually distinguishes role | |
| _TIER_SYMBOL = { | |
| 0: "star", # Destination — prominent star | |
| 1: "square", # Last-mile DCs | |
| 2: "circle", # Tier-4 warehouses | |
| 3: "circle", # Tier-3 warehouses | |
| 4: "diamond", # Tier-2 (Atlanta pivot) | |
| 5: "hexagram", # Hub (Nashville) | |
| 6: "triangle-up", # Source suppliers | |
| } | |
| # Matplotlib equivalents for GIF rendering (one marker per node requires individual scatter calls) | |
| _MPL_MARKER = { | |
| "star": "*", | |
| "square": "s", | |
| "circle": "o", | |
| "diamond": "D", | |
| "hexagram": "h", | |
| "triangle-up": "^", | |
| } | |
| # ============================================================================ | |
| # Private helpers | |
| # ============================================================================ | |
| def _edge_travel_times(result) -> Dict[Tuple[str, str], float]: | |
| """Derive edge travel times from the first shipment that uses each hop.""" | |
| tt: Dict[Tuple[str, str], float] = {} | |
| for s in result.shipments: | |
| path = s["path_nodes"] | |
| ets = s["edge_times"] | |
| for h in range(len(ets)): | |
| hop = (path[h], path[h + 1]) | |
| if hop not in tt: | |
| tt[hop] = float(ets[h]) | |
| if len(tt) == len(result.edge_ids): | |
| return tt # early exit once all edges covered | |
| return tt | |
| def _node_stress(result, t: int) -> List[float]: | |
| """ | |
| Return a per-node backlog-fraction in [0, 1] at timestep t. | |
| stress = total_backlog / (total_backlog + total_inventory + 1) | |
| Ordered to match result.node_ids. | |
| """ | |
| fracs = [] | |
| for nid in result.node_ids: | |
| bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids) | |
| inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids) | |
| fracs.append(bl / max(bl + inv, 1)) | |
| return fracs | |
| def _node_hover_custom(result, t: int, | |
| type_names: Optional[List[str]] = None) -> List[List]: | |
| """ | |
| Return per-node customdata rows [inv, bl, type_name] at time t. | |
| type_names (static) is passed once and repeated in every frame so the | |
| hovertemplate can reference %{customdata[2]}. | |
| """ | |
| rows = [] | |
| for i, nid in enumerate(result.node_ids): | |
| inv = sum(int(result.inventory[nid][iid][t]) for iid in result.item_ids) | |
| bl = sum(int(result.backlog[nid][iid][t]) for iid in result.item_ids) | |
| tname = type_names[i] if type_names else "" | |
| rows.append([inv, bl, tname]) | |
| return rows | |
| def _interpolate_position( | |
| shipment: dict, | |
| t: int, | |
| node_coords: Dict[str, Tuple[float, float]], | |
| ) -> Optional[Tuple[float, float]]: | |
| """ | |
| Return (lat, lon) for a shipment at integer day t, or None if not active. | |
| The particle travels along path_nodes using edge_times for timing. | |
| """ | |
| d = shipment["day"] | |
| a = shipment["arrival_day"] | |
| if t < d or t >= a: | |
| return None | |
| path = shipment["path_nodes"] | |
| ets = shipment["edge_times"] | |
| elapsed = float(t - d) | |
| cum = 0.0 | |
| for h, tt in enumerate(ets): | |
| tt = float(tt) | |
| if elapsed <= cum + tt + 1e-9: | |
| frac = (elapsed - cum) / max(tt, 1e-6) | |
| frac = max(0.0, min(1.0, frac)) | |
| lat_a, lon_a = node_coords[path[h]] | |
| lat_b, lon_b = node_coords[path[h + 1]] | |
| return lat_a + frac * (lat_b - lat_a), lon_a + frac * (lon_b - lon_a) | |
| cum += tt | |
| return None | |
| def _frame_particles(result, t: int, item_filter: Optional[str] = None): | |
| """ | |
| Return (lats, lons, texts, colors) for all active shipment particles | |
| at timestep t. colors is a list of CSS color strings keyed by item. | |
| """ | |
| item_color_map = {iid: ITEM_COLORS[i % len(ITEM_COLORS)] | |
| for i, iid in enumerate(result.item_ids)} | |
| lats, lons, texts, colors = [], [], [], [] | |
| for s in result.shipments: | |
| if item_filter and s["item"] != item_filter: | |
| continue | |
| pos = _interpolate_position(s, t, result.node_coords) | |
| if pos is None: | |
| continue | |
| lat, lon = pos | |
| lats.append(lat) | |
| lons.append(lon) | |
| texts.append( | |
| f"<b>{s['item']}</b><br>{s['from']}→{s['to']}<br>" | |
| f"qty: {s['units']} day {s['day']}→{s['arrival_day']}" | |
| ) | |
| colors.append(item_color_map[s["item"]]) | |
| return lats, lons, texts, colors | |
| def _bullwhip_ratios(result) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Compute B_(n,i) = Var(inflow_(n,i)) / Var(outflow_(n,i)) per node/item. | |
| At the destination node outflow is replaced by customer demand. | |
| Source nodes (tier 6) are excluded (no inflow on network edges). | |
| """ | |
| dest_id = next( | |
| nid for nid in result.node_ids if result.tier.get(nid) == 0 | |
| ) | |
| ratios: Dict[str, Dict[str, float]] = {} | |
| for nid in result.node_ids: | |
| tier = result.tier.get(nid, -1) | |
| if tier == 6: | |
| continue | |
| ratios[nid] = {} | |
| for iid in result.item_ids: | |
| inflow_arr = result.inflow[nid][iid].astype(float) | |
| if nid == dest_id: | |
| outflow_arr = result.demand[iid].astype(float) | |
| else: | |
| outflow_arr = result.outflow[nid][iid].astype(float) | |
| var_in = float(np.var(inflow_arr, ddof=1)) if result.T > 1 else 0.0 | |
| var_out = float(np.var(outflow_arr, ddof=1)) if result.T > 1 else 0.0 | |
| if var_out > 1.0: # guard against near-zero outflow variance | |
| ratios[nid][iid] = var_in / var_out | |
| return ratios | |
| # ============================================================================ | |
| # 1. Animated network map | |
| # ============================================================================ | |
| def make_network_animation( | |
| result, | |
| frame_step: Optional[int] = None, | |
| frame_duration_ms: int = 150, | |
| ) -> go.Figure: | |
| """ | |
| Animated Scattergeo map of the supply-chain network. | |
| Parameters | |
| ---------- | |
| result : SimResult | |
| frame_step : int or None | |
| Days between animation frames. Auto-computed to cap at ~200 frames. | |
| frame_duration_ms : int | |
| Milliseconds per frame during playback. | |
| """ | |
| T = result.T | |
| max_frames = 200 | |
| if frame_step is None: | |
| frame_step = max(1, T // max_frames) | |
| frame_times = list(range(0, T, frame_step)) | |
| node_ids = result.node_ids | |
| coords = result.node_coords | |
| edge_tt = _edge_travel_times(result) | |
| # ── Node metadata ────────────────────────────────────────────────────── | |
| node_lats = [coords[n][0] for n in node_ids] | |
| node_lons = [coords[n][1] for n in node_ids] | |
| node_sizes = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) | |
| for n in node_ids] | |
| node_labels = node_ids | |
| node_type_names = [_TIER_TYPE_NAME.get(result.tier.get(n, -1), "Warehouse") | |
| for n in node_ids] | |
| node_symbols = [_TIER_SYMBOL.get(result.tier.get(n, 3), "circle") | |
| for n in node_ids] | |
| # ── Edge traces (static) ─────────────────────────────────────────────── | |
| edge_traces = [] | |
| max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0 | |
| for eid in result.edge_ids: | |
| u, v = eid | |
| if u not in coords or v not in coords: | |
| continue | |
| cap = result.edge_cap.get(eid, 1.0) | |
| lw = 1.0 + 3.5 * (cap / max_cap) ** 0.5 | |
| tt = edge_tt.get(eid, "?") | |
| edge_traces.append(go.Scattergeo( | |
| lat=[coords[u][0], coords[v][0], None], | |
| lon=[coords[u][1], coords[v][1], None], | |
| mode="lines", | |
| line=dict(width=lw, color="rgba(140,140,160,0.55)"), | |
| hoverinfo="text", | |
| text=f"{u} → {v}<br>travel: {tt} day(s)<br>daily cap: {cap:.0f} units", | |
| showlegend=False, | |
| name=f"{u}→{v}", | |
| )) | |
| n_edge_traces = len(edge_traces) | |
| # ── Initial node trace (t=0) ─────────────────────────────────────────── | |
| stress_0 = _node_stress(result, 0) | |
| custom_0 = _node_hover_custom(result, 0, node_type_names) | |
| node_trace = go.Scattergeo( | |
| lat=node_lats, | |
| lon=node_lons, | |
| mode="markers+text", | |
| text=node_labels, | |
| textposition="top center", | |
| textfont=dict(size=9, color="black"), | |
| marker=dict( | |
| size=node_sizes, | |
| symbol=node_symbols, | |
| color=stress_0, | |
| colorscale=NODE_COLORSCALE, | |
| cmin=0.0, cmax=1.0, | |
| colorbar=dict( | |
| title="Backlog<br>stress", | |
| thickness=12, | |
| len=0.5, | |
| x=1.01, | |
| tickvals=[0, 0.5, 1], | |
| ticktext=["0 (healthy)", "0.5", "1 (stockout)"], | |
| tickfont=dict(size=9), | |
| ), | |
| line=dict(width=1.5, color="white"), | |
| ), | |
| customdata=custom_0, | |
| hovertemplate=( | |
| "<b>%{text}</b><br>" | |
| "Role: %{customdata[2]}<br>" | |
| "Total inventory: %{customdata[0]:,} units<br>" | |
| "Total backlog: %{customdata[1]:,} units<br>" | |
| "<extra></extra>" | |
| ), | |
| showlegend=False, | |
| name="nodes", | |
| ) | |
| # ── Initial shipment trace (t=0) ─────────────────────────────────────── | |
| lats0, lons0, texts0, colors0 = _frame_particles(result, 0) | |
| ship_trace = go.Scattergeo( | |
| lat=lats0, | |
| lon=lons0, | |
| mode="markers", | |
| marker=dict(size=7, color=colors0, opacity=0.85, | |
| line=dict(width=0.5, color="white")), | |
| hoverinfo="text", | |
| text=texts0, | |
| showlegend=False, | |
| name="shipments", | |
| ) | |
| # ── Legend: item colors (shipment dots) ────────────────────────────── | |
| item_legend_traces = [] | |
| for i, iid in enumerate(result.item_ids): | |
| item_legend_traces.append(go.Scattergeo( | |
| lat=[None], lon=[None], | |
| mode="markers", | |
| marker=dict(size=9, color=ITEM_COLORS[i % len(ITEM_COLORS)], | |
| symbol="circle"), | |
| name=iid, | |
| legendgrouptitle_text="Shipment items" if i == 0 else None, | |
| legendgroup="items", | |
| showlegend=True, | |
| )) | |
| # ── Legend: node-type symbols ───────────────────────────────────────── | |
| # One proxy trace per distinct tier present in this result. | |
| seen_tiers: dict = {} | |
| for n in node_ids: | |
| t_val = result.tier.get(n, -1) | |
| if t_val not in seen_tiers: | |
| seen_tiers[t_val] = (_TIER_TYPE_NAME.get(t_val, "Unknown"), | |
| _TIER_SYMBOL.get(t_val, "circle")) | |
| node_type_legend_traces = [] | |
| for i, (t_val, (tname, sym)) in enumerate( | |
| sorted(seen_tiers.items(), key=lambda x: x[0])): | |
| node_type_legend_traces.append(go.Scattergeo( | |
| lat=[None], lon=[None], | |
| mode="markers", | |
| marker=dict(size=10, color="gray", symbol=sym), | |
| name=tname, | |
| legendgrouptitle_text="Node types" if i == 0 else None, | |
| legendgroup="node_types", | |
| showlegend=True, | |
| )) | |
| # ── Assemble base figure ─────────────────────────────────────────────── | |
| all_traces = (edge_traces + [node_trace, ship_trace] | |
| + item_legend_traces + node_type_legend_traces) | |
| fig = go.Figure(data=all_traces) | |
| # Trace indices of dynamic traces (edge + node + ship; legend traces follow) | |
| idx_nodes = n_edge_traces | |
| idx_ships = n_edge_traces + 1 | |
| # ── Pre-compute frames ───────────────────────────────────────────────── | |
| frames = [] | |
| for t in frame_times: | |
| stress_t = _node_stress(result, t) | |
| custom_t = _node_hover_custom(result, t, node_type_names) | |
| lats_t, lons_t, texts_t, colors_t = _frame_particles(result, t) | |
| frames.append(go.Frame( | |
| data=[ | |
| # Plain dicts avoid serialising default None lat/lon values | |
| # that go.Scattergeo() would inject and confuse Plotly.js. | |
| {"type": "scattergeo", | |
| "marker": {"color": stress_t}, | |
| "customdata": custom_t}, | |
| {"type": "scattergeo", | |
| "lat": lats_t, "lon": lons_t, | |
| "text": texts_t, | |
| "marker": {"color": colors_t, "size": 7, "opacity": 0.85}}, | |
| ], | |
| traces=[idx_nodes, idx_ships], | |
| layout={"title": {"text": f"<b>ISOMORPH Supply Chain Digital Twin</b> — Day {t} / {T - 1}"}}, | |
| name=str(t), | |
| )) | |
| fig.frames = frames | |
| # ── Slider steps ────────────────────────────────────────────────────── | |
| slider_steps = [ | |
| dict( | |
| method="animate", | |
| args=[[str(t)], | |
| dict(mode="immediate", | |
| frame=dict(duration=0, redraw=True), | |
| transition=dict(duration=0))], | |
| label=str(t), | |
| ) | |
| for t in frame_times | |
| ] | |
| # ── Layout ──────────────────────────────────────────────────────────── | |
| fig.update_layout( | |
| title=dict( | |
| text=f"<b>ISOMORPH Supply Chain Digital Twin</b> — Day 0 / {T - 1}", | |
| x=0.5, xanchor="center", font=dict(size=14), | |
| ), | |
| geo=dict( | |
| scope="usa", | |
| projection_type="albers usa", | |
| showland=True, landcolor="rgb(243,243,243)", | |
| showlakes=True, lakecolor="rgb(210,230,255)", | |
| showrivers=True, rivercolor="rgb(210,230,255)", | |
| showcoastlines=True, coastlinecolor="rgb(180,180,200)", | |
| showsubunits=True, subunitcolor="rgb(200,200,215)", | |
| bgcolor="rgba(255,255,255,0)", | |
| ), | |
| legend=dict( | |
| x=0.01, y=0.01, | |
| bgcolor="rgba(255,255,255,0.82)", | |
| bordercolor="lightgray", borderwidth=1, | |
| font=dict(size=9), | |
| tracegroupgap=6, | |
| ), | |
| updatemenus=[dict( | |
| type="buttons", | |
| showactive=False, | |
| y=1.08, x=0.0, xanchor="left", | |
| buttons=[ | |
| dict( | |
| label="▶ Play", | |
| method="animate", | |
| args=[None, dict( | |
| frame=dict(duration=frame_duration_ms, redraw=True), | |
| fromcurrent=True, | |
| transition=dict(duration=0), | |
| )], | |
| ), | |
| dict( | |
| label="⏸ Pause", | |
| method="animate", | |
| args=[[None], dict( | |
| frame=dict(duration=0, redraw=False), | |
| mode="immediate", | |
| transition=dict(duration=0), | |
| )], | |
| ), | |
| ], | |
| )], | |
| sliders=[dict( | |
| currentvalue=dict( | |
| prefix="Day: ", | |
| font=dict(size=11), | |
| visible=True, | |
| xanchor="center", | |
| ), | |
| pad=dict(t=50, b=10), | |
| len=0.9, | |
| x=0.05, | |
| steps=slider_steps, | |
| transition=dict(duration=0), | |
| )], | |
| margin=dict(l=0, r=0, t=80, b=60), | |
| height=540, | |
| paper_bgcolor="white", | |
| ) | |
| return fig | |
| # ============================================================================ | |
| # 2. Node detail time-series panel | |
| # ============================================================================ | |
| def make_node_timeseries( | |
| result, | |
| node_id: str, | |
| item_ids: Optional[List[str]] = None, | |
| ) -> go.Figure: | |
| """ | |
| 4–5 subplot panel for a single node showing inventory, backlog, inflow, | |
| outflow, and (destination only) realized demand. | |
| Parameters | |
| ---------- | |
| result : SimResult | |
| node_id : str | |
| Node to visualise. | |
| item_ids : list[str] or None | |
| Subset of items to plot. Defaults to all items in result. | |
| """ | |
| if item_ids is None: | |
| item_ids = result.item_ids | |
| dest_id = next(n for n in result.node_ids if result.tier.get(n) == 0) | |
| is_dest = (node_id == dest_id) | |
| # Use plain Python lists — avoids numpy int32/int64 serialization issues | |
| # inside Gradio 4.x's gr.Plot JSON path. | |
| days = list(range(result.T)) | |
| n_panels = 5 if is_dest else 4 | |
| panel_titles = [ | |
| "On-hand inventory (units stored at this node)", | |
| "Backlog (unfulfilled demand pending)", | |
| "Inflow (units arriving per day)", | |
| "Outflow (units shipped out per day)", | |
| ] | |
| if is_dest: | |
| panel_titles.append("Realized demand (customer orders received per day)") | |
| fig = make_subplots( | |
| rows=n_panels, cols=1, | |
| shared_xaxes=True, | |
| subplot_titles=panel_titles, | |
| vertical_spacing=0.06, | |
| ) | |
| for idx, iid in enumerate(item_ids): | |
| color = ITEM_COLORS[idx % len(ITEM_COLORS)] | |
| # Capture loop variables by value via default args to avoid closure issues. | |
| def _add(row, y_arr, dash="solid", _iid=iid, _color=color): | |
| fig.add_trace( | |
| go.Scatter( | |
| x=days, | |
| y=y_arr.tolist(), # convert np.int32 → Python ints | |
| mode="lines", | |
| name=_iid, | |
| line=dict(color=_color, width=1.5, dash=dash), | |
| legendgroup=_iid, | |
| showlegend=(row == 1), | |
| hovertemplate=f"Day %{{x}}<br>{_iid}: %{{y:,.0f}}<extra></extra>", | |
| ), | |
| row=row, col=1, | |
| ) | |
| _add(1, result.inventory[node_id][iid]) | |
| _add(2, result.backlog[node_id][iid]) | |
| _add(3, result.inflow[node_id][iid]) | |
| _add(4, result.outflow[node_id][iid]) | |
| if is_dest: | |
| _add(5, result.demand[iid], dash="dot") | |
| # Disruption event markers — one vline per panel. | |
| # add_vline(row=, col=) was not available in all Plotly 5.x builds; | |
| # iterate over subplot y-axis references instead. | |
| if result.disruption_log: | |
| yref_list = ["y"] + [f"y{i}" for i in range(2, n_panels + 1)] | |
| for ev in result.disruption_log: | |
| for yref in yref_list: | |
| fig.add_shape( | |
| type="line", | |
| x0=ev["day"], x1=ev["day"], | |
| y0=0, y1=1, | |
| xref="x", yref=f"{yref} domain", | |
| line=dict(dash="dash", color="rgba(255,80,80,0.5)", width=1), | |
| ) | |
| # Panel y-axis labels | |
| y_labels = ["Units", "Units", "Units/day", "Units/day"] | |
| if is_dest: | |
| y_labels.append("Units/day") | |
| for row, lbl in enumerate(y_labels, start=1): | |
| fig.update_yaxes(title_text=lbl, row=row, col=1, | |
| title_font=dict(size=10), tickfont=dict(size=9)) | |
| fig.update_xaxes(title_text="Day", row=n_panels, col=1, | |
| tickfont=dict(size=9)) | |
| tier_name = _TIER_LABEL.get(result.tier.get(node_id, -1), "") | |
| type_name = _TIER_TYPE_NAME.get(result.tier.get(node_id, -1), "") | |
| disruption_note = ( | |
| " · <span style='color:rgba(255,80,80,0.9)'>red dashes = disruption events</span>" | |
| if result.disruption_log else "" | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=(f"<b>{node_id}</b> — {type_name}" | |
| f"{disruption_note}"), | |
| x=0.5, xanchor="center", font=dict(size=13), | |
| ), | |
| legend=dict( | |
| title="Items (SKUs)", | |
| x=1.01, y=1.0, | |
| font=dict(size=10), | |
| bgcolor="rgba(255,255,255,0.8)", | |
| bordercolor="lightgray", borderwidth=1, | |
| ), | |
| height=180 * n_panels + 80, | |
| paper_bgcolor="white", | |
| plot_bgcolor="rgba(248,248,252,1)", | |
| margin=dict(l=60, r=120, t=60, b=50), | |
| ) | |
| # Light grid | |
| fig.update_xaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") | |
| fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") | |
| return fig | |
| # ============================================================================ | |
| # 3. Bullwhip amplification chart | |
| # ============================================================================ | |
| def make_bullwhip_chart(result) -> go.Figure: | |
| """ | |
| Bar chart of tier-level bullwhip ratio B = Var(inflow)/Var(outflow). | |
| Bars are grouped by item; tiers run from destination (left) to hub (right). | |
| A dashed reference line at B = 1 marks the no-amplification baseline. | |
| """ | |
| ratios = _bullwhip_ratios(result) | |
| # Collect tiers present (excluding sources, tier 6) | |
| tiers_present = sorted( | |
| {result.tier[n] for n in ratios if ratios[n]}, | |
| ) | |
| # Per-item bar traces | |
| fig = go.Figure() | |
| for idx, iid in enumerate(result.item_ids): | |
| tier_means = [] | |
| for t in tiers_present: | |
| nodes_in_tier = [n for n in ratios | |
| if result.tier.get(n) == t and iid in ratios[n]] | |
| if nodes_in_tier: | |
| tier_means.append( | |
| float(np.mean([ratios[n][iid] for n in nodes_in_tier])) | |
| ) | |
| else: | |
| tier_means.append(None) | |
| x_labels = [_TIER_LABEL.get(t, f"Tier {t}") for t in tiers_present] | |
| fig.add_trace(go.Bar( | |
| x=x_labels, | |
| y=tier_means, | |
| name=iid, | |
| marker_color=ITEM_COLORS[idx % len(ITEM_COLORS)], | |
| opacity=0.82, | |
| text=[f"{v:.2f}" if v is not None else "" for v in tier_means], | |
| textposition="outside", | |
| textfont=dict(size=9), | |
| )) | |
| # Reference line at B = 1 | |
| fig.add_hline( | |
| y=1.0, line_dash="dash", | |
| line_color="rgba(80,80,80,0.6)", line_width=1.5, | |
| annotation_text="B = 1 (no amplification)", | |
| annotation_position="top right", | |
| annotation_font_size=9, | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=( | |
| "<b>Bullwhip Amplification by Tier</b><br>" | |
| "<sup>B > 1 means demand variability grows as orders travel upstream</sup>" | |
| ), | |
| x=0.5, xanchor="center", font=dict(size=13), | |
| ), | |
| xaxis=dict( | |
| title="Network tier (NewYork = downstream → Suppliers = upstream)", | |
| title_font=dict(size=11), | |
| tickfont=dict(size=10), | |
| ), | |
| yaxis=dict( | |
| title="B = Var(inflow) / Var(outflow)", | |
| title_font=dict(size=11), | |
| tickfont=dict(size=9), | |
| rangemode="tozero", | |
| ), | |
| barmode="group", | |
| legend=dict( | |
| title="Items", | |
| x=1.01, y=1.0, | |
| font=dict(size=10), | |
| bgcolor="rgba(255,255,255,0.8)", | |
| bordercolor="lightgray", borderwidth=1, | |
| ), | |
| height=400, | |
| paper_bgcolor="white", | |
| plot_bgcolor="rgba(248,248,252,1)", | |
| margin=dict(l=60, r=120, t=60, b=60), | |
| ) | |
| fig.update_yaxes(showgrid=True, gridcolor="rgba(200,200,220,0.5)") | |
| return fig | |
| # ============================================================================ | |
| # 4. Edge utilization heatmap | |
| # ============================================================================ | |
| def make_edge_heatmap(result) -> go.Figure: | |
| """ | |
| Heatmap of edge utilization (fraction of daily capacity) over time. | |
| Y-axis: edges sorted downstream→upstream. | |
| X-axis: simulation day. | |
| Color: 0 (white/green) = empty → 1 (red) = at capacity. | |
| Disruption events are shown as vertical dashed lines. | |
| """ | |
| # Sort edges: downstream edges first (by tier of source node) | |
| def _edge_tier(eid): | |
| return result.tier.get(eid[0], 99) | |
| sorted_edges = sorted(result.edge_ids, key=_edge_tier) | |
| edge_labels = [f"{u} → {v}" for u, v in sorted_edges] | |
| z = np.array([result.edge_util[eid] for eid in sorted_edges], | |
| dtype=np.float64) # shape (n_edges, T) | |
| days = list(range(result.T)) | |
| fig = go.Figure(go.Heatmap( | |
| z=z.tolist(), | |
| x=days, | |
| y=edge_labels, | |
| colorscale=[ | |
| [0.0, "rgb(240,248,255)"], # near-empty: very light blue | |
| [0.5, "rgb(255,200,100)"], # half full: yellow-orange | |
| [1.0, "rgb(220,50,50)"], # saturated: red | |
| ], | |
| zmin=0.0, zmax=1.0, | |
| colorbar=dict( | |
| title="Utilization<br>fraction", | |
| thickness=13, | |
| len=0.7, | |
| tickvals=[0, 0.5, 1], | |
| ticktext=["0%", "50%", "100%"], | |
| tickfont=dict(size=9), | |
| ), | |
| hovertemplate=( | |
| "Edge: %{y}<br>Day: %{x}<br>Utilization: %{z:.1%}<extra></extra>" | |
| ), | |
| )) | |
| # Disruption event markers | |
| for ev in result.disruption_log: | |
| fig.add_vline( | |
| x=ev["day"], line_dash="dash", | |
| line_color="rgba(80,80,255,0.6)", line_width=1.5, | |
| ) | |
| fig.add_annotation( | |
| x=ev["day"], y=1.01, xref="x", yref="paper", | |
| text="disruption", showarrow=False, | |
| font=dict(size=8, color="blue"), textangle=-90, | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=( | |
| "<b>Edge Utilization over Time</b><br>" | |
| "<sup>Fraction of daily shipping capacity used per edge " | |
| "(0 = empty, 1 = fully saturated). " | |
| "Blue dashes mark disruption events.</sup>" | |
| ), | |
| x=0.5, xanchor="center", font=dict(size=13), | |
| ), | |
| xaxis=dict( | |
| title="Day", | |
| title_font=dict(size=11), | |
| tickfont=dict(size=9), | |
| ), | |
| yaxis=dict( | |
| title="Edge", | |
| title_font=dict(size=11), | |
| tickfont=dict(size=9), | |
| autorange="reversed", # downstream edges at top | |
| ), | |
| height=max(320, 30 * len(sorted_edges) + 120), | |
| paper_bgcolor="white", | |
| margin=dict(l=160, r=80, t=60, b=60), | |
| ) | |
| return fig | |
| # ============================================================================ | |
| # 5. HTML wrapper for animated network map (Gradio gr.HTML compatible) | |
| # ============================================================================ | |
| def make_network_animation_html( | |
| result, | |
| frame_step: Optional[int] = None, | |
| frame_duration_ms: int = 150, | |
| ) -> str: | |
| """ | |
| Return an <iframe srcdoc=...> string for the animated network map. | |
| Gradio's gr.Plot uses Plotly.react() internally, which strips the frames | |
| array and breaks Play/Pause. An <iframe> with a full standalone HTML page | |
| (full_html=True) calls Plotly.newPlot() directly so frames are preserved | |
| and the animation works. The iframe also avoids the height:100% collapse | |
| that occurs when embedding partial HTML in gr.HTML. | |
| """ | |
| import html as _html | |
| import plotly.io as pio | |
| fig = make_network_animation(result, frame_step, frame_duration_ms) | |
| full_html = pio.to_html( | |
| fig, | |
| include_plotlyjs="cdn", | |
| full_html=True, | |
| config={"responsive": True}, | |
| ) | |
| escaped = _html.escape(full_html, quote=True) | |
| return ( | |
| f'<iframe srcdoc="{escaped}" ' | |
| f'width="100%" height="600" frameborder="0" scrolling="no" ' | |
| f'style="border:none;display:block;"></iframe>' | |
| ) | |
| # ============================================================================ | |
| # 6. Animated GIF export | |
| # ============================================================================ | |
| def make_network_animation_gif( | |
| result, | |
| frame_step: Optional[int] = None, | |
| frame_duration_ms: int = 150, | |
| max_frames: int = 80, | |
| ) -> str: | |
| """ | |
| Render the network animation as an animated GIF and return the temp file path. | |
| Uses matplotlib (Agg backend) with cartopy for the US map background. | |
| No kaleido or Chrome required. | |
| Parameters | |
| ---------- | |
| result : SimResult | |
| frame_step : int or None | |
| Days between frames. Auto-computed to cap at max_frames. | |
| frame_duration_ms : int | |
| Milliseconds per frame during playback. | |
| max_frames : int | |
| Maximum number of frames to render (keeps file size reasonable). | |
| """ | |
| import io as _io | |
| import tempfile | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| try: | |
| from PIL import Image | |
| except ImportError: | |
| raise ImportError("Pillow is required for GIF export: pip install Pillow") | |
| try: | |
| import cartopy.crs as ccrs | |
| import cartopy.feature as cfeature | |
| _PROJ = ccrs.LambertConformal(central_longitude=-96, central_latitude=39, | |
| standard_parallels=(33, 45)) | |
| _GEO = ccrs.PlateCarree() | |
| HAS_CARTOPY = True | |
| except ImportError: | |
| HAS_CARTOPY = False | |
| # Stress colormap: green(0) → yellow(0.35) → red(1.0) | |
| stress_cmap = mcolors.LinearSegmentedColormap.from_list( | |
| "stress", | |
| [(0.00, "#4CAF50"), (0.35, "#FFC107"), (1.00, "#F44336")], | |
| ) | |
| stress_norm = mcolors.Normalize(vmin=0.0, vmax=1.0) | |
| T = result.T | |
| if frame_step is None: | |
| frame_step = max(1, T // max_frames) | |
| frame_times = list(range(0, T, frame_step)) | |
| coords = result.node_coords | |
| node_ids = result.node_ids | |
| node_lons = [coords[n][1] for n in node_ids] | |
| node_lats = [coords[n][0] for n in node_ids] | |
| # s is marker area in pt² — keep proportional to Plotly sizes but much smaller | |
| node_s = [_TIER_MARKER_SIZE.get(result.tier.get(n, 3), 13) * 12 for n in node_ids] | |
| max_cap = max(result.edge_cap.values()) if result.edge_cap else 1.0 | |
| # Precompute: (lon0, lat0, lon1, lat1, linewidth) | |
| edges: List[Tuple] = [] | |
| for eid in result.edge_ids: | |
| u, v = eid | |
| if u not in coords or v not in coords: | |
| continue | |
| cap = result.edge_cap.get(eid, 1.0) | |
| lw = 0.6 + 2.0 * (cap / max_cap) ** 0.5 | |
| edges.append((coords[u][1], coords[u][0], coords[v][1], coords[v][0], lw)) | |
| W, H, DPI = 8.5, 5.0, 90 | |
| pil_frames = [] | |
| for t in frame_times: | |
| stress_t = _node_stress(result, t) | |
| node_colors = [stress_cmap(stress_norm(s)) for s in stress_t] # RGBA tuples | |
| node_markers = [_MPL_MARKER.get(_TIER_SYMBOL.get(result.tier.get(n, 3), "circle"), "o") | |
| for n in node_ids] | |
| lats_t, lons_t, _, colors_t = _frame_particles(result, t) | |
| if HAS_CARTOPY: | |
| fig = plt.figure(figsize=(W, H), dpi=DPI, facecolor="#f4f6f9") | |
| ax = fig.add_subplot(1, 1, 1, projection=_PROJ) | |
| ax.set_extent([-125, -66, 24, 50], crs=_GEO) | |
| ax.add_feature(cfeature.LAND, facecolor="#f0f0ee", zorder=0) | |
| ax.add_feature(cfeature.OCEAN, facecolor="#d0e4f0", zorder=0) | |
| ax.add_feature(cfeature.LAKES, facecolor="#d0e4f0", alpha=0.6, zorder=1) | |
| ax.add_feature(cfeature.STATES, edgecolor="#c0c0c0", linewidth=0.4, zorder=2) | |
| ax.add_feature(cfeature.COASTLINE,edgecolor="#999999", linewidth=0.5, zorder=2) | |
| for lon0, lat0, lon1, lat1, lw in edges: | |
| ax.plot([lon0, lon1], [lat0, lat1], transform=_GEO, | |
| color="#7a7a90", linewidth=lw, alpha=0.65, zorder=3, | |
| solid_capstyle="round") | |
| for nid, lon, lat, color, s, mkr in zip( | |
| node_ids, node_lons, node_lats, node_colors, node_s, node_markers): | |
| ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=4, | |
| linewidths=1.5, edgecolors="white", alpha=0.85, transform=_GEO) | |
| ax.text(lon, lat + 0.9, nid, transform=_GEO, | |
| fontsize=6.5, ha="center", va="bottom", | |
| fontweight="bold", color="#1a1a2e", zorder=6, | |
| bbox=dict(boxstyle="round,pad=0.12", fc="white", | |
| ec="none", alpha=0.65)) | |
| if lons_t: | |
| rgba_t = [mcolors.to_rgba(c) for c in colors_t] | |
| ax.scatter(lons_t, lats_t, c=rgba_t, s=45, zorder=8, | |
| linewidths=0.5, edgecolors="white", alpha=1.0, | |
| transform=_GEO) | |
| ax.set_title( | |
| f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}", | |
| fontsize=10, fontweight="bold", pad=6, color="#1a1a2e", | |
| ) | |
| fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01) | |
| else: | |
| # Fallback: plain axes, no geographic background | |
| fig, ax = plt.subplots(figsize=(W, H), dpi=DPI) | |
| fig.patch.set_facecolor("#f4f6f9") | |
| ax.set_facecolor("#dce8f0") | |
| ax.set_xlim(-128, -65) | |
| ax.set_ylim(23, 50) | |
| ax.axis("off") | |
| for lon0, lat0, lon1, lat1, lw in edges: | |
| ax.plot([lon0, lon1], [lat0, lat1], color="#7a7a90", linewidth=lw, | |
| alpha=0.65, zorder=1, solid_capstyle="round") | |
| for nid, lon, lat, color, s, mkr in zip( | |
| node_ids, node_lons, node_lats, node_colors, node_s, node_markers): | |
| ax.scatter([lon], [lat], c=[color], s=[s], marker=mkr, zorder=3, | |
| linewidths=1.5, edgecolors="white", alpha=0.85) | |
| ax.text(lon, lat + 0.8, nid, fontsize=6.5, ha="center", va="bottom", | |
| fontweight="bold", color="#1a1a2e", zorder=5, | |
| bbox=dict(boxstyle="round,pad=0.12", fc="white", | |
| ec="none", alpha=0.65)) | |
| if lons_t: | |
| ax.scatter(lons_t, lats_t, c=colors_t, s=45, zorder=8, | |
| linewidths=0.5, edgecolors="white", alpha=0.95) | |
| ax.set_title( | |
| f"ISOMORPH Supply Chain Digital Twin — Day {t} / {T - 1}", | |
| fontsize=10, fontweight="bold", pad=6, color="#1a1a2e", | |
| ) | |
| fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01) | |
| buf = _io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=DPI, facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| buf.seek(0) | |
| pil_frames.append(Image.open(buf).convert("RGB")) | |
| # Normalize all frames to the same size (first frame drives dimensions) | |
| W_px, H_px = pil_frames[0].size | |
| pil_frames = [f.resize((W_px, H_px), Image.LANCZOS) for f in pil_frames] | |
| # GIF supports only 256 colours. The map background consumes many palette | |
| # slots with similar grays/blues, crowding out item and stress colours. | |
| # Fix: append large solid-colour blocks of every important colour to the | |
| # combined image before quantising, so the median-cut algorithm is forced | |
| # to reserve palette entries for them. | |
| pinned_hex = ( | |
| ITEM_COLORS[:5] # all 5 item colours | |
| + ["#4CAF50", "#8BC34A", "#FFC107", "#FF9800", # stress ramp | |
| "#FF5722", "#F44336"] | |
| + ["#ffffff", "#1a1a2e", "#7a7a90", # white / text / edges | |
| "#f0f0ee", "#d0e4f0", "#c0c0c0"] # map background tones | |
| ) | |
| BLOCK = 2000 # large block → many pixels → dominates palette selection | |
| ref_strips = [] | |
| for h in pinned_hex: | |
| r, g, b = int(h.lstrip("#")[0:2], 16), int(h.lstrip("#")[2:4], 16), int(h.lstrip("#")[4:6], 16) | |
| ref_strips.append(Image.new("RGB", (BLOCK, H_px), (r, g, b))) | |
| ref = Image.new("RGB", (BLOCK * len(ref_strips), H_px)) | |
| for i, s in enumerate(ref_strips): | |
| ref.paste(s, (i * BLOCK, 0)) | |
| combined = Image.new("RGB", (W_px * len(pil_frames) + ref.width, H_px)) | |
| for i, f in enumerate(pil_frames): | |
| combined.paste(f, (i * W_px, 0)) | |
| combined.paste(ref, (W_px * len(pil_frames), 0)) | |
| palette_source = combined.quantize(colors=255, method=Image.Quantize.MEDIANCUT) | |
| quantized = [f.quantize(palette=palette_source, dither=0) for f in pil_frames] | |
| tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False, prefix="isomorph_") | |
| quantized[0].save( | |
| tmp.name, | |
| save_all=True, | |
| append_images=quantized[1:], | |
| loop=0, | |
| duration=frame_duration_ms, | |
| optimize=False, | |
| ) | |
| tmp.close() | |
| return tmp.name | |