| """Heatmap visualization utilities (colorbar, overlays, Plotly).""" |
| import base64 |
|
|
| import cv2 |
| import numpy as np |
| import streamlit as st |
| import plotly.graph_objects as go |
|
|
| from config.constants import COLORMAPS |
|
|
|
|
| def _colormap_gradient_base64(colormap_name, width=512): |
| """Generate a horizontal gradient bar as base64 PNG for the given colormap.""" |
| cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) |
| gradient = np.linspace(0, 255, width, dtype=np.uint8).reshape(1, -1) |
| rgb = cv2.cvtColor(cv2.applyColorMap(gradient, cv2_cmap), cv2.COLOR_BGR2RGB) |
| bar = np.repeat(rgb, 6, axis=0) |
| _, buf = cv2.imencode(".png", cv2.cvtColor(bar, cv2.COLOR_RGB2BGR)) |
| return base64.b64encode(buf.tobytes()).decode("utf-8") |
|
|
| |
| _REGION_COLORS = [ |
| (0, 188, 212), |
| (0, 230, 118), |
| (255, 235, 59), |
| (171, 71, 188), |
| (0, 150, 255), |
| (255, 167, 38), |
| (124, 179, 66), |
| (233, 30, 99), |
| ] |
|
|
|
|
| def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2): |
| """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place.""" |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| overlay = annotated.copy() |
| cv2.fillPoly(overlay, contours, color) |
| mask_3d = np.stack([mask] * 3, axis=-1).astype(bool) |
| annotated[mask_3d] = ( |
| (1 - fill_alpha) * annotated[mask_3d].astype(np.float32) |
| + fill_alpha * overlay[mask_3d].astype(np.float32) |
| ).astype(np.uint8) |
| cv2.drawContours(annotated, contours, -1, color, stroke_width) |
|
|
|
|
| def render_horizontal_colorbar(colormap_name, clip_min=0, clip_max=1, is_rescale=False, caption=None): |
| """ |
| Render a compact horizontal colorbar for batch mode, anchored above the table. |
| |
| When ``is_rescale`` is True (Force scale **Range** with a strict sub-interval), tick labels show |
| model force values in ``[clip_min, clip_max]``. The gradient still spans the full colormap because |
| the heatmap has already been **rescaled** so the lowest (highest) value in your range maps to |
| the colormap minimum (maximum)—same convention as the main Plotly view. |
| """ |
| ticks = [0, 0.25, 0.5, 0.75, 1] |
| if is_rescale: |
| rng = clip_max - clip_min |
| labels = [f"{clip_min + t * rng:.2f}" for t in ticks] |
| else: |
| labels = [f"{t:.2f}" for t in ticks] |
|
|
| data_url = _colormap_gradient_base64(colormap_name) |
| labels_html = "".join(f'<span class="cb-tick">{l}</span>' for l in labels) |
| html = f""" |
| <div class="colorbar-table-header"> |
| <div class="colorbar-ticks">{labels_html}</div> |
| <div class="colorbar-bar" style="background-image: url(data:image/png;base64,{data_url});"></div> |
| </div> |
| """ |
| st.markdown(html, unsafe_allow_html=True) |
| if caption: |
| st.caption(caption) |
|
|
|
|
| def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(0, 188, 212), stroke_width=2): |
| """Composite heatmap with drawn region overlay.""" |
| annotated = heatmap_rgb.copy() |
| _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width) |
| return annotated |
|
|
|
|
| def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=None, fill_alpha=0.3): |
| """Draw each region separately with distinct color and label (R1, R2, ...). No merging.""" |
| annotated = heatmap_rgb.copy() |
| if cell_mask is not None and np.any(cell_mask > 0): |
| contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| cv2.drawContours(annotated, contours, -1, (255, 0, 0), 3) |
| for i, mask in enumerate(masks): |
| color = _REGION_COLORS[i % len(_REGION_COLORS)] |
| _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2) |
| |
| M = cv2.moments(mask) |
| if M["m00"] > 0: |
| cx = int(M["m10"] / M["m00"]) |
| cy = int(M["m01"] / M["m00"]) |
| label = labels[i] if i < len(labels) else f"R{i + 1}" |
| cv2.putText( |
| annotated, label, (cx - 12, cy + 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA |
| ) |
| cv2.putText( |
| annotated, label, (cx - 12, cy + 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1, cv2.LINE_AA |
| ) |
| return annotated |
|
|
|
|
| def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2): |
| """Add red contour overlay to Plotly heatmap subplot. Draws all contours (handles multiple disconnected regions).""" |
| if cell_mask is None or not np.any(cell_mask > 0): |
| return |
| contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if not contours: |
| return |
| for cnt in contours: |
| pts = cnt.squeeze() |
| if pts.ndim == 1: |
| pts = pts.reshape(1, 2) |
| x, y = pts[:, 0].tolist(), pts[:, 1].tolist() |
| if x[0] != x[-1] or y[0] != y[-1]: |
| x.append(x[0]) |
| y.append(y[0]) |
| fig_pl.add_trace( |
| go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=3), showlegend=False), |
| row=row, col=col |
| ) |
|
|