| """ |
| Core inference logic for S2F (Shape2Force). |
| Predicts force maps from bright field microscopy images. |
| """ |
| import os |
| import sys |
| import cv2 |
| import torch |
| import numpy as np |
|
|
| |
| S2F_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| if S2F_ROOT not in sys.path: |
| sys.path.insert(0, S2F_ROOT) |
|
|
| from config.constants import BATCH_INFERENCE_SIZE, DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE |
| from models.s2f_model import create_s2f_model |
| from utils.paths import get_ckp_base, model_subfolder |
| from utils.substrate_settings import get_settings_of_category, compute_settings_normalization |
| from utils import config |
|
|
|
|
| def load_image(filepath, target_size=None): |
| """Load and preprocess a bright field image.""" |
| img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE) |
| if img is None: |
| raise ValueError(f"Could not load image: {filepath}") |
| size = target_size if target_size is not None else MODEL_INPUT_SIZE |
| if isinstance(size, int): |
| size = (size, size) |
| img = cv2.resize(img, size) |
| img = img.astype(np.float32) / 255.0 |
| return img |
|
|
|
|
| def sum_force_map(force_map): |
| """Compute cell force as sum of pixel values scaled by SCALE_FACTOR_FORCE.""" |
| if isinstance(force_map, np.ndarray): |
| force_map = torch.from_numpy(force_map.astype(np.float32)) |
| if force_map.dim() == 2: |
| force_map = force_map.unsqueeze(0).unsqueeze(0) |
| elif force_map.dim() == 3: |
| force_map = force_map.unsqueeze(0) |
| |
| return torch.sum(force_map, dim=(2, 3)) * config.SCALE_FACTOR_FORCE |
|
|
|
|
| def create_settings_channels_single(substrate_name, device, height, width, config_path=None, |
| substrate_config=None): |
| """ |
| Create settings channels for a single image (single-cell mode). |
| |
| Args: |
| substrate_name: Substrate name (used if substrate_config is None) |
| device: torch device |
| height, width: spatial dimensions |
| config_path: Path to substrate config JSON |
| substrate_config: Optional dict with 'pixelsize' and 'young'. If provided, overrides substrate_name. |
| """ |
| norm_params = compute_settings_normalization(config_path=config_path) |
| if substrate_config is not None and 'pixelsize' in substrate_config and 'young' in substrate_config: |
| settings = substrate_config |
| else: |
| settings = get_settings_of_category(substrate_name, config_path=config_path) |
| pmin, pmax = norm_params['pixelsize']['min'], norm_params['pixelsize']['max'] |
| ymin, ymax = norm_params['young']['min'], norm_params['young']['max'] |
| pixelsize_norm = (settings['pixelsize'] - pmin) / (pmax - pmin) if pmax > pmin else 0.5 |
| young_norm = (settings['young'] - ymin) / (ymax - ymin) if ymax > ymin else 0.5 |
| pixelsize_norm = max(0.0, min(1.0, pixelsize_norm)) |
| young_norm = max(0.0, min(1.0, young_norm)) |
| pixelsize_ch = torch.full( |
| (1, 1, height, width), pixelsize_norm, device=device, dtype=torch.float32 |
| ) |
| young_ch = torch.full( |
| (1, 1, height, width), young_norm, device=device, dtype=torch.float32 |
| ) |
| return torch.cat([pixelsize_ch, young_ch], dim=1) |
|
|
|
|
| class S2FPredictor: |
| """ |
| Shape2Force predictor for single-cell or spheroid force map prediction. |
| """ |
|
|
| def __init__(self, model_type="single_cell", checkpoint_path=None, ckp_folder=None, device=None): |
| """ |
| Args: |
| model_type: "single_cell" or "spheroid" |
| checkpoint_path: Path to .pth checkpoint (relative to ckp_folder or absolute) |
| ckp_folder: Folder containing checkpoints (default: S2F/ckp) |
| device: "cuda" or "cpu" (auto-detected if None) |
| """ |
| self.model_type = model_type |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| ckp_base = get_ckp_base(S2F_ROOT) |
| subfolder = model_subfolder(model_type) |
| ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder) |
| if not os.path.isdir(ckp_dir): |
| ckp_dir = ckp_base |
|
|
| in_channels = 3 if model_type == "single_cell" else 1 |
| s2f_model_type = "s2f" if model_type == "single_cell" else "s2f_spheroid" |
| generator, _ = create_s2f_model(in_channels=in_channels, model_type=s2f_model_type) |
| self.generator = generator |
|
|
| if checkpoint_path: |
| full_path = checkpoint_path |
| if not os.path.isabs(checkpoint_path): |
| full_path = os.path.join(ckp_dir, checkpoint_path) |
| if not os.path.exists(full_path): |
| full_path = os.path.join(ckp_base, checkpoint_path) |
| if not os.path.exists(full_path): |
| raise FileNotFoundError(f"Checkpoint not found: {full_path}") |
|
|
| if model_type == "single_cell": |
| self.generator.load_checkpoint_with_expansion(full_path, strict=True) |
| else: |
| checkpoint = torch.load(full_path, map_location="cpu", weights_only=False) |
| state = checkpoint.get("generator_state_dict") or checkpoint.get("model_state_dict") or checkpoint |
| self.generator.load_state_dict(state, strict=True) |
| if hasattr(self.generator, "set_output_mode"): |
| self.generator.set_output_mode(use_tanh=False) |
|
|
| self.generator = self.generator.to(self.device) |
| self.generator.eval() |
|
|
| self.norm_params = compute_settings_normalization() if model_type == "single_cell" else None |
| self._use_tanh_output = model_type == "single_cell" |
| self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json") |
|
|
| def predict(self, image_path=None, image_array=None, substrate=None, |
| substrate_config=None): |
| """ |
| Run prediction on an image. |
| |
| Args: |
| image_path: Path to bright field image (tif, png, jpg) |
| image_array: numpy array (H, W) or (H, W, C) in [0, 255] or [0, 1] |
| substrate: Substrate name for single-cell mode (used if substrate_config is None) |
| substrate_config: Optional dict with 'pixelsize' and 'young'. Overrides substrate lookup. |
| |
| Returns: |
| heatmap: numpy array (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE) in [0, 1] |
| force: scalar cell force (sum of heatmap * SCALE_FACTOR_FORCE) |
| pixel_sum: raw sum of all pixel values in heatmap |
| """ |
| if image_path is not None: |
| img = load_image(image_path) |
| elif image_array is not None: |
| img = np.asarray(image_array, dtype=np.float32) |
| if img.ndim == 3: |
| img = img[:, :, 0] if img.shape[-1] >= 1 else img |
| if img.max() > 1.0: |
| img = img / 255.0 |
| img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE)) |
| else: |
| raise ValueError("Provide image_path or image_array") |
|
|
| x = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0).to(self.device) |
|
|
| if self.model_type == "single_cell" and self.norm_params is not None: |
| sub = substrate if substrate is not None else DEFAULT_SUBSTRATE |
| settings_ch = create_settings_channels_single( |
| sub, self.device, x.shape[2], x.shape[3], |
| config_path=self.config_path, substrate_config=substrate_config |
| ) |
| x = torch.cat([x, settings_ch], dim=1) |
|
|
| with torch.no_grad(): |
| pred = self.generator(x) |
|
|
| if self._use_tanh_output: |
| pred = (pred + 1.0) / 2.0 |
| |
| heatmap = pred[0, 0].cpu().numpy() |
| force = sum_force_map(pred).item() |
| pixel_sum = float(np.sum(heatmap)) |
|
|
| return heatmap, force, pixel_sum |
|
|
| def predict_batch(self, images, substrate=None, substrate_config=None, batch_size=None, |
| on_progress=None): |
| """ |
| Run prediction on a batch of images. Processes in chunks to avoid OOM on |
| memory-constrained environments (e.g. Hugging Face free tier). |
| |
| Args: |
| images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C). |
| substrate: Substrate name for single-cell mode (same for all images). |
| substrate_config: Optional dict with 'pixelsize' and 'young' (same for all). |
| batch_size: Max images per forward pass (default: BATCH_INFERENCE_SIZE). Use 1 for minimal memory. |
| on_progress: Optional callback(processed: int, total: int) called after each forward pass. |
| |
| Returns: |
| List of (heatmap, force, pixel_sum) tuples. |
| """ |
| batch_size = batch_size if batch_size is not None else BATCH_INFERENCE_SIZE |
| imgs = [] |
| for item in images: |
| img = item[0] if isinstance(item, tuple) else item |
| img = np.asarray(img, dtype=np.float32) |
| if img.ndim == 3: |
| img = img[:, :, 0] if img.shape[-1] >= 1 else img |
| if img.max() > 1.0: |
| img = img / 255.0 |
| img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE)) |
| imgs.append(img) |
|
|
| results = [] |
| for start in range(0, len(imgs), batch_size): |
| chunk = imgs[start : start + batch_size] |
| x = torch.from_numpy(np.stack(chunk)).float().unsqueeze(1).to(self.device) |
|
|
| if self.model_type == "single_cell" and self.norm_params is not None: |
| sub = substrate if substrate is not None else DEFAULT_SUBSTRATE |
| settings_ch = create_settings_channels_single( |
| sub, self.device, x.shape[2], x.shape[3], |
| config_path=self.config_path, substrate_config=substrate_config |
| ) |
| settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1) |
| x = torch.cat([x, settings_batch], dim=1) |
|
|
| with torch.no_grad(): |
| pred = self.generator(x) |
|
|
| if self._use_tanh_output: |
| pred = (pred + 1.0) / 2.0 |
|
|
| for i in range(pred.shape[0]): |
| heatmap = pred[i, 0].cpu().numpy() |
| force = sum_force_map(pred[i : i + 1]).item() |
| pixel_sum = float(np.sum(heatmap)) |
| results.append((heatmap, force, pixel_sum)) |
|
|
| if on_progress is not None: |
| on_progress(len(results), len(imgs)) |
|
|
| return results |
|
|