""" gradio_app.py — DeepLense GSoC 2026 Interactive Demo ===================================================== HuggingFace Spaces-ready Gradio application for gravitational lensing dark matter morphology classification. Features: • Upload any lensing image or pick from 3 sample images per class (9 total) • Dropdown: choose model — Baseline / Transfer / Ensemble / Equivariant • Classification result with probability bars for No Sub / CDM / Vortex • GradCAM heatmap overlay — "where is the model looking?" • TTA analysis: predictions at 0° / 90° / 180° / 270° to visually prove equivariant stability vs ensemble instability Deploy: 1. Push this file to a HuggingFace Space (Gradio SDK). 2. Place model weights in weights/ directory of the Space repo. 3. Place sample images in demo_samples/{class_name}/*.npy or *.png. 4. Add requirements.txt with: torch torchvision gradio numpy pillow opencv-python Local run: python gradio_app.py Environment variables (for HF Spaces): MODEL_WEIGHTS_DIR : path to weights/ directory (default: "weights") SAMPLES_DIR : path to sample images (default: "demo_samples") """ from __future__ import annotations import os import sys import warnings import math import traceback from pathlib import Path from typing import Optional import numpy as np from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms as transforms # ── Gradio ──────────────────────────────────────────────────────────────────── try: import gradio as gr except ImportError: raise ImportError("pip install gradio>=4.0.0") # ── OpenCV for GradCAM resize ───────────────────────────────────────────────── try: import cv2 OPENCV_AVAILABLE = True except ImportError: OPENCV_AVAILABLE = False warnings.warn("opencv-python not installed — GradCAM overlay disabled.") # ── src/ module imports ──────────────────────────────────────────────────────── # Support both: python gradio_app.py (from project root) and HF Spaces layout _SCRIPT_DIR = Path(__file__).parent.resolve() for candidate in [_SCRIPT_DIR / 'src', _SCRIPT_DIR]: if (candidate / 'models.py').exists(): if str(candidate) not in sys.path: sys.path.insert(0, str(candidate)) break from models import ResNetBaseline, ResNetTransfer, ViTChampion, DeepLenseEnsemble, load_model from metrics import compute_gradcam, overlay_gradcam # ───────────────────────────────────────────────────────────────────────────── # CONFIGURATION # ───────────────────────────────────────────────────────────────────────────── # ── Path resolution — works both locally and on HuggingFace Spaces ────────── # On HF Spaces all repo files are mounted at /app/ # Locally they sit next to this script if os.path.isdir('/app/demo_samples'): # Running on HuggingFace Spaces SAMPLES_DIR = '/app/demo_samples' WEIGHTS_DIR = '/app/weights' else: # Running locally SAMPLES_DIR = str(_SCRIPT_DIR / 'demo_samples') WEIGHTS_DIR = str(_SCRIPT_DIR / 'weights') print(f"✅ WEIGHTS_DIR : {WEIGHTS_DIR} (exists={os.path.exists(WEIGHTS_DIR)})") print(f"✅ SAMPLES_DIR : {SAMPLES_DIR} (exists={os.path.exists(SAMPLES_DIR)})") DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') CLASS_NAMES = ['No Sub', 'CDM', 'Vortex'] CLASS_COLORS = {'No Sub': '#2196F3', 'CDM': '#F44336', 'Vortex': '#4CAF50'} CLASS_DESCRIPTIONS = { 'No Sub': 'Smooth lens — no dark matter substructure. Featureless convergence map.', 'CDM': 'Cold Dark Matter — localised point-mass subhalos producing small-scale perturbations.', 'Vortex': 'Quantum condensate — extended vortex filaments from ultra-light axion dark matter.', } # ───────────────────────────────────────────────────────────────────────────── # MODEL REGISTRY # ───────────────────────────────────────────────────────────────────────────── MODEL_CONFIG = { 'Transfer (ResNet-18)': { 'mode': 'RGB', 'image_size': 224, 'weights_key': 'transfer_best.pth', 'description': 'ImageNet-pretrained ResNet-18. Best absolute accuracy (89.3%). ' 'Uses GradCAM on layer4 for visualisation.', }, 'Baseline (ResNet-18)': { 'mode': 'L', 'image_size': 64, 'weights_key': 'baseline_best.pth', 'description': 'ResNet-18 trained from scratch on 64×64 grayscale. Lower bound (60.4%). ' 'Uses GradCAM on layer4.', }, 'Ensemble (ResNet + ViT)': { 'mode': 'RGB', 'image_size': 224, 'weights_key': ['transfer_best.pth', 'vit_best.pth'], 'description': 'Stacking meta-learner fusing ResNet-18 + ViT-B/16 (84.0%). ' 'Uses Attention Rollout for ViT sub-model visualisation. ' 'TTA shows 6.2% drop — proves orientation bias.', }, 'Equivariant (C8)': { 'mode': 'L', 'image_size': 128, 'weights_key': 'equivariant_best.pth', 'description': 'C8-equivariant CNN (escnn). Only 0.4% TTA drop — ' 'rotational symmetry baked into architecture. ' 'GradCAM hooks group_pool for equivariant-safe visualisation.', }, } # ───────────────────────────────────────────────────────────────────────────── # MODEL CACHE — load once, reuse across calls # ───────────────────────────────────────────────────────────────────────────── _model_cache: dict[str, torch.nn.Module] = {} def _load_model_cached(model_name: str) -> Optional[torch.nn.Module]: """Load and cache a model. Returns None if weights are not found.""" if model_name in _model_cache: return _model_cache[model_name] cfg = MODEL_CONFIG[model_name] try: if model_name == 'Baseline (ResNet-18)': m = ResNetBaseline(num_classes=3) w = os.path.join(WEIGHTS_DIR, cfg['weights_key']) if not os.path.exists(w): return None m = load_model(m, w, DEVICE) elif model_name == 'Transfer (ResNet-18)': m = ResNetTransfer(num_classes=3) w = os.path.join(WEIGHTS_DIR, cfg['weights_key']) if not os.path.exists(w): return None m = load_model(m, w, DEVICE) elif model_name == 'Ensemble (ResNet + ViT)': resnet_w, vit_w = [os.path.join(WEIGHTS_DIR, k) for k in cfg['weights_key']] if not (os.path.exists(resnet_w) and os.path.exists(vit_w)): return None resnet = load_model(ResNetTransfer(num_classes=3), resnet_w, DEVICE) vit = load_model(ViTChampion(num_classes=3), vit_w, DEVICE) m = DeepLenseEnsemble(resnet_model=resnet, vit_model=vit, freeze_base=True) m = m.to(DEVICE) m.eval() elif model_name == 'Equivariant (C8)': try: from models import EquivariantCNN m = EquivariantCNN(num_classes=3, n_rotations=8) except ImportError: return None w = os.path.join(WEIGHTS_DIR, cfg['weights_key']) if not os.path.exists(w): return None m = load_model(m, w, DEVICE) else: return None _model_cache[model_name] = m return m except Exception as e: print(f"⚠️ Failed to load {model_name}: {e}") return None # ───────────────────────────────────────────────────────────────────────────── # TRANSFORM UTILITIES # ───────────────────────────────────────────────────────────────────────────── def _get_transform(mode: str, image_size: int) -> transforms.Compose: if mode == 'RGB': normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), normalize, ]) else: # Grayscale return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]), ]) def _pil_to_tensor(pil_img: Image.Image, mode: str, image_size: int) -> torch.Tensor: """Converts a PIL Image to a normalised (1, C, H, W) tensor.""" transform = _get_transform(mode, image_size) if pil_img.mode == 'RGBA': pil_img = pil_img.convert('RGB') tensor = transform(pil_img) return tensor.unsqueeze(0) def _tensor_to_numpy_displayable(tensor: torch.Tensor, mode: str) -> np.ndarray: """Converts a (1, C, H, W) or (C, H, W) normalised tensor to (H, W, 3) uint8.""" t = tensor.squeeze(0).cpu().numpy() if mode == 'RGB': mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) t = t.transpose(1, 2, 0) * std + mean t = np.clip(t, 0, 1) return (t * 255).astype(np.uint8) else: t = t[0] * 0.5 + 0.5 t = np.clip(t, 0, 1) t = (t * 255).astype(np.uint8) return np.stack([t, t, t], axis=-1) # Grayscale → RGB for display # ───────────────────────────────────────────────────────────────────────────── # SAMPLE IMAGE LOADER # ───────────────────────────────────────────────────────────────────────────── def _load_sample_images() -> dict[str, list[str]]: """ Loads sample image paths from SAMPLES_DIR. Directory structure expected: demo_samples/ no_sub/ *.png or *.jpg cdm/ *.png or *.jpg vortex/ *.png or *.jpg Returns dict: {class_name: [path1, path2, path3]} """ samples = {} class_dir_map = {'No Sub': 'no_sub', 'CDM': 'cdm', 'Vortex': 'vortex'} for class_name, dir_name in class_dir_map.items(): class_dir = os.path.join(SAMPLES_DIR, dir_name) if not os.path.exists(class_dir): samples[class_name] = [] continue exts = {'.png', '.jpg', '.jpeg', '.npy'} paths = sorted([ os.path.join(class_dir, f) for f in os.listdir(class_dir) if Path(f).suffix.lower() in exts ])[:3] samples[class_name] = paths return samples SAMPLE_IMAGES = _load_sample_images() def _path_to_pil(path: str) -> Image.Image: """Loads a PNG/JPG/NPY file as PIL Image (RGB or L).""" if path.endswith('.npy'): arr = np.load(path) if arr.ndim == 2: arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) arr = (arr * 255).astype(np.uint8) return Image.fromarray(arr, mode='L') elif arr.ndim == 3 and arr.shape[0] == 1: arr = arr[0] arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) arr = (arr * 255).astype(np.uint8) return Image.fromarray(arr, mode='L') else: arr = arr.transpose(1, 2, 0) if arr.shape[0] == 3 else arr arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) return Image.fromarray(arr) else: return Image.open(path).convert('RGB') # ───────────────────────────────────────────────────────────────────────────── # INFERENCE ENGINE # ───────────────────────────────────────────────────────────────────────────── def _run_inference( model: torch.nn.Module, image_tensor: torch.Tensor, ) -> tuple[int, np.ndarray]: """Returns (pred_class_idx, probs_array).""" model.eval() with torch.no_grad(): logits = model(image_tensor.to(DEVICE)) probs = F.softmax(logits, dim=1).squeeze(0).cpu().numpy() return int(np.argmax(probs)), probs def _run_tta( model: torch.nn.Module, image_tensor: torch.Tensor, angles: list[int] = [0, 90, 180, 270], ) -> dict[int, tuple[int, np.ndarray]]: """ Runs inference at each TTA angle independently (NOT averaged). Returns dict: {angle: (pred_class_idx, probs_array)} This is intentionally per-angle (not averaged TTA) so the demo can show how each rotation changes the prediction — revealing orientation bias. """ model.eval() results = {} k_map = {0: 0, 90: 1, 180: 2, 270: 3} with torch.no_grad(): for angle in angles: k = k_map.get(angle % 360, 0) rotated = torch.rot90(image_tensor, k=k, dims=[2, 3]) logits = model(rotated.to(DEVICE)) probs = F.softmax(logits, dim=1).squeeze(0).cpu().numpy() results[angle] = (int(np.argmax(probs)), probs) return results # ───────────────────────────────────────────────────────────────────────────── # GRADCAM OVERLAY BUILDER # ───────────────────────────────────────────────────────────────────────────── def _build_gradcam_image( model: torch.nn.Module, image_tensor: torch.Tensor, pred_class: int, mode: str, ) -> Optional[np.ndarray]: """ Returns a (H, W, 3) uint8 numpy array with the GradCAM overlay, or None if GradCAM is unavailable (missing opencv or escnn edge cases). """ if not OPENCV_AVAILABLE: return None try: cam, _, _ = compute_gradcam( model = model, image_tensor = image_tensor, class_idx = pred_class, device = DEVICE, ) img_np = _tensor_to_numpy_displayable(image_tensor, mode).astype(float) / 255.0 overlay = overlay_gradcam(img_np, cam, alpha=0.45) return (overlay * 255).astype(np.uint8) except Exception as e: print(f"⚠️ GradCAM failed: {e}\n{traceback.format_exc()}") return None # ───────────────────────────────────────────────────────────────────────────── # MAIN GRADIO CALLBACK # ───────────────────────────────────────────────────────────────────────────── def classify_image( uploaded_image, # PIL Image from gr.Image model_name: str, sample_class: str, sample_index: int, ) -> tuple: """ Main inference callback. Returns: (result_html, probabilities_dict, gradcam_image, tta_html, model_info_text) """ # ── Resolve input image ─────────────────────────────────────────────── pil_img = None if uploaded_image is not None: if isinstance(uploaded_image, np.ndarray): pil_img = Image.fromarray(uploaded_image) elif isinstance(uploaded_image, Image.Image): pil_img = uploaded_image else: pil_img = Image.fromarray(uploaded_image) elif sample_class and SAMPLE_IMAGES.get(sample_class): idx = min(int(sample_index), len(SAMPLE_IMAGES[sample_class]) - 1) pil_img = _path_to_pil(SAMPLE_IMAGES[sample_class][idx]) # Always convert to RGB at this stage — model transforms handle # grayscale conversion internally if mode='L'. Starting from RGB # guarantees the image is never a flat 1-channel blob going in. if pil_img is not None and pil_img.mode != 'RGB': pil_img = pil_img.convert('RGB') if pil_img is None: return ( "
Please upload an image or select a sample.
", {cls: 0.0 for cls in CLASS_NAMES}, None, "No image provided.
", MODEL_CONFIG.get(model_name, {}).get('description', ''), ) # ── Load model ──────────────────────────────────────────────────────── model = _load_model_cached(model_name) if model is None: weights_info = MODEL_CONFIG[model_name]['weights_key'] missing = weights_info if isinstance(weights_info, str) else ', '.join(weights_info) return ( f"⚠️ Model weights not found: {missing}. "
f"Place .pth files in {WEIGHTS_DIR}/
Model not loaded.
", MODEL_CONFIG[model_name]['description'], ) cfg = MODEL_CONFIG[model_name] mode = cfg['mode'] image_size = cfg['image_size'] # ── Preprocess ──────────────────────────────────────────────────────── try: tensor = _pil_to_tensor(pil_img, mode, image_size) # (1, C, H, W) except Exception as e: return ( f"Image preprocessing failed: {e}
", {cls: 0.0 for cls in CLASS_NAMES}, None, "", cfg['description'], ) # ── Standard inference ──────────────────────────────────────────────── pred_class, probs = _run_inference(model, tensor) pred_name = CLASS_NAMES[pred_class] # ── Build result HTML ───────────────────────────────────────────────── color = CLASS_COLORS[pred_name] result_html = f"""{CLASS_DESCRIPTIONS[pred_name]}
Confidence: {probs[pred_class]*100:.1f}%
A physically correct model (lensing has no preferred orientation) should give identical predictions at any rotation. This table shows whether the model is equivariantly stable or orientation-biased.
| Rotation | Prediction | Confidence | Prob Distribution | Stable? |
|---|
{stability_text}
No Sub = ■ CDM = ■ Vortex = ■