""" Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images. """ import os import sys import traceback # Suppress OpenCV verbose logging (cv2.utils.logging not reliably available in all builds) os.environ.setdefault("OPENCV_LOG_LEVEL", "ERROR") import cv2 import numpy as np import streamlit as st S2F_ROOT = os.path.dirname(os.path.abspath(__file__)) if S2F_ROOT not in sys.path: sys.path.insert(0, S2F_ROOT) from config.constants import ( BATCH_INFERENCE_SIZE, BATCH_MAX_IMAGES, COLORMAPS, DEFAULT_SUBSTRATE, MODEL_TYPE_LABELS, SAMPLE_EXTENSIONS, SAMPLE_THUMBNAIL_LIMIT, ) from utils.paths import get_ckp_base, get_ckp_folder, get_sample_folder, list_files_in_folder, model_subfolder from utils.segmentation import estimate_cell_mask from utils.substrate_settings import list_substrates from utils.display import apply_display_scale from ui.components import ( build_original_vals, build_cell_vals, render_batch_results, render_result_display, render_region_canvas, ST_DIALOG, HAS_DRAWABLE_CANVAS, ) CITATION = ( "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. " "\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\", 2026." ) def _inference_cache_condition_key(model_type, use_manual, substrate_val, substrate_config): """Hashable key for substrate / manual conditions so cache invalidates when single-cell inputs change.""" if model_type != "single_cell": return None if use_manual and substrate_config is not None: return ( "manual", round(float(substrate_config["pixelsize"]), 6), round(float(substrate_config["young"]), 2), ) return ("preset", str(substrate_val)) # Measure tool dialog: defined early so it exists before render_result_display uses it if HAS_DRAWABLE_CANVAS and ST_DIALOG: @ST_DIALOG("Measure tool", width="medium") def measure_region_dialog(): raw_heatmap = st.session_state.get("measure_raw_heatmap") if raw_heatmap is None: st.warning("No prediction available to measure.") return display_mode = st.session_state.get("measure_display_mode", "Default") _m_clamp = st.session_state.get("measure_clamp_only", False) display_heatmap = apply_display_scale( raw_heatmap, display_mode, clip_min=st.session_state.get("measure_clip_min", 0), clip_max=st.session_state.get("measure_clip_max", 1), clamp_only=_m_clamp, ) bf_img = st.session_state.get("measure_bf_img") original_vals = st.session_state.get("measure_original_vals") cell_vals = st.session_state.get("measure_cell_vals") cell_mask = st.session_state.get("measure_cell_mask") input_filename = st.session_state.get("measure_input_filename", "image") colormap_name = st.session_state.get("measure_colormap", "Jet") render_region_canvas( display_heatmap, raw_heatmap=raw_heatmap, bf_img=bf_img, original_vals=original_vals, cell_vals=cell_vals, cell_mask=cell_mask, key_suffix="dialog", input_filename=input_filename, colormap_name=colormap_name, ) else: def measure_region_dialog(): pass def _get_measure_dialog_fn(): """Return measure dialog callable if available, else None (fixes st_dialog ordering).""" return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name, display_mode, auto_cell_boundary, cell_mask=None, clip_min=0, clip_max=1, clamp_only=False): """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it.""" if cell_mask is None and auto_cell_boundary: cell_mask = estimate_cell_mask(heatmap) st.session_state["measure_raw_heatmap"] = heatmap.copy() st.session_state["measure_display_mode"] = display_mode st.session_state["measure_clip_min"] = clip_min st.session_state["measure_clip_max"] = clip_max st.session_state["measure_clamp_only"] = clamp_only st.session_state["measure_bf_img"] = img.copy() st.session_state["measure_input_filename"] = key_img or "image" st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force) st.session_state["measure_colormap"] = colormap_name st.session_state["measure_auto_cell_on"] = auto_cell_boundary st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="wide") st.markdown( '', unsafe_allow_html=True, ) _css_path = os.path.join(S2F_ROOT, "static", "s2f_styles.css") if os.path.exists(_css_path): with open(_css_path, "r") as f: st.markdown(f"", unsafe_allow_html=True) st.markdown("""

🦠 Shape2Force (S2F)

Predict traction force maps from bright-field microscopy images of cells or spheroids

""", unsafe_allow_html=True) # Folders ckp_base = get_ckp_base(S2F_ROOT) def get_cached_sample_thumbnails(model_type, sample_folder, sample_files): """Return cached sample thumbnails. Key by (model_type, tuple(files)).""" cache_key = (model_type, tuple(sample_files)) if "sample_thumbnails" not in st.session_state: st.session_state["sample_thumbnails"] = {} cache = st.session_state["sample_thumbnails"] if cache_key not in cache: thumbnails = [] for fname in sample_files[:SAMPLE_THUMBNAIL_LIMIT]: path = os.path.join(sample_folder, fname) img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) thumbnails.append((fname, img)) cache[cache_key] = thumbnails return cache[cache_key] def _render_sample_selector(model_type, batch_mode): """ Render sample image selector (Example mode). Returns (img, imgs_batch, selected_sample, selected_samples). For single mode: img is set, imgs_batch=[]. For batch: img=None, imgs_batch=list of (img, key). """ sample_folder = get_sample_folder(S2F_ROOT, model_type) sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS) sample_subfolder_name = model_subfolder(model_type) img = None imgs_batch = [] selected_sample = None selected_samples = [] if not sample_files: st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.") return img, imgs_batch, selected_sample, selected_samples if batch_mode: selected_samples = st.multiselect( f"Select example images (max {BATCH_MAX_IMAGES})", sample_files, default=None, max_selections=BATCH_MAX_IMAGES, key=f"sample_batch_{model_type}", ) if selected_samples: for fname in selected_samples[:BATCH_MAX_IMAGES]: sample_path = os.path.join(sample_folder, fname) loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE) if loaded is not None: imgs_batch.append((loaded, fname)) else: selected_sample = st.selectbox( f"Select example image (from `samples/{sample_subfolder_name}/`)", sample_files, format_func=lambda x: x, key=f"sample_{model_type}", ) if selected_sample: sample_path = os.path.join(sample_folder, selected_sample) img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE) thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files) n_cols = min(5, len(thumbnails)) cols = st.columns(n_cols) for i, (fname, sample_img) in enumerate(thumbnails): if sample_img is not None: with cols[i % n_cols]: st.image(sample_img, caption=fname, width=120) return img, imgs_batch, selected_sample, selected_samples # Sidebar with st.sidebar: st.markdown(""" """, unsafe_allow_html=True) with st.container(border=False, key="s2f_grp_model"): model_type = st.radio( "Model type", ["single_cell", "spheroid"], format_func=lambda x: MODEL_TYPE_LABELS[x], horizontal=False, help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.", ) ckp_folder = get_ckp_folder(ckp_base, model_type) ckp_files = list_files_in_folder(ckp_folder, ".pth") ckp_subfolder_name = model_subfolder(model_type) if ckp_files: checkpoint = st.selectbox( "Checkpoint", ckp_files, key=f"checkpoint_{model_type}", help=f"Select a .pth file from ckp/{ckp_subfolder_name}/", ) else: st.warning(f"No .pth files in ckp/{ckp_subfolder_name}/. Add checkpoints to load.") checkpoint = None substrate_config = None substrate_val = DEFAULT_SUBSTRATE use_manual = True if model_type == "single_cell": try: with st.container(border=False, key="s2f_grp_conditions"): st.markdown('

Conditions

', unsafe_allow_html=True) conditions_source = st.radio( "Conditions", ["From config", "Manually"], horizontal=True, label_visibility="collapsed", ) from_config = conditions_source == "From config" if from_config: substrate_config = None substrates = list_substrates() substrate_val = st.selectbox( "Conditions (from config)", substrates, help="Select a preset from config/substrate_settings.json", label_visibility="collapsed", ) use_manual = False else: manual_pixelsize = st.number_input("Pixel size (µm/px)", min_value=0.1, max_value=50.0, value=3.0769, step=0.1, format="%.4f") manual_young = st.number_input("Pascals", min_value=100.0, max_value=100000.0, value=6000.0, step=100.0, format="%.0f") substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young} use_manual = True except FileNotFoundError: st.error("config/substrate_settings.json not found") batch_mode = st.toggle( "Batch mode", value=False, help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.", ) auto_cell_boundary = st.toggle( "Auto boundary", value=False, help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.", ) with st.container(border=False, key="s2f_grp_force"): force_scale_mode = st.radio( "Force scale", ["Default", "Range"], horizontal=True, key="s2f_force_scale", help="Default: display forces on the full 0–1 scale. Range: set a sub-range; values outside are zeroed and the rest is stretched to the colormap.", ) if force_scale_mode == "Default": clip_min, clip_max = 0.0, 1.0 display_mode = "Default" clamp_only = True else: mn_col, mx_col = st.columns(2) with mn_col: clip_min = st.number_input( "Min", min_value=0.0, max_value=1.0, value=0.0, step=0.01, format="%.2f", key="s2f_clip_min", help="Lower bound of the display range (0–1).", ) with mx_col: clip_max = st.number_input( "Max", min_value=0.0, max_value=1.0, value=1.0, step=0.01, format="%.2f", key="s2f_clip_max", help="Upper bound of the display range (0–1).", ) if clip_min >= clip_max: st.warning("Min must be less than max. Using 0.00–1.00 for display.") clip_min, clip_max = 0.0, 1.0 display_mode = "Range" clamp_only = False cm_col_lbl, cm_col_sb = st.columns([1, 2]) with cm_col_lbl: st.markdown('

Colormap

', unsafe_allow_html=True) with cm_col_sb: colormap_name = st.selectbox( "Colormap", list(COLORMAPS.keys()), key="s2f_colormap", label_visibility="collapsed", help="Color scheme for the force map. Viridis is often preferred for accessibility.", ) # Main area: image input img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed", key="s2f_img_source") img = None imgs_batch = [] # list of (img, key_img) for batch mode uploaded = None uploaded_list = [] selected_sample = None selected_samples = [] if batch_mode: # Batch mode: multiple images (max BATCH_MAX_IMAGES) if img_source == "Upload": uploaded_list = st.file_uploader( "Upload bright-field images", type=["tif", "tiff", "png", "jpg", "jpeg"], accept_multiple_files=True, help=f"Select up to {BATCH_MAX_IMAGES} images. Bright-field microscopy (grayscale or RGB).", ) if uploaded_list: uploaded_list = uploaded_list[:BATCH_MAX_IMAGES] for u in uploaded_list: bytes_data = u.read() nparr = np.frombuffer(bytes_data, np.uint8) decoded = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) if decoded is not None: imgs_batch.append((decoded, u.name)) u.seek(0) else: img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=True) else: # Single image mode if img_source == "Upload": uploaded = st.file_uploader( "Upload bright-field image", type=["tif", "tiff", "png", "jpg", "jpeg"], help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB).", ) if uploaded: bytes_data = uploaded.read() nparr = np.frombuffer(bytes_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) uploaded.seek(0) else: img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=False) st.markdown("") col_btn, col_info = st.columns([1, 3]) with col_btn: run = st.button("Run prediction", type="primary", use_container_width=True) with col_info: ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/" st.markdown(f"""
{MODEL_TYPE_LABELS[model_type]} {ckp_path}
""", unsafe_allow_html=True) has_image = img is not None has_batch = len(imgs_batch) > 0 if "prediction_result" not in st.session_state: st.session_state["prediction_result"] = None if "batch_results" not in st.session_state: st.session_state["batch_results"] = None if not batch_mode: st.session_state["batch_results"] = None # Clear when switching to single mode # Single-image keys (for non-batch) key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample _cond_key = _inference_cache_condition_key(model_type, use_manual, substrate_val, substrate_config) current_key = (model_type, checkpoint, key_img, _cond_key) cached = st.session_state["prediction_result"] has_cached = cached is not None and cached.get("cache_key") == current_key and not batch_mode just_ran = run and checkpoint and has_image and not batch_mode just_ran_batch = run and checkpoint and has_batch and batch_mode @st.cache_resource def _load_predictor(model_type, checkpoint, ckp_folder): """Load and cache predictor. Invalidated when model_type or checkpoint changes.""" from predictor import S2FPredictor return S2FPredictor( model_type=model_type, checkpoint_path=checkpoint, ckp_folder=ckp_folder, ) def _prepare_and_render_cached_result(r, key_img, colormap_name, display_mode, auto_cell_boundary, clip_min, clip_max, clamp_only, download_key_suffix="", check_measure_dialog=False, show_success=False): """Prepare display from cached result and render. Used by both just_ran and has_cached paths.""" img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"] display_heatmap = apply_display_scale( heatmap, display_mode, clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only, ) cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None _populate_measure_session_state( heatmap, img, pixel_sum, force, key_img, colormap_name, display_mode, auto_cell_boundary, cell_mask=cell_mask, clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only, ) if check_measure_dialog and st.session_state.pop("open_measure_dialog", False): measure_region_dialog() if show_success: st.success("Prediction complete!") render_result_display( img, heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix=download_key_suffix, colormap_name=colormap_name, display_mode=display_mode, measure_region_dialog=_get_measure_dialog_fn(), auto_cell_boundary=auto_cell_boundary, cell_mask=cell_mask, clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only, ) if just_ran_batch: st.session_state["prediction_result"] = None st.session_state["batch_results"] = None with st.spinner("Loading model and predicting..."): progress_bar = None try: predictor = _load_predictor(model_type, checkpoint, ckp_folder) sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE n_images = len(imgs_batch) progress_bar = st.progress(0, text=f"Predicting 0 / {n_images} images") pred_results = [] for start in range(0, n_images, BATCH_INFERENCE_SIZE): chunk = imgs_batch[start : start + BATCH_INFERENCE_SIZE] chunk_results = predictor.predict_batch( chunk, substrate=sub_val, substrate_config=substrate_config if model_type == "single_cell" else None, ) pred_results.extend(chunk_results) progress_bar.progress(min(start + len(chunk), n_images) / n_images, text=f"Predicting {len(pred_results)} / {n_images} images") batch_results = [ { "img": img_b.copy(), "heatmap": heatmap.copy(), "force": force, "pixel_sum": pixel_sum, "key_img": key_b, "cell_mask": estimate_cell_mask(heatmap) if auto_cell_boundary else None, } for (img_b, key_b), (heatmap, force, pixel_sum) in zip(imgs_batch, pred_results) ] st.session_state["batch_results"] = batch_results progress_bar.empty() st.success(f"Prediction complete for {len(batch_results)} image(s)!") render_batch_results( batch_results, colormap_name=colormap_name, display_mode=display_mode, clip_min=clip_min, clip_max=clip_max, auto_cell_boundary=auto_cell_boundary, clamp_only=clamp_only, ) except Exception as e: if progress_bar is not None: progress_bar.empty() st.error(f"Prediction failed: {e}") st.code(traceback.format_exc()) elif batch_mode and st.session_state.get("batch_results"): render_batch_results( st.session_state["batch_results"], colormap_name=colormap_name, display_mode=display_mode, clip_min=clip_min, clip_max=clip_max, auto_cell_boundary=auto_cell_boundary, clamp_only=clamp_only, ) elif just_ran: st.session_state["prediction_result"] = None with st.spinner("Loading model and predicting..."): try: predictor = _load_predictor(model_type, checkpoint, ckp_folder) sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE heatmap, force, pixel_sum = predictor.predict( image_array=img, substrate=sub_val, substrate_config=substrate_config if model_type == "single_cell" else None, ) cache_key = (model_type, checkpoint, key_img, _cond_key) r = { "img": img.copy(), "heatmap": heatmap.copy(), "force": force, "pixel_sum": pixel_sum, "cache_key": cache_key, } st.session_state["prediction_result"] = r _prepare_and_render_cached_result( r, key_img, colormap_name, display_mode, auto_cell_boundary, clip_min, clip_max, clamp_only, download_key_suffix="", check_measure_dialog=False, show_success=True, ) except Exception as e: st.error(f"Prediction failed: {e}") st.code(traceback.format_exc()) elif has_cached: r = st.session_state["prediction_result"] _prepare_and_render_cached_result( r, key_img, colormap_name, display_mode, auto_cell_boundary, clip_min, clip_max, clamp_only, download_key_suffix="_cached", check_measure_dialog=True, show_success=False, ) elif run and not checkpoint: st.warning("Please add checkpoint files to the ckp/ folder and select one.") elif run and not has_image and not has_batch: st.warning("Please upload an image or select an example.") elif run and batch_mode and not has_batch: st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.") st.markdown(f""" """, unsafe_allow_html=True)