"""Heatmap visualization utilities (colorbar, overlays, Plotly).""" import base64 import cv2 import numpy as np import streamlit as st import plotly.graph_objects as go from config.constants import COLORMAPS def _colormap_gradient_base64(colormap_name, width=512): """Generate a horizontal gradient bar as base64 PNG for the given colormap.""" cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) gradient = np.linspace(0, 255, width, dtype=np.uint8).reshape(1, -1) rgb = cv2.cvtColor(cv2.applyColorMap(gradient, cv2_cmap), cv2.COLOR_BGR2RGB) bar = np.repeat(rgb, 6, axis=0) _, buf = cv2.imencode(".png", cv2.cvtColor(bar, cv2.COLOR_RGB2BGR)) return base64.b64encode(buf.tobytes()).decode("utf-8") # Distinct colors for each region (RGB - heatmap_rgb is RGB) _REGION_COLORS = [ (0, 188, 212), # cyan (matches drawing tool) (0, 230, 118), # green (255, 235, 59), # yellow (171, 71, 188), # purple (0, 150, 255), # blue (255, 167, 38), # amber (124, 179, 66), # light green (233, 30, 99), # pink ] def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2): """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place.""" contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) overlay = annotated.copy() cv2.fillPoly(overlay, contours, color) mask_3d = np.stack([mask] * 3, axis=-1).astype(bool) annotated[mask_3d] = ( (1 - fill_alpha) * annotated[mask_3d].astype(np.float32) + fill_alpha * overlay[mask_3d].astype(np.float32) ).astype(np.uint8) cv2.drawContours(annotated, contours, -1, color, stroke_width) def render_horizontal_colorbar(colormap_name, clip_min=0, clip_max=1, is_rescale=False, caption=None): """ Render a compact horizontal colorbar for batch mode, anchored above the table. When ``is_rescale`` is True (Force scale **Range** with a strict sub-interval), tick labels show model force values in ``[clip_min, clip_max]``. The gradient still spans the full colormap because the heatmap has already been **rescaled** so the lowest (highest) value in your range maps to the colormap minimum (maximum)—same convention as the main Plotly view. """ ticks = [0, 0.25, 0.5, 0.75, 1] if is_rescale: rng = clip_max - clip_min labels = [f"{clip_min + t * rng:.2f}" for t in ticks] else: labels = [f"{t:.2f}" for t in ticks] data_url = _colormap_gradient_base64(colormap_name) labels_html = "".join(f'{l}' for l in labels) html = f"""
""" st.markdown(html, unsafe_allow_html=True) if caption: st.caption(caption) def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(0, 188, 212), stroke_width=2): """Composite heatmap with drawn region overlay.""" annotated = heatmap_rgb.copy() _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width) return annotated def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=None, fill_alpha=0.3): """Draw each region separately with distinct color and label (R1, R2, ...). No merging.""" annotated = heatmap_rgb.copy() if cell_mask is not None and np.any(cell_mask > 0): contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(annotated, contours, -1, (255, 0, 0), 3) for i, mask in enumerate(masks): color = _REGION_COLORS[i % len(_REGION_COLORS)] _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2) # Label at centroid M = cv2.moments(mask) if M["m00"] > 0: cx = int(M["m10"] / M["m00"]) cy = int(M["m01"] / M["m00"]) label = labels[i] if i < len(labels) else f"R{i + 1}" cv2.putText( annotated, label, (cx - 12, cy + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA ) cv2.putText( annotated, label, (cx - 12, cy + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1, cv2.LINE_AA ) return annotated def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2): """Add red contour overlay to Plotly heatmap subplot. Draws all contours (handles multiple disconnected regions).""" if cell_mask is None or not np.any(cell_mask > 0): return contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return for cnt in contours: pts = cnt.squeeze() if pts.ndim == 1: pts = pts.reshape(1, 2) x, y = pts[:, 0].tolist(), pts[:, 1].tolist() if x[0] != x[-1] or y[0] != y[-1]: x.append(x[0]) y.append(y[0]) fig_pl.add_trace( go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=3), showlegend=False), row=row, col=col )