import gradio as gr import os import torch import tensorflow as tf tf.config.set_visible_devices([], 'GPU') import numpy as np from PIL import Image from PIL import Image as PILImage from pathlib import Path import matplotlib.pyplot as plt import io from skimage.io import imread from skimage.color import rgb2gray from csbdeep.utils import normalize from stardist.models import StarDist2D from stardist.plot import render_label from MEDIARFormer import MEDIARFormer from Predictor import Predictor from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation # === Setup for GPU or CPU === device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load SegFormer processor_segformer = SegformerImageProcessor(do_reduce_labels=False) model_segformer = SegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b0-finetuned-ade-512-512", num_labels=8, ignore_mismatched_sizes=True ) model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location=device)) model_segformer.to(device) model_segformer.eval() # Load StarDist model (CPU-only, no GPU support) model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo') # Load Cellpose model with GPU if available model_cellpose = cellpose_models.CellposeModel(gpu=torch.cuda.is_available()) # SegFormer Inference def infer_segformer(image): image = image.convert("RGB") inputs = processor_segformer(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model_segformer(**inputs).logits pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy() # Colorize colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]]) color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) for c in range(8): color_mask[pred_mask == c] = colors[c] return image, Image.fromarray(color_mask) # StarDist Inference def infer_stardist(image): image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image) labels, _ = model_stardist.predict_instances(normalize(image_gray)) overlay = render_label(labels, img=image_gray) overlay = (overlay[..., :3] * 255).astype(np.uint8) return image, Image.fromarray(overlay) # MEDIAR Inference def infer_mediar(image, temp_dir="temp_mediar"): os.makedirs(temp_dir, exist_ok=True) input_path = os.path.join(temp_dir, "input_image.tiff") output_path = os.path.join(temp_dir, "input_image_label.tiff") image.save(input_path) model_args = { "classes": 3, "decoder_channels": [1024, 512, 256, 128, 64], "decoder_pab_channels": 256, "encoder_name": 'mit_b5', "in_channels": 3 } model = MEDIARFormer(**model_args) weights = torch.load("from_phase1.pth", map_location=device) model.load_state_dict(weights, strict=False) model.to(device) model.eval() predictor = Predictor(model, device.type, temp_dir, temp_dir, algo_params={"use_tta": False}) predictor.img_names = ["input_image.tiff"] _ = predictor.conduct_prediction() pred = imread(output_path) fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(pred, cmap="cividis") ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) return image, Image.open(buf) # Cellpose Inference def infer_cellpose(image, temp_dir="temp_cellpose"): os.makedirs(temp_dir, exist_ok=True) input_path = os.path.join(temp_dir, "input_image.tif") output_overlay = os.path.join(temp_dir, "overlay.png") image.save(input_path) img = cellpose_io.imread(input_path) masks, flows, styles = model_cellpose.eval(img, batch_size=1) fig = plt.figure(figsize=(12,5)) cellpose_plot.show_segmentation(fig, img, masks, flows[0]) plt.tight_layout() fig.savefig(output_overlay) plt.close(fig) return image, Image.open(output_overlay) # Main segmentation dispatcher def segment(model_name, image): ext = image.format.lower() if hasattr(image, 'format') and image.format else None if model_name == "Cellpose" and ext not in ["tif", "tiff", None]: return None, f"❌ Cellpose only supports `.tif` or `.tiff` images." if model_name == "SegFormer": return infer_segformer(image) elif model_name == "StarDist": return infer_stardist(image) elif model_name == "MEDIAR": return infer_mediar(image) elif model_name == "Cellpose": return infer_cellpose(image) else: return None, f"❌ Unknown model: {model_name}" # === Gradio UI === with gr.Blocks(title="Cell Segmentation Explorer") as app: gr.Markdown("## Cell Segmentation Explorer") gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.") with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown( choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"], label="Select Segmentation Model", value="SegFormer" ) image_input = gr.Image(type="pil", label="Uploaded Image") description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.") submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") with gr.Column(): output_image = gr.Image(label="Segmentation Result") def handle_submit(model_name, img): if img is None: return None _, result = segment(model_name, img) return result submit_btn.click( fn=handle_submit, inputs=[model_dropdown, image_input], outputs=output_image ) clear_btn.click( lambda: [None, None], inputs=None, outputs=[image_input, output_image] ) gr.Markdown("---") gr.Markdown("### Sample Images (click to use as input)") original_sample_paths = ["img1.png", "img2.png", "img3.png"] resized_sample_paths = [] for idx, p in enumerate(original_sample_paths): img = PILImage.open(p).resize((128, 128)) temp_path = f"/tmp/sample_resized_{idx}.png" img.save(temp_path) resized_sample_paths.append(temp_path) sample_image_components = [] with gr.Row(): for i, img_path in enumerate(resized_sample_paths): def load_full_image(idx=i): return PILImage.open(original_sample_paths[idx]) sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False) sample_img.select( fn=load_full_image, inputs=[], outputs=image_input ) sample_image_components.append(sample_img) app.launch()