| """Gradio Space for exploring Curia models and CuriaBench datasets. |
| |
| This application allows users to: |
| |
| - Select any available Curia classification head. |
| - Load the matching CuriaBench test split and sample random images per class. |
| - Upload custom medical images that match the model's expected orientation. |
| - Forward images through the selected model head and visualise class probabilities. |
| |
| The space expects an HF token with access to "raidium" resources to be |
| provided via the HF_TOKEN environment variable (configure it as a secret when |
| deploying to Hugging Face Spaces). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import random |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import torch |
| from datasets import Dataset |
| from PIL import Image |
| import traceback |
|
|
| from inference import ( |
| load_curia_dataset, |
| load_id_to_labels, |
| infer_image, |
| ) |
|
|
| |
| |
| |
|
|
| HEAD_OPTIONS: List[Tuple[str, str]] = [ |
| ("abdominal-trauma", "Active Extravasation"), |
| ("anatomy-ct", "Anatomy CT"), |
| ("anatomy-mri", "Anatomy MRI"), |
| ("atlas-stroke", "Atlas Stroke"), |
| ("covidx-ct", "COVIDx CT"), |
| ("deep-lesion-site", "Deep Lesion Site"), |
| ("emidec-classification-mask", "EMIDEC Classification"), |
| ("ich", "Intracranial Hemorrhage"), |
| ("ixi", "IXI"), |
| ("kits", "KiTS"), |
| ("kneeMRI", "Knee MRI"), |
| ("luna16-3D", "LUNA16 3D"), |
| |
| |
| |
| |
| ("oasis", "OASIS"), |
| ] |
|
|
| |
| HEADS_REQUIRING_MASK: set[str] = { |
| "anatomy-ct", |
| "anatomy-mri", |
| "deep-lesion-site", |
| "emidec-classification-mask", |
| "kits", |
| "kneeMRI", |
| "luna16-3D", |
| "neural_foraminal_narrowing", |
| "spinal_canal_stenosis", |
| "subarticular_stenosis", |
| } |
|
|
| HEADS_3D = { |
| "oasis", |
| "luna16-3D", |
| "kneeMRI", |
| } |
|
|
| REGRESSION_HEADS = { |
| "ixi", |
| } |
|
|
| DATASET_OPTIONS: Dict[str, str] = { |
| "anatomy-ct": "Anatomy CT (test)", |
| "anatomy-ct-hard": "Anatomy CT Hard (test)", |
| "anatomy-mri": "Anatomy MRI (test)", |
| "covidx-ct": "COVIDx CT (test)", |
| "deep-lesion-site": "Deep Lesion Site (test)", |
| "emidec-classification-mask": "EMIDEC Classification Mask (test)", |
| "ixi": "IXI (test)", |
| "kits": "KiTS (test)", |
| "kneeMRI": "Knee MRI (test)", |
| "luna16-3D": "LUNA16 3D (test)", |
| "oasis": "OASIS (test)", |
| } |
|
|
| DEFAULT_DATASET_FOR_HEAD: Dict[str, str] = { |
| "anatomy-ct": "anatomy-ct", |
| "anatomy-mri": "anatomy-mri", |
| "covidx-ct": "covidx-ct", |
| "deep-lesion-site": "deep-lesion-site", |
| "emidec-classification-mask": "emidec-classification-mask", |
| "ixi": "ixi", |
| "kits": "kits", |
| "kneeMRI": "kneeMRI", |
| "luna16-3D": "luna16-3D", |
| "oasis": "oasis", |
| } |
|
|
|
|
| |
| |
| |
| DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = { |
| "anatomy-ct": {"window_level": 40, "window_width": 400}, |
| "anatomy-ct-hard": {"window_level": 40, "window_width": 400}, |
| "anatomy-mri": None, |
| "atlas-stroke": None, |
| "covidx-ct": {"window_level": -600, "window_width": 1500}, |
| "deep-lesion-site": {"window_level": 40, "window_width": 400}, |
| "emidec-classification-mask": None, |
| "ich": {"window_level": 40, "window_width": 80}, |
| "ixi": None, |
| "kits": {"window_level": 40, "window_width": 400}, |
| "kneeMRI": None, |
| "luna16": {"window_level": -600, "window_width": 1500}, |
| "luna16-3D": {"window_level": -600, "window_width": 1500}, |
| "oasis": None, |
| } |
|
|
| LOGO_PATH = "Logo horizontal medium copie 4_CREME.png" |
|
|
| CUSTOM_CSS = """ |
| .gr-prose { max-width: 900px; } |
| #app-hero { |
| display: flex; |
| align-items: center; |
| gap: 2.5rem; |
| margin-bottom: 1.5rem; |
| padding-right: 1.5rem; |
| } |
| #app-hero .hero-text { |
| flex: 1; |
| padding-right: 1rem; |
| } |
| #app-hero .hero-text h1 { |
| font-size: 2.25rem; |
| margin-bottom: 0.5rem; |
| } |
| #app-hero .hero-text p { |
| margin: 0.25rem 0; |
| line-height: 1.5; |
| } |
| #app-hero .hero-logo img { |
| max-height: 60px; |
| width: auto; |
| display: block; |
| } |
| @media (max-width: 768px) { |
| #app-hero { |
| flex-direction: column; |
| text-align: center; |
| padding-right: 0; |
| } |
| #app-hero .hero-text { |
| padding-right: 0; |
| } |
| #app-hero .hero-text h1, |
| #app-hero .hero-text p { |
| text-align: center; |
| } |
| #app-hero .hero-logo img { |
| margin: 0 auto 1rem; |
| } |
| } |
| """ |
|
|
|
|
| def load_logo_data_uri() -> str: |
| try: |
| with open(LOGO_PATH, "rb") as logo_file: |
| encoded = base64.b64encode(logo_file.read()).decode("ascii") |
| return f"data:image/png;base64,{encoded}" |
| except FileNotFoundError: |
| return "" |
|
|
|
|
| LOGO_DATA_URI = load_logo_data_uri() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def apply_windowing(image: np.ndarray, head: str) -> np.ndarray: |
| """Apply CT windowing based on the dataset. |
| |
| For CT images, applies window level and width transformation. |
| For MRI images (windowing=None), returns the image unchanged. |
| |
| Args: |
| image: Raw image array (e.g., in Hounsfield Units for CT) |
| subset: Dataset subset name to determine windowing parameters |
| |
| Returns: |
| Windowed image array |
| """ |
| windowing = DEFAULT_WINDOWINGS.get(head) |
|
|
| |
| if windowing is None: |
| return image |
|
|
| window_level = windowing["window_level"] |
| window_width = windowing["window_width"] |
|
|
| |
| |
| window_min = window_level - window_width / 2 |
| window_max = window_level + window_width / 2 |
|
|
| |
| windowed = np.clip(image, window_min, window_max) |
| windowed = (windowed - window_min) / (window_max - window_min) |
|
|
| return windowed.astype(np.float32) |
|
|
|
|
| def to_display_image(image: np.ndarray) -> np.ndarray: |
| """Normalise image for display purposes (uint8, 3-channel).""" |
|
|
| |
| if image.ndim == 3: |
| gr.Info(f"Image is 3D, we display only the middle slice") |
| image = image[:, :, image.shape[2] // 2] |
|
|
| arr = np.array(image, copy=True) |
| if not np.isfinite(arr).all(): |
| arr = np.nan_to_num(arr, nan=0.0) |
|
|
| arr_min = float(arr.min()) |
| arr_max = float(arr.max()) |
| if arr_max - arr_min > 1e-6: |
| arr = (arr - arr_min) / (arr_max - arr_min) |
| else: |
| arr = np.zeros_like(arr) |
|
|
| arr = (arr * 255).clip(0, 255).astype(np.uint8) |
| if arr.ndim == 2: |
| arr = np.stack([arr, arr, arr], axis=-1) |
| return arr |
|
|
|
|
| def prepare_mask_tensor(mask: np.ndarray, height: int, width: int) -> Optional[torch.Tensor]: |
| arr = np.squeeze(mask) |
| if arr.ndim == 2: |
| arr = arr.reshape(1, height, width) |
| else: |
| if arr.shape[-2:] == (height, width): |
| arr = arr.reshape(-1, height, width) |
| elif arr.shape[0] == height and arr.shape[1] == width: |
| arr = np.transpose(arr, (2, 0, 1)) |
| elif arr.shape[1] == height and arr.shape[2] == width: |
| arr = arr.reshape(arr.shape[0], height, width) |
| elif arr.size % (height * width) == 0: |
| try: |
| arr = arr.reshape(-1, height, width) |
| except ValueError: |
| return None |
| else: |
| return None |
|
|
| mask_tensors: List[torch.Tensor] = [] |
| for idx, slice_arr in enumerate(arr): |
| bool_mask = torch.from_numpy(slice_arr > 0) |
| if bool_mask.any(): |
| mask_tensors.append(bool_mask) |
|
|
| if not mask_tensors: |
| return None |
|
|
| stacked = torch.stack(mask_tensors, dim=0).bool() |
| return stacked |
|
|
|
|
| def apply_contour_overlay( |
| image: np.ndarray, |
| mask: Any, |
| thickness: int = 1, |
| color: Tuple[int, int, int] = (255, 0, 0), |
| ) -> np.ndarray: |
| """Draw only the contours of segmentation masks instead of filled masks.""" |
| height, width = image.shape[:2] |
| mask_tensor = prepare_mask_tensor(mask, height, width) |
| if mask_tensor is None: |
| return image |
|
|
| |
| output = image.copy() |
|
|
| |
| for idx in range(mask_tensor.shape[0]): |
| mask_np = mask_tensor[idx].numpy().astype(np.uint8) |
|
|
| |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| |
| cv2.drawContours(output, contours, -1, color, thickness) |
|
|
| return output |
|
|
|
|
| def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray: |
| display = to_display_image(image) |
| if mask is None: |
| return display |
|
|
| try: |
| overlaid = apply_contour_overlay(display, mask) |
| return overlaid |
| except Exception: |
| gr.Warning("Mask provided but could not be visualised.") |
| return display |
|
|
|
|
| def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int: |
| if "target" not in dataset.column_names: |
| return random.randrange(len(dataset)) |
|
|
| if target is None: |
| return random.randrange(len(dataset)) |
|
|
| indices = [idx for idx, value in enumerate(dataset["target"]) if value == target] |
| if not indices: |
| return random.randrange(len(dataset)) |
| return random.choice(indices) |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| def update_dataset_display(head: str) -> str: |
| """Update the dataset name display based on the selected head.""" |
| dataset_key = DEFAULT_DATASET_FOR_HEAD.get(head) |
| if dataset_key: |
| dataset_label = DATASET_OPTIONS.get(dataset_key, dataset_key) |
| return f"**Dataset:** {dataset_label}" |
| return "**Dataset:** not available" |
|
|
|
|
| def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| """Disable upload component for heads that require masks.""" |
| if head in HEADS_REQUIRING_MASK: |
| info_update = gr.update( |
| value="⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset.", |
| visible=True, |
| ) |
| upload_update = gr.update(interactive=False) |
| return info_update, upload_update |
| elif head in HEADS_3D: |
| info_update = gr.update( |
| value="⚠️ Custom image upload is disabled for this task because it requires a 3D image.", |
| visible=True, |
| ) |
| upload_update = gr.update(interactive=False) |
| return info_update, upload_update |
|
|
| info_update = gr.update(visible=False) |
| upload_update = gr.update(interactive=True) |
| return info_update, upload_update |
|
|
|
|
| def load_dataset_metadata(head: str) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: |
| """Load dataset metadata based on the selected head.""" |
| subset = DEFAULT_DATASET_FOR_HEAD.get(head) |
| if not subset: |
| dropdown = gr.update(choices=["Random"], value="Random", interactive=False) |
| button = gr.update(interactive=False) |
| return dropdown, "No dataset found for this head.", button |
|
|
| |
| id2label = load_id_to_labels().get(head, {}) |
|
|
|
|
| try: |
| dataset = load_curia_dataset(subset) |
| except Exception as exc: |
| dropdown = gr.update(choices=["Random"], value="Random", interactive=False) |
| button = gr.update(interactive=False) |
| return dropdown, f"Failed to load dataset: {exc}", button |
|
|
| |
| classes = sorted(id2label.keys()) |
| options = [ |
| "Random", |
| *[f"{cls_id}: {id2label[cls_id]}" for cls_id in classes], |
| ] |
| dropdown = gr.update(choices=options, value="Random", interactive=True) |
| button = gr.update(interactive=True) |
| return dropdown, f"Loaded {subset} ({len(dataset)} test samples)", button |
|
|
|
|
| def parse_target_selection(selection: str) -> Optional[int]: |
| if not selection or selection == "Random": |
| return None |
|
|
| try: |
| target_str = selection.split(":", 1)[0].strip() |
| return int(target_str) |
| except (ValueError, AttributeError): |
| return None |
|
|
|
|
| def sample_dataset_example( |
| subset: str, |
| target_id: Optional[int], |
| ) -> Tuple[np.ndarray, Dict[str, Any]]: |
| dataset = load_curia_dataset(subset) |
| index = pick_random_indices(dataset, target_id) |
| record = dataset[index] |
| image = np.array(record["image"]).astype(np.float32) |
| mask_array = record.get("mask") |
|
|
| meta = { |
| "index": index, |
| "target": record.get("target"), |
| "mask": mask_array, |
| } |
|
|
| return image, meta |
|
|
|
|
| def load_dataset_sample( |
| target_selection: str, |
| head: str, |
| ) -> Tuple[ |
| Optional[np.ndarray], |
| str, |
| Dict[str, Any], |
| Dict[str, Any], |
| Optional[Dict[str, Any]], |
| ]: |
| """Load a dataset sample based on the selected head.""" |
| subset = DEFAULT_DATASET_FOR_HEAD.get(head) |
| if not subset: |
| gr.Warning("No dataset found for this head.") |
| return None, "", gr.update(visible=False), gr.update(visible=False), None |
|
|
| try: |
| target_id = parse_target_selection(target_selection) |
| image, meta = sample_dataset_example(subset, target_id) |
| |
| windowed_image = apply_windowing(image, subset) |
| display = to_display_image(windowed_image) |
| if meta.get("mask") is not None: |
| display = apply_contour_overlay(display, meta.get("mask")) |
|
|
| target = meta.get("target") |
| |
| ground_truth_update = gr.update(value="") |
| if target is not None: |
| |
| id2label = load_id_to_labels().get(head, {}) |
| label_name = id2label.get(target, str(target)) |
| ground_truth_update = gr.update(value=f"{label_name} (class {target})", visible=True) |
|
|
| return ( |
| display, |
| "", |
| gr.update(visible=False), |
| ground_truth_update, |
| {"image": image, "mask": meta.get("mask")}, |
| ) |
| except Exception as exc: |
| gr.Warning(f"Failed to load sample: {exc}") |
| return None, "", gr.update(visible=False), gr.update(visible=False), None |
|
|
|
|
| def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame: |
| """Return a dataframe sorted by probability desc.""" |
|
|
| values = probs.detach().cpu().numpy() |
| rows = [ |
| {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)} |
| for idx, val in enumerate(values) |
| ] |
| df = pd.DataFrame(rows) |
| df.sort_values("probability", ascending=False, inplace=True) |
| return df |
|
|
|
|
| def run_inference( |
| image_state: Optional[Dict[str, Any]], |
| head: str, |
| ) -> Tuple[str, Dict[str, Any]]: |
| if not image_state or "image" not in image_state: |
| return "Load a dataset sample or upload an image first.", gr.update(visible=False) |
|
|
| try: |
| image = image_state["image"] |
| output = infer_image(image, head, image_state.get("mask"), return_probs=head not in REGRESSION_HEADS) |
|
|
| if head in REGRESSION_HEADS: |
| return f"{output:.1f}", gr.update(visible=False) |
|
|
| |
| id2label = load_id_to_labels().get(head, {}) |
|
|
| df = format_probabilities(output, id2label) |
| top_row = df.iloc[0] |
| prediction = f"{top_row['label']} (p={top_row['probability']:.3f})" |
| result_text = prediction |
| return result_text, gr.update(visible=True, value=df) |
| except Exception as exc: |
| traceback.print_exc() |
| return f"Failed to run inference: {exc}", gr.update(visible=False) |
|
|
| def handle_upload_preview( |
| image: np.ndarray | Image.Image | None, |
| head: str, |
| ) -> Tuple[Optional[np.ndarray], str, str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]: |
| """Handle image upload preview, deriving dataset from head.""" |
| if image is None: |
| return None, "Please upload an image.", "", pd.DataFrame(), gr.update(visible=False), None |
|
|
| try: |
| np_image = np.array(image).astype(np.float32) |
| if np_image.ndim == 3: |
| |
| np_image = np_image.mean(axis=-1) |
|
|
| |
| display = to_display_image(np_image) |
|
|
| return ( |
| display, |
| "Image uploaded. Computing predictions...", |
| "", |
| pd.DataFrame(), |
| gr.update(value=""), |
| {"image": np_image, "mask": None}, |
| ) |
| except Exception as exc: |
| return None, f"Failed to load image: {exc}", "", pd.DataFrame(), gr.update(value=""), None |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_demo() -> gr.Blocks: |
| with gr.Blocks(css=CUSTOM_CSS) as demo: |
| logo_block = "" |
| if LOGO_DATA_URI: |
| logo_block = f'<div class="hero-logo"><img src="{LOGO_DATA_URI}" alt="Curia logo" /></div>' |
| hero_html = f""" |
| <div id=\"app-hero\"> |
| {logo_block} |
| <div class=\"hero-text\"> |
| <h1>Curia Model Playground</h1> |
| <p>Experiment with the multi-head Curia models on CuriaBench evaluation data or your own medical images.</p> |
| <p>Each head expects a single 2D slice in the Curia-defined plane/orientation (PL axial, IL coronal, IP sagittal) with raw Hounsfield units (CT) or normalised MRI intensities.</p> |
| </div> |
| </div> |
| """ |
| gr.HTML(hero_html) |
|
|
| default_head = "kits" |
| head_dropdown = gr.Dropdown( |
| label="Model head", |
| choices=[(label, key) for key, label in HEAD_OPTIONS], |
| value=default_head, |
| ) |
|
|
| |
|
|
| with gr.Row(): |
| with gr.Column(): |
| |
| dataset_display = gr.Markdown(f"**Dataset:** {DATASET_OPTIONS.get(DEFAULT_DATASET_FOR_HEAD.get(default_head, ''), 'Unknown')}") |
| dataset_status = gr.Markdown("Select a model head to load class metadata.") |
| class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random") |
| dataset_btn = gr.Button("Load dataset sample") |
|
|
| with gr.Column(): |
| gr.Markdown("### Upload custom image") |
| |
| initial_requires_mask = default_head in HEADS_REQUIRING_MASK |
| upload_info_text = gr.Markdown( |
| value=( |
| "⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset." |
| if initial_requires_mask |
| else "" |
| ), |
| visible=initial_requires_mask, |
| ) |
| upload_component = gr.Image( |
| label="Upload image", |
| image_mode="L", |
| type="numpy", |
| interactive=not initial_requires_mask, |
| ) |
|
|
| gr.Markdown("---") |
|
|
| status_text = gr.Markdown() |
| with gr.Row(): |
| with gr.Column(): |
| image_display = gr.Image(label="Image", interactive=False, type="numpy") |
|
|
| with gr.Column(): |
| ground_truth_display = gr.Textbox(label="Ground Truth", interactive=False) |
| main_prediction = gr.Textbox(label="Prediction", value="", interactive=False) |
| prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"], visible=False) |
|
|
| image_state = gr.State() |
|
|
| |
| |
| demo.load( |
| fn=load_dataset_metadata, |
| inputs=[head_dropdown], |
| outputs=[class_dropdown, dataset_status, dataset_btn], |
| ) |
| |
| head_dropdown.change( |
| fn=update_dataset_display, |
| inputs=[head_dropdown], |
| outputs=[dataset_display], |
| ).then( |
| fn=update_upload_component_state, |
| inputs=[head_dropdown], |
| outputs=[upload_info_text, upload_component], |
| ).then( |
| fn=load_dataset_metadata, |
| inputs=[head_dropdown], |
| outputs=[class_dropdown, dataset_status, dataset_btn], |
| ) |
|
|
| dataset_btn.click( |
| fn=load_dataset_sample, |
| inputs=[class_dropdown, head_dropdown], |
| outputs=[ |
| image_display, |
| main_prediction, |
| prediction_probs, |
| ground_truth_display, |
| image_state, |
| ], |
| ).then( |
| fn=run_inference, |
| inputs=[image_state, head_dropdown], |
| outputs=[main_prediction, prediction_probs], |
| ) |
|
|
| upload_component.upload( |
| fn=handle_upload_preview, |
| inputs=[upload_component, head_dropdown], |
| outputs=[ |
| image_display, |
| status_text, |
| main_prediction, |
| prediction_probs, |
| ground_truth_display, |
| image_state, |
| ], |
| ).then( |
| fn=run_inference, |
| inputs=[image_state, head_dropdown], |
| outputs=[main_prediction, prediction_probs], |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|