Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], 'GPU') | |
| import numpy as np | |
| from PIL import Image | |
| from PIL import Image as PILImage | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import io | |
| from skimage.io import imread | |
| from skimage.color import rgb2gray | |
| from csbdeep.utils import normalize | |
| from stardist.models import StarDist2D | |
| from stardist.plot import render_label | |
| from MEDIARFormer import MEDIARFormer | |
| from Predictor import Predictor | |
| from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| # === Setup for GPU or CPU === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load SegFormer | |
| processor_segformer = SegformerImageProcessor(do_reduce_labels=False) | |
| model_segformer = SegformerForSemanticSegmentation.from_pretrained( | |
| "nvidia/segformer-b0-finetuned-ade-512-512", | |
| num_labels=8, | |
| ignore_mismatched_sizes=True | |
| ) | |
| model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location=device)) | |
| model_segformer.to(device) | |
| model_segformer.eval() | |
| # Load StarDist model (CPU-only, no GPU support) | |
| model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo') | |
| # Load Cellpose model with GPU if available | |
| model_cellpose = cellpose_models.CellposeModel(gpu=torch.cuda.is_available()) | |
| # SegFormer Inference | |
| def infer_segformer(image): | |
| image = image.convert("RGB") | |
| inputs = processor_segformer(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model_segformer(**inputs).logits | |
| pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy() | |
| # Colorize | |
| colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]]) | |
| color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
| for c in range(8): | |
| color_mask[pred_mask == c] = colors[c] | |
| return image, Image.fromarray(color_mask) | |
| # StarDist Inference | |
| def infer_stardist(image): | |
| image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image) | |
| labels, _ = model_stardist.predict_instances(normalize(image_gray)) | |
| overlay = render_label(labels, img=image_gray) | |
| overlay = (overlay[..., :3] * 255).astype(np.uint8) | |
| return image, Image.fromarray(overlay) | |
| # MEDIAR Inference | |
| def infer_mediar(image, temp_dir="temp_mediar"): | |
| os.makedirs(temp_dir, exist_ok=True) | |
| input_path = os.path.join(temp_dir, "input_image.tiff") | |
| output_path = os.path.join(temp_dir, "input_image_label.tiff") | |
| image.save(input_path) | |
| model_args = { | |
| "classes": 3, | |
| "decoder_channels": [1024, 512, 256, 128, 64], | |
| "decoder_pab_channels": 256, | |
| "encoder_name": 'mit_b5', | |
| "in_channels": 3 | |
| } | |
| model = MEDIARFormer(**model_args) | |
| weights = torch.load("from_phase1.pth", map_location=device) | |
| model.load_state_dict(weights, strict=False) | |
| model.to(device) | |
| model.eval() | |
| predictor = Predictor(model, device.type, temp_dir, temp_dir, algo_params={"use_tta": False}) | |
| predictor.img_names = ["input_image.tiff"] | |
| _ = predictor.conduct_prediction() | |
| pred = imread(output_path) | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.imshow(pred, cmap="cividis") | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close() | |
| buf.seek(0) | |
| return image, Image.open(buf) | |
| # Cellpose Inference | |
| def infer_cellpose(image, temp_dir="temp_cellpose"): | |
| os.makedirs(temp_dir, exist_ok=True) | |
| input_path = os.path.join(temp_dir, "input_image.tif") | |
| output_overlay = os.path.join(temp_dir, "overlay.png") | |
| image.save(input_path) | |
| img = cellpose_io.imread(input_path) | |
| masks, flows, styles = model_cellpose.eval(img, batch_size=1) | |
| fig = plt.figure(figsize=(12,5)) | |
| cellpose_plot.show_segmentation(fig, img, masks, flows[0]) | |
| plt.tight_layout() | |
| fig.savefig(output_overlay) | |
| plt.close(fig) | |
| return image, Image.open(output_overlay) | |
| # Main segmentation dispatcher | |
| def segment(model_name, image): | |
| ext = image.format.lower() if hasattr(image, 'format') and image.format else None | |
| if model_name == "Cellpose" and ext not in ["tif", "tiff", None]: | |
| return None, f"❌ Cellpose only supports `.tif` or `.tiff` images." | |
| if model_name == "SegFormer": | |
| return infer_segformer(image) | |
| elif model_name == "StarDist": | |
| return infer_stardist(image) | |
| elif model_name == "MEDIAR": | |
| return infer_mediar(image) | |
| elif model_name == "Cellpose": | |
| return infer_cellpose(image) | |
| else: | |
| return None, f"❌ Unknown model: {model_name}" | |
| # === Gradio UI === | |
| with gr.Blocks(title="Cell Segmentation Explorer") as app: | |
| gr.Markdown("## Cell Segmentation Explorer") | |
| gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown( | |
| choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"], | |
| label="Select Segmentation Model", | |
| value="SegFormer" | |
| ) | |
| image_input = gr.Image(type="pil", label="Uploaded Image") | |
| description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.") | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Segmentation Result") | |
| def handle_submit(model_name, img): | |
| if img is None: | |
| return None | |
| _, result = segment(model_name, img) | |
| return result | |
| submit_btn.click( | |
| fn=handle_submit, | |
| inputs=[model_dropdown, image_input], | |
| outputs=output_image | |
| ) | |
| clear_btn.click( | |
| lambda: [None, None], | |
| inputs=None, | |
| outputs=[image_input, output_image] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Sample Images (click to use as input)") | |
| original_sample_paths = ["img1.png", "img2.png", "img3.png"] | |
| resized_sample_paths = [] | |
| for idx, p in enumerate(original_sample_paths): | |
| img = PILImage.open(p).resize((128, 128)) | |
| temp_path = f"/tmp/sample_resized_{idx}.png" | |
| img.save(temp_path) | |
| resized_sample_paths.append(temp_path) | |
| sample_image_components = [] | |
| with gr.Row(): | |
| for i, img_path in enumerate(resized_sample_paths): | |
| def load_full_image(idx=i): | |
| return PILImage.open(original_sample_paths[idx]) | |
| sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False) | |
| sample_img.select( | |
| fn=load_full_image, | |
| inputs=[], | |
| outputs=image_input | |
| ) | |
| sample_image_components.append(sample_img) | |
| app.launch() | |