| """Cell segmentation from force map for background exclusion.""" |
| import numpy as np |
| from scipy.ndimage import gaussian_filter |
| from skimage.filters import threshold_otsu |
| from skimage.morphology import closing, opening, dilation, remove_small_objects, disk |
| from skimage.measure import label, regionprops |
|
|
|
|
| def estimate_cell_mask(heatmap, sigma=2, min_size=200, exclude_full_image=True, |
| threshold_relax=0.85, dilate_radius=4, min_area_ratio=0.2): |
| """ |
| Estimate cell region from force map using Otsu thresholding and morphological cleanup. |
| |
| Supports multiple disconnected regions (e.g., two cells): components whose area is |
| at least min_area_ratio of the largest are merged into the final mask. |
| |
| Args: |
| heatmap: 2D float array [0, 1] - predicted force map |
| sigma: Gaussian smoothing sigma to reduce noise. Default 2. |
| min_size: Minimum object size in pixels; smaller objects removed. Default 200. |
| exclude_full_image: If True, exclude the largest connected component when it |
| covers most of the image (>70%) and use the second largest. Default True. |
| threshold_relax: Multiply Otsu threshold by this (<1 = looser, include more pixels). |
| Default 0.85. |
| dilate_radius: Radius to dilate mask outward to include surrounding pixels. |
| Default 4. |
| min_area_ratio: Include components with area >= this fraction of the largest |
| component (0–1). E.g. 0.2 = include regions at least 20% the size of the |
| largest. Handles multiple disconnected force regions. Default 0.2. |
| |
| Returns: |
| mask: Binary uint8 array, 1 = estimated cell, 0 = background |
| """ |
| heatmap = np.clip(heatmap, 0, 1).astype(np.float64) |
| if np.max(heatmap) <= 0: |
| return np.zeros_like(heatmap, dtype=np.uint8) |
|
|
| |
| smoothed = gaussian_filter(heatmap, sigma=sigma) |
|
|
| |
| thresh = threshold_otsu(smoothed) * threshold_relax |
| mask = (smoothed > thresh).astype(np.uint8) |
|
|
| |
| mask = closing(mask, disk(5)).astype(np.uint8) |
| mask = opening(mask, disk(3)).astype(np.uint8) |
| mask_bool = mask.astype(bool) |
| min_size_int = max(int(min_size), 0) |
| |
| |
| try: |
| mask_bool = remove_small_objects(mask_bool, max_size=max(min_size_int - 1, 0)) |
| except TypeError: |
| mask_bool = remove_small_objects(mask_bool, min_size=min_size_int) |
| mask = mask_bool.astype(np.uint8) |
|
|
| |
| |
| labeled = label(mask) |
| props = list(regionprops(labeled)) |
|
|
| if len(props) == 0: |
| return np.zeros_like(heatmap, dtype=np.uint8) |
|
|
| props_sorted = sorted(props, key=lambda x: x.area, reverse=True) |
| total_px = heatmap.shape[0] * heatmap.shape[1] |
|
|
| |
| if exclude_full_image and len(props_sorted) >= 2 and props_sorted[0].area > 0.7 * total_px: |
| props_sorted = props_sorted[1:] |
|
|
| |
| ref_area = props_sorted[0].area |
| |
| labels_to_keep = [p.label for p in props_sorted if p.area >= min_area_ratio * ref_area] |
|
|
| mask = np.zeros_like(labeled, dtype=np.uint8) |
| for lab in labels_to_keep: |
| mask[labeled == lab] = 1 |
|
|
| |
| if dilate_radius > 0: |
| mask = dilation(mask, disk(dilate_radius)).astype(np.uint8) |
|
|
| return mask |
|
|