Shape2Force / S2FApp /ui /measure_tool.py
kaveh's picture
optimised, removed dead codes
7068f5c
"""Measure tool: drawable canvas, region metrics, and downloads."""
import csv
import html
import io
import os
import cv2
import numpy as np
import streamlit as st
from PIL import Image
from config.constants import CANVAS_SIZE, DRAW_TOOLS, TOOL_LABELS
from utils.report import heatmap_to_rgb_with_contour, create_measure_pdf_report
from ui.heatmaps import make_annotated_heatmap_multi_regions
try:
from streamlit_drawable_canvas import st_canvas
HAS_DRAWABLE_CANVAS = True
except (ImportError, AttributeError):
HAS_DRAWABLE_CANVAS = False
def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
"""Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid."""
obj_type = obj.get("type", "")
pts = []
if obj_type == "rect":
left = obj.get("left", 0)
top = obj.get("top", 0)
w = obj.get("width", 0)
h = obj.get("height", 0)
pts = np.array([
[left, top], [left + w, top], [left + w, top + h], [left, top + h]
], dtype=np.float32)
elif obj_type == "circle" or obj_type == "ellipse":
left = obj.get("left", 0)
top = obj.get("top", 0)
width = obj.get("width", 0)
height = obj.get("height", 0)
radius = obj.get("radius", 0)
angle_deg = obj.get("angle", 0)
if radius > 0:
rx = ry = radius
angle_rad = np.deg2rad(angle_deg)
cx = left + radius * np.cos(angle_rad)
cy = top + radius * np.sin(angle_rad)
else:
rx = width / 2 if width > 0 else 0
ry = height / 2 if height > 0 else 0
if rx <= 0 or ry <= 0:
return None
cx = left + rx
cy = top + ry
if rx <= 0 or ry <= 0:
return None
n = 32
angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
elif obj_type == "path":
path = obj.get("path", [])
for cmd in path:
if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
if cmd[0] in ("M", "L"):
pts.append([float(cmd[1]), float(cmd[2])])
elif cmd[0] == "Q" and len(cmd) >= 5:
pts.append([float(cmd[3]), float(cmd[4])])
elif cmd[0] == "C" and len(cmd) >= 7:
pts.append([float(cmd[5]), float(cmd[6])])
if len(pts) < 3:
return None
pts = np.array(pts, dtype=np.float32)
else:
return None
pts[:, 0] *= scale_x
pts[:, 1] *= scale_y
pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
return pts
def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
"""Parse drawn shapes and return a list of individual masks (one per shape)."""
if not json_data or "objects" not in json_data or not json_data["objects"]:
return []
scale_x = heatmap_w / canvas_w
scale_y = heatmap_h / canvas_h
masks = []
for obj in json_data["objects"]:
pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h)
if pts is None:
continue
mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 1)
masks.append(mask)
return masks
def build_original_vals(raw_heatmap, pixel_sum, force):
"""Build original_vals dict for measure tool (full map)."""
return {
"pixel_sum": pixel_sum,
"force": force,
"max": float(np.max(raw_heatmap)),
"mean": float(np.mean(raw_heatmap)),
}
def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force):
"""Compute metrics over estimated cell area only."""
area_px = int(np.sum(cell_mask))
if area_px == 0:
return None, None, None
region_values = raw_heatmap * cell_mask
cell_pixel_sum = float(np.sum(region_values))
cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum
cell_mean = cell_pixel_sum / area_px
return cell_pixel_sum, cell_force, cell_mean
def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force):
"""Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid."""
cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force)
if cell_pixel_sum is None:
return None
region_values = raw_heatmap * cell_mask
region_nonzero = region_values[cell_mask > 0]
cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
return {
"pixel_sum": cell_pixel_sum,
"force": cell_force,
"max": cell_max,
"mean": cell_mean,
}
def compute_region_metrics(raw_heatmap, mask, original_vals=None):
"""Compute region metrics from mask."""
area_px = int(np.sum(mask))
region_values = raw_heatmap * mask
region_nonzero = region_values[mask > 0]
force_sum = float(np.sum(region_values))
density = force_sum / area_px if area_px > 0 else 0
region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
region_force_scaled = (
force_sum * (original_vals["force"] / original_vals["pixel_sum"])
if original_vals and original_vals.get("pixel_sum", 0) > 0
else force_sum
)
return {
"area_px": area_px,
"force_sum": force_sum,
"density": density,
"max": region_max,
"mean": region_mean,
"force_scaled": region_force_scaled,
}
def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=3):
"""Draw contour from mask on RGB image. Resizes mask to match img if needed."""
h, w = img_rgb.shape[:2]
if mask.shape[:2] != (h, w):
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width)
return img_rgb
def render_region_metrics_and_downloads(metrics_list, masks, heatmap_rgb, input_filename, key_suffix, has_original_vals,
first_region_label=None, bf_img=None, cell_mask=None, colormap_name="Jet"):
"""
Render per-shape metrics table and download buttons.
first_region_label: custom label for first row (e.g. 'Auto boundary').
masks: list of region masks (user-drawn only; used for labeled heatmap with R1, R2...).
"""
base_name = os.path.splitext(input_filename or "image")[0]
st.markdown("**Regions (each selection = one row)**")
if has_original_vals:
headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"]
csv_rows = [["image", "region"] + headers[1:]]
else:
headers = ["Region", "Area (px²)", "Force sum", "Mean"]
csv_rows = [["image", "region", "Area", "Force sum", "Mean"]]
table_rows = [headers]
for i, metrics in enumerate(metrics_list, 1):
region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}"
if has_original_vals:
row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]
csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}",
f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"])
else:
row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]
csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}",
f"{metrics['mean']:.6f}"])
table_rows.append(row)
# Render as HTML table to avoid Streamlit's default row/column indices
header = table_rows[0]
body = table_rows[1:]
th_cells = "".join(
f'<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">{html.escape(str(h))}</th>'
for h in header
)
rows_html = [
"<tr>"
+ "".join(
f'<td style="border: 1px solid #ddd; padding: 8px;">{html.escape(str(c))}</td>'
for c in row
)
+ "</tr>"
for row in body
]
table_html = (
f'<table style="border-collapse: collapse; width: 100%;">'
f"<thead><tr>{th_cells}</tr></thead>"
f"<tbody>{''.join(rows_html)}</tbody></table>"
)
st.markdown(table_html, unsafe_allow_html=True)
buf_csv = io.StringIO()
csv.writer(buf_csv).writerows(csv_rows)
# Annotated heatmap: each region separate with R1, R2 labels (no merging)
region_labels = [f"R{i + 1}" for i in range(len(masks))]
heatmap_labeled = make_annotated_heatmap_multi_regions(heatmap_rgb.copy(), masks, region_labels, cell_mask=None)
buf_img = io.BytesIO()
Image.fromarray(heatmap_labeled).save(buf_img, format="PNG")
buf_img.seek(0)
# PDF report (requires bf_img)
pdf_bytes = None
if bf_img is not None:
pdf_bytes = create_measure_pdf_report(bf_img, heatmap_labeled, table_rows, base_name)
n_cols = 3 if pdf_bytes is not None else 2
dl_cols = st.columns(n_cols)
with dl_cols[0]:
st.download_button("Download all regions", data=buf_csv.getvalue(),
file_name=f"{base_name}_all_regions.csv", mime="text/csv",
key=f"download_all_regions_{key_suffix}", icon=":material/download:")
with dl_cols[1]:
st.download_button("Download heatmap", data=buf_img.getvalue(),
file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
key=f"download_annotated_{key_suffix}", icon=":material/image:")
if pdf_bytes is not None:
with dl_cols[2]:
st.download_button("Download report", data=pdf_bytes,
file_name=f"{base_name}_measure_report.pdf", mime="application/pdf",
key=f"download_measure_pdf_{key_suffix}", icon=":material/picture_as_pdf:")
def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None,
cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"):
"""Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
if not HAS_DRAWABLE_CANVAS:
st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
return
raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
h, w = display_heatmap.shape
heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
if bf_img is not None:
bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
left_col, right_col = st.columns(2, gap=None)
with left_col:
draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
st.caption("Left-click add, right-click close. \nForce map (draw region)")
canvas_result = st_canvas(
fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
key=f"region_measure_canvas_{key_suffix}",
)
with right_col:
vals = cell_vals if cell_vals else original_vals
if vals:
label = "Cell area" if cell_vals else "Full map"
st.markdown(
f'<p class="s2f-measure-vals-heading">{html.escape(label)}</p>'
f'<div class="s2f-measure-vals-panel"><div class="s2f-measure-vals-grid">'
f"<span><strong>Sum:</strong> {vals['pixel_sum']:.1f}</span>"
f"<span><strong>Force:</strong> {vals['force']:.1f}</span>"
f"<span><strong>Max:</strong> {vals['max']:.3f}</span>"
f"<span><strong>Mean:</strong> {vals['mean']:.3f}</span>"
f"</div></div>",
unsafe_allow_html=True,
)
st.caption("Bright-field")
bf_display = bf_rgb.copy()
if cell_mask is not None and np.any(cell_mask > 0):
bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=3)
st.image(bf_display, width=CANVAS_SIZE)
else:
st.markdown("**Draw a region** on the heatmap.")
draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
key=f"draw_mode_region_{key_suffix}")
st.caption("Polygon: left-click to add points, right-click to close.")
canvas_result = st_canvas(
fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
key=f"region_measure_canvas_{key_suffix}",
)
if canvas_result.json_data:
masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
if masks:
metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
if cell_mask is not None and np.any(cell_mask > 0):
cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
metrics_list = [cell_metrics] + metrics_list
render_region_metrics_and_downloads(
metrics_list, masks, heatmap_rgb, input_filename, key_suffix, original_vals is not None,
first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None,
bf_img=bf_img, cell_mask=cell_mask, colormap_name=colormap_name,
)