Shape2Force / S2FApp /ui /result_display.py
kaveh's picture
roll back remaining batch result option
f537675
"""Result display: single and batch prediction views."""
import csv
import io
import os
import zipfile
import cv2
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale, is_display_range_remapped
from utils.report import heatmap_to_rgb_with_contour, heatmap_to_png_bytes, create_pdf_report
from utils.segmentation import estimate_cell_mask
from ui.heatmaps import render_horizontal_colorbar, add_cell_contour_to_fig
from ui.measure_tool import (
build_original_vals,
build_cell_vals,
render_region_canvas,
_compute_cell_metrics,
HAS_DRAWABLE_CANVAS,
)
# Histogram bar color (matches static/s2f_styles.css accent)
_HISTOGRAM_ACCENT = "#0d9488"
_RESULT_FIG_HEIGHT = 320
_HISTOGRAM_HEIGHT = 180
def _result_banner(badge: str, badge_class: str, title: str) -> str:
"""HTML row for INPUT/OUTPUT section titles (batch + single views). badge_class: input | output."""
return (
f'<div class="result-label"><span class="result-badge {badge_class}">{badge}</span> {title}</div>'
)
def render_batch_results(batch_results, colormap_name="Jet", display_mode="Default",
clip_min=0, clip_max=1,
auto_cell_boundary=False, clamp_only=False):
"""
Render batch prediction results: summary table, bright-field row, heatmap row, and bulk download.
batch_results: list of dicts with img, heatmap, force, pixel_sum, key_img, cell_mask.
cell_mask is computed on-the-fly when auto_cell_boundary is True and not stored.
"""
if not batch_results:
return
# Resolve cell_mask and precompute display_heatmap for each result
for r in batch_results:
if auto_cell_boundary and (r.get("cell_mask") is None or not np.any(r.get("cell_mask", 0) > 0)):
r["_cell_mask"] = estimate_cell_mask(r["heatmap"])
else:
r["_cell_mask"] = r.get("cell_mask") if auto_cell_boundary else None
r["_display_heatmap"] = apply_display_scale(
r["heatmap"], display_mode,
clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only,
)
# Build table rows - consistent column names for both modes
headers = ["Image", "Force", "Sum", "Max", "Mean"]
rows = []
csv_rows = [["image"] + headers[1:]]
for r in batch_results:
heatmap = r["heatmap"]
cell_mask = r.get("_cell_mask")
key = r["key_img"] or "image"
if auto_cell_boundary and cell_mask is not None and np.any(cell_mask > 0):
vals = heatmap[cell_mask > 0]
cell_pixel_sum = float(np.sum(vals))
cell_force = cell_pixel_sum * (r["force"] / r["pixel_sum"]) if r["pixel_sum"] > 0 else cell_pixel_sum
cell_mean = cell_pixel_sum / np.sum(cell_mask) if np.sum(cell_mask) > 0 else 0
row = [key, f"{cell_force:.2f}", f"{cell_pixel_sum:.2f}",
f"{np.max(heatmap):.4f}", f"{cell_mean:.4f}"]
else:
row = [key, f"{r['force']:.2f}", f"{r['pixel_sum']:.2f}",
f"{np.max(heatmap):.4f}", f"{np.mean(heatmap):.4f}"]
rows.append(row)
csv_rows.append([os.path.splitext(key)[0]] + row[1:])
st.markdown(_result_banner("INPUT", "input", "Bright-field images"), unsafe_allow_html=True)
n_cols = min(5, len(batch_results))
bf_cols = st.columns(n_cols)
for i, r in enumerate(batch_results):
img = r["img"]
if img.ndim == 2:
img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
else:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
with bf_cols[i % n_cols]:
st.image(img_rgb, caption=r["key_img"], use_container_width=True)
is_rescale_b = is_display_range_remapped(display_mode, clip_min, clip_max)
st.markdown(_result_banner("OUTPUT", "output", "Predicted force maps"), unsafe_allow_html=True)
hm_cols = st.columns(n_cols)
for i, r in enumerate(batch_results):
hm_rgb = heatmap_to_rgb_with_contour(
r["_display_heatmap"], colormap_name,
r.get("_cell_mask") if auto_cell_boundary else None,
)
with hm_cols[i % n_cols]:
st.image(hm_rgb, caption=r["key_img"], use_container_width=True)
render_horizontal_colorbar(colormap_name, clip_min, clip_max, is_rescale_b)
# Table
st.dataframe(
{h: [r[i] for r in rows] for i, h in enumerate(headers)},
use_container_width=True,
hide_index=True,
)
# Histograms in accordion (one per row for visibility)
with st.expander("Force distribution (histograms)", expanded=False):
for i, r in enumerate(batch_results):
heatmap = r["heatmap"]
cell_mask = r.get("_cell_mask")
vals = heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and auto_cell_boundary) else heatmap.flatten()
vals = vals[vals > 0] if np.any(vals > 0) else vals
st.markdown(f"**{r['key_img']}**")
if len(vals) > 0:
fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color=_HISTOGRAM_ACCENT)])
fig.update_layout(
height=_HISTOGRAM_HEIGHT, margin=dict(l=40, r=20, t=10, b=40),
xaxis_title="Force value", yaxis_title="Count",
showlegend=False,
)
st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
else:
st.caption("No data")
if i < len(batch_results) - 1:
st.divider()
# Bulk downloads: CSV and heatmaps (zip)
buf_csv = io.StringIO()
csv.writer(buf_csv).writerows(csv_rows)
zip_buf = io.BytesIO()
with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
for r in batch_results:
hm_bytes = heatmap_to_png_bytes(
r["_display_heatmap"], colormap_name,
r.get("_cell_mask") if auto_cell_boundary else None,
)
base = os.path.splitext(r["key_img"] or "image")[0]
zf.writestr(f"{base}_heatmap.png", hm_bytes.getvalue())
zip_buf.seek(0)
dl_col1, dl_col2 = st.columns(2)
with dl_col1:
st.download_button(
"Download all as CSV",
data=buf_csv.getvalue(),
file_name="s2f_batch_results.csv",
mime="text/csv",
key="download_batch_csv",
icon=":material/download:",
)
with dl_col2:
st.download_button(
"Download all heatmaps",
data=zip_buf.getvalue(),
file_name="s2f_batch_heatmaps.zip",
mime="application/zip",
key="download_batch_heatmaps",
icon=":material/image:",
)
def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
colormap_name="Jet", display_mode="Default", measure_region_dialog=None, auto_cell_boundary=True,
cell_mask=None, clip_min=0.0, clip_max=1.0, clamp_only=False):
"""
Render prediction result: plot, metrics, expander, and download/measure buttons.
measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
cell_mask: optional precomputed cell mask; if None and auto_cell_boundary, will be computed.
"""
if cell_mask is None and auto_cell_boundary:
cell_mask = estimate_cell_mask(raw_heatmap)
elif not auto_cell_boundary:
cell_mask = None
cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
base_name = os.path.splitext(key_img or "image")[0]
if use_cell_metrics:
main_csv_rows = [
["image", "Cell sum", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
[base_name, f"{cell_pixel_sum:.2f}", f"{cell_force:.2f}",
f"{np.max(raw_heatmap):.4f}", f"{cell_mean:.4f}"],
]
else:
main_csv_rows = [
["image", "Sum of all pixels", "Force (scaled)", "Heatmap max", "Heatmap mean"],
[base_name, f"{pixel_sum:.2f}", f"{force:.2f}",
f"{np.max(raw_heatmap):.4f}", f"{np.mean(raw_heatmap):.4f}"],
]
buf_main_csv = io.StringIO()
csv.writer(buf_main_csv).writerows(main_csv_rows)
buf_hm = heatmap_to_png_bytes(display_heatmap, colormap_name, cell_mask=cell_mask)
is_rescale = is_display_range_remapped(display_mode, clip_min, clip_max)
tit1, tit2 = st.columns(2)
with tit1:
st.markdown(_result_banner("INPUT", "input", "Bright-field image"), unsafe_allow_html=True)
with tit2:
st.markdown(_result_banner("OUTPUT", "output", "Predicted force map"), unsafe_allow_html=True)
fig_pl = make_subplots(rows=1, cols=2)
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
plotly_colorscale = cv_colormap_to_plotly_colorscale(colormap_name)
colorbar_cfg = dict(len=0.4, thickness=12, tickmode="array")
tick_positions = [0, 0.25, 0.5, 0.75, 1]
if is_rescale:
rng = clip_max - clip_min
colorbar_cfg["tickvals"] = tick_positions
colorbar_cfg["ticktext"] = [f"{clip_min + t * rng:.2f}" for t in tick_positions]
else:
colorbar_cfg["tickvals"] = tick_positions
colorbar_cfg["ticktext"] = [f"{t:.2f}" for t in tick_positions]
fig_pl.add_trace(go.Heatmap(z=display_heatmap, colorscale=plotly_colorscale, zmin=0.0, zmax=1.0, showscale=True,
colorbar=colorbar_cfg), row=1, col=2)
add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2)
fig_pl.update_layout(
height=_RESULT_FIG_HEIGHT,
margin=dict(l=10, r=10, t=10, b=10),
xaxis=dict(scaleanchor="y", scaleratio=1),
xaxis2=dict(scaleanchor="y2", scaleratio=1),
)
fig_pl.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig_pl.update_yaxes(showticklabels=False, autorange="reversed", showgrid=False, zeroline=False)
st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True})
col1, col2, col3, col4 = st.columns(4)
if use_cell_metrics:
with col1:
st.metric("Cell sum", f"{cell_pixel_sum:.2f}", help="Sum over estimated cell area (background excluded)")
with col2:
st.metric("Cell force (scaled)", f"{cell_force:.2f}", help="Total traction force in physical units")
with col3:
st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
with col4:
st.metric("Heatmap mean", f"{cell_mean:.4f}", help="Mean force over estimated cell area")
else:
with col1:
st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
with col2:
st.metric("Force (scaled)", f"{force:.2f}", help="Total traction force in physical units (full field of view)")
with col3:
st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
with col4:
st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
# Statistics panel (mean, std, percentiles, histogram)
with st.expander("Statistics"):
vals = raw_heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and use_cell_metrics) else raw_heatmap.flatten()
if len(vals) > 0:
stat_col1, stat_col2, stat_col3 = st.columns(3)
p25, p50, p75, p90 = (
float(np.percentile(vals, 25)), float(np.percentile(vals, 50)),
float(np.percentile(vals, 75)), float(np.percentile(vals, 90)),
)
with stat_col1:
st.metric("Mean", f"{float(np.mean(vals)):.4f}")
st.metric("Std", f"{float(np.std(vals)):.4f}")
with stat_col2:
st.metric("P25", f"{p25:.4f}")
st.metric("P50 (median)", f"{p50:.4f}")
with stat_col3:
st.metric("P75", f"{p75:.4f}")
st.metric("P90", f"{p90:.4f}")
st.markdown("**Histogram**")
hist_fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color=_HISTOGRAM_ACCENT)])
hist_fig.update_layout(
height=_HISTOGRAM_HEIGHT, margin=dict(l=40, r=20, t=20, b=40),
xaxis_title="Force value", yaxis_title="Count",
showlegend=False,
)
st.plotly_chart(hist_fig, use_container_width=True, config={"displayModeBar": False})
else:
st.caption("No nonzero values to compute statistics.")
with st.expander("How to read the results"):
if use_cell_metrics:
st.markdown("""
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
This is the raw image you provided—it shows cell shape but not forces.
**Output (right):** Predicted traction force map.
- **Color** indicates force magnitude: blue = low, red = high
- **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
- **Red border = estimated cell area** (background excluded from metrics)
- Values are normalized to [0, 1] for visualization
**Metrics (auto cell boundary on):**
- **Cell sum:** Sum over estimated cell area (background excluded)
- **Cell force (scaled):** Total traction force in physical units
- **Heatmap max:** Peak force intensity in the map
- **Heatmap mean:** Mean force over the estimated cell area
""")
else:
st.markdown("""
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
This is the raw image you provided—it shows cell shape but not forces.
**Output (right):** Predicted traction force map.
- **Color** indicates force magnitude: blue = low, red = high
- **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
- Values are normalized to [0, 1] for visualization
**Metrics (auto cell boundary off):**
- **Sum of all pixels:** Raw sum over entire map
- **Force (scaled):** Total traction force in physical units (full field of view)
- **Heatmap max/mean:** Peak and average force intensity (full field of view)
""")
original_vals = build_original_vals(raw_heatmap, pixel_sum, force)
pdf_bytes = create_pdf_report(
img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name,
cell_mask=cell_mask, cell_pixel_sum=cell_pixel_sum, cell_force=cell_force, cell_mean=cell_mean
)
btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4)
with btn_col1:
if HAS_DRAWABLE_CANVAS and measure_region_dialog is not None:
if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
st.session_state["open_measure_dialog"] = True
st.rerun()
elif HAS_DRAWABLE_CANVAS:
with st.expander("Measure tool"):
expander_cell_vals = build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force) if (auto_cell_boundary and cell_mask is not None) else None
expander_cell_mask = cell_mask if auto_cell_boundary else None
render_region_canvas(
display_heatmap,
raw_heatmap=raw_heatmap,
bf_img=img,
original_vals=original_vals,
cell_vals=expander_cell_vals,
cell_mask=expander_cell_mask,
key_suffix="expander",
input_filename=key_img,
colormap_name=colormap_name,
)
else:
st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
with btn_col2:
st.download_button(
"Download heatmap",
width="stretch",
data=buf_hm.getvalue(),
file_name="s2f_heatmap.png",
mime="image/png",
key=f"download_heatmap{download_key_suffix}",
icon=":material/download:",
)
with btn_col3:
st.download_button(
"Download values",
width="stretch",
data=buf_main_csv.getvalue(),
file_name=f"{base_name}_main_values.csv",
mime="text/csv",
key=f"download_main_values{download_key_suffix}",
icon=":material/download:",
)
with btn_col4:
st.download_button(
"Download report",
width="stretch",
data=pdf_bytes,
file_name=f"{base_name}_report.pdf",
mime="application/pdf",
key=f"download_pdf{download_key_suffix}",
icon=":material/picture_as_pdf:",
)