Shape2Force / S2FApp /predictor.py
kaveh's picture
Implemented batch processing for predictions to optimize memory usage and added progress tracking. Updated constants for batch inference size and adjusted contour drawing parameters for better visibility in UI elements.
6727da5
"""
Core inference logic for S2F (Shape2Force).
Predicts force maps from bright field microscopy images.
"""
import os
import sys
import cv2
import torch
import numpy as np
# Ensure S2F is in path when running from project root or S2F
S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
if S2F_ROOT not in sys.path:
sys.path.insert(0, S2F_ROOT)
from config.constants import BATCH_INFERENCE_SIZE, DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
from models.s2f_model import create_s2f_model
from utils.paths import get_ckp_base, model_subfolder
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
from utils import config
def load_image(filepath, target_size=None):
"""Load and preprocess a bright field image."""
img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"Could not load image: {filepath}")
size = target_size if target_size is not None else MODEL_INPUT_SIZE
if isinstance(size, int):
size = (size, size)
img = cv2.resize(img, size)
img = img.astype(np.float32) / 255.0
return img
def sum_force_map(force_map):
"""Compute cell force as sum of pixel values scaled by SCALE_FACTOR_FORCE."""
if isinstance(force_map, np.ndarray):
force_map = torch.from_numpy(force_map.astype(np.float32))
if force_map.dim() == 2:
force_map = force_map.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
elif force_map.dim() == 3:
force_map = force_map.unsqueeze(0) # [1, 1, H, W]
# force_map: [B, 1, H, W], sum over spatial dims (2, 3)
return torch.sum(force_map, dim=(2, 3)) * config.SCALE_FACTOR_FORCE
def create_settings_channels_single(substrate_name, device, height, width, config_path=None,
substrate_config=None):
"""
Create settings channels for a single image (single-cell mode).
Args:
substrate_name: Substrate name (used if substrate_config is None)
device: torch device
height, width: spatial dimensions
config_path: Path to substrate config JSON
substrate_config: Optional dict with 'pixelsize' and 'young'. If provided, overrides substrate_name.
"""
norm_params = compute_settings_normalization(config_path=config_path)
if substrate_config is not None and 'pixelsize' in substrate_config and 'young' in substrate_config:
settings = substrate_config
else:
settings = get_settings_of_category(substrate_name, config_path=config_path)
pmin, pmax = norm_params['pixelsize']['min'], norm_params['pixelsize']['max']
ymin, ymax = norm_params['young']['min'], norm_params['young']['max']
pixelsize_norm = (settings['pixelsize'] - pmin) / (pmax - pmin) if pmax > pmin else 0.5
young_norm = (settings['young'] - ymin) / (ymax - ymin) if ymax > ymin else 0.5
pixelsize_norm = max(0.0, min(1.0, pixelsize_norm))
young_norm = max(0.0, min(1.0, young_norm))
pixelsize_ch = torch.full(
(1, 1, height, width), pixelsize_norm, device=device, dtype=torch.float32
)
young_ch = torch.full(
(1, 1, height, width), young_norm, device=device, dtype=torch.float32
)
return torch.cat([pixelsize_ch, young_ch], dim=1)
class S2FPredictor:
"""
Shape2Force predictor for single-cell or spheroid force map prediction.
"""
def __init__(self, model_type="single_cell", checkpoint_path=None, ckp_folder=None, device=None):
"""
Args:
model_type: "single_cell" or "spheroid"
checkpoint_path: Path to .pth checkpoint (relative to ckp_folder or absolute)
ckp_folder: Folder containing checkpoints (default: S2F/ckp)
device: "cuda" or "cpu" (auto-detected if None)
"""
self.model_type = model_type
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
ckp_base = get_ckp_base(S2F_ROOT)
subfolder = model_subfolder(model_type)
ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
if not os.path.isdir(ckp_dir):
ckp_dir = ckp_base # fallback if subfolders not used
in_channels = 3 if model_type == "single_cell" else 1
s2f_model_type = "s2f" if model_type == "single_cell" else "s2f_spheroid"
generator, _ = create_s2f_model(in_channels=in_channels, model_type=s2f_model_type)
self.generator = generator
if checkpoint_path:
full_path = checkpoint_path
if not os.path.isabs(checkpoint_path):
full_path = os.path.join(ckp_dir, checkpoint_path)
if not os.path.exists(full_path):
full_path = os.path.join(ckp_base, checkpoint_path) # try base folder
if not os.path.exists(full_path):
raise FileNotFoundError(f"Checkpoint not found: {full_path}")
if model_type == "single_cell":
self.generator.load_checkpoint_with_expansion(full_path, strict=True)
else:
checkpoint = torch.load(full_path, map_location="cpu", weights_only=False)
state = checkpoint.get("generator_state_dict") or checkpoint.get("model_state_dict") or checkpoint
self.generator.load_state_dict(state, strict=True)
if hasattr(self.generator, "set_output_mode"):
self.generator.set_output_mode(use_tanh=False) # sigmoid [0,1] for inference
self.generator = self.generator.to(self.device)
self.generator.eval()
self.norm_params = compute_settings_normalization() if model_type == "single_cell" else None
self._use_tanh_output = model_type == "single_cell" # single_cell uses tanh, spheroid uses sigmoid
self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
def predict(self, image_path=None, image_array=None, substrate=None,
substrate_config=None):
"""
Run prediction on an image.
Args:
image_path: Path to bright field image (tif, png, jpg)
image_array: numpy array (H, W) or (H, W, C) in [0, 255] or [0, 1]
substrate: Substrate name for single-cell mode (used if substrate_config is None)
substrate_config: Optional dict with 'pixelsize' and 'young'. Overrides substrate lookup.
Returns:
heatmap: numpy array (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE) in [0, 1]
force: scalar cell force (sum of heatmap * SCALE_FACTOR_FORCE)
pixel_sum: raw sum of all pixel values in heatmap
"""
if image_path is not None:
img = load_image(image_path)
elif image_array is not None:
img = np.asarray(image_array, dtype=np.float32)
if img.ndim == 3:
img = img[:, :, 0] if img.shape[-1] >= 1 else img
if img.max() > 1.0:
img = img / 255.0
img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
else:
raise ValueError("Provide image_path or image_array")
x = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,H,W]
if self.model_type == "single_cell" and self.norm_params is not None:
sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
settings_ch = create_settings_channels_single(
sub, self.device, x.shape[2], x.shape[3],
config_path=self.config_path, substrate_config=substrate_config
)
x = torch.cat([x, settings_ch], dim=1) # [1,3,H,W]
with torch.no_grad():
pred = self.generator(x)
if self._use_tanh_output:
pred = (pred + 1.0) / 2.0 # Tanh [-1,1] to [0, 1]
# else: spheroid already outputs sigmoid [0, 1]
heatmap = pred[0, 0].cpu().numpy()
force = sum_force_map(pred).item()
pixel_sum = float(np.sum(heatmap))
return heatmap, force, pixel_sum
def predict_batch(self, images, substrate=None, substrate_config=None, batch_size=None,
on_progress=None):
"""
Run prediction on a batch of images. Processes in chunks to avoid OOM on
memory-constrained environments (e.g. Hugging Face free tier).
Args:
images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
substrate: Substrate name for single-cell mode (same for all images).
substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
batch_size: Max images per forward pass (default: BATCH_INFERENCE_SIZE). Use 1 for minimal memory.
on_progress: Optional callback(processed: int, total: int) called after each forward pass.
Returns:
List of (heatmap, force, pixel_sum) tuples.
"""
batch_size = batch_size if batch_size is not None else BATCH_INFERENCE_SIZE
imgs = []
for item in images:
img = item[0] if isinstance(item, tuple) else item
img = np.asarray(img, dtype=np.float32)
if img.ndim == 3:
img = img[:, :, 0] if img.shape[-1] >= 1 else img
if img.max() > 1.0:
img = img / 255.0
img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
imgs.append(img)
results = []
for start in range(0, len(imgs), batch_size):
chunk = imgs[start : start + batch_size]
x = torch.from_numpy(np.stack(chunk)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
if self.model_type == "single_cell" and self.norm_params is not None:
sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
settings_ch = create_settings_channels_single(
sub, self.device, x.shape[2], x.shape[3],
config_path=self.config_path, substrate_config=substrate_config
)
settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
with torch.no_grad():
pred = self.generator(x)
if self._use_tanh_output:
pred = (pred + 1.0) / 2.0
for i in range(pred.shape[0]):
heatmap = pred[i, 0].cpu().numpy()
force = sum_force_map(pred[i : i + 1]).item()
pixel_sum = float(np.sum(heatmap))
results.append((heatmap, force, pixel_sum))
if on_progress is not None:
on_progress(len(results), len(imgs))
return results