Spaces:
Runtime error
Runtime error
| """Gradio web demo for satellite change detection. | |
| Upload before/after satellite image pairs, select a model, and view the | |
| predicted change mask, overlay, and change-area statistics. | |
| Auto-detects available checkpoints — no manual path entry needed. | |
| Usage: | |
| python app.py | |
| """ | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from data.dataset import IMAGENET_MEAN, IMAGENET_STD | |
| from inference import sliding_window_inference | |
| from models import get_model | |
| from utils.visualization import overlay_changes | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Globals | |
| # --------------------------------------------------------------------------- | |
| _cached_model: Optional[torch.nn.Module] = None | |
| _cached_model_key: Optional[str] = None | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _config: Optional[Dict[str, Any]] = None | |
| _hf_model_repo_id: Optional[str] = os.getenv("HF_MODEL_REPO") | |
| _hf_model_revision: Optional[str] = os.getenv("HF_MODEL_REVISION") | |
| # Search these directories for checkpoint files | |
| _CHECKPOINT_SEARCH_DIRS = [ | |
| Path("checkpoints"), | |
| Path("/home/user/app/checkpoints"), | |
| Path("/kaggle/working/checkpoints"), | |
| Path("/content/drive/MyDrive/change-detection/checkpoints"), | |
| ] | |
| # Map model names to expected checkpoint filenames | |
| _MODEL_CHECKPOINT_NAMES = { | |
| "siamese_cnn": "siamese_cnn_best.pth", | |
| "unet_pp": "unet_pp_best.pth", | |
| "changeformer": "changeformer_best.pth", | |
| } | |
| def _download_checkpoint_from_hf(model_name: str) -> Optional[Path]: | |
| """Download checkpoint from Hugging Face Hub if configured. | |
| Uses env var ``HF_MODEL_REPO`` as the source model repository and | |
| downloads to ``./checkpoints`` cache. | |
| Args: | |
| model_name: One of the supported model keys. | |
| Returns: | |
| Local path to downloaded checkpoint, or ``None`` if unavailable. | |
| """ | |
| if not _hf_model_repo_id: | |
| return None | |
| filename = _MODEL_CHECKPOINT_NAMES.get(model_name) | |
| if filename is None: | |
| return None | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=_hf_model_repo_id, | |
| filename=filename, | |
| revision=_hf_model_revision, | |
| local_dir="checkpoints", | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info("Downloaded %s from %s", filename, _hf_model_repo_id) | |
| return Path(local_path) | |
| except Exception as exc: # pragma: no cover - best-effort fallback | |
| logger.warning("Could not download %s from HF Hub: %s", filename, exc) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Config / model loading | |
| # --------------------------------------------------------------------------- | |
| def _load_config() -> Dict[str, Any]: | |
| """Load and cache the project config. | |
| Returns: | |
| Full config dict. | |
| """ | |
| global _config | |
| if _config is None: | |
| config_path = Path("configs/config.yaml") | |
| with open(config_path, "r") as fh: | |
| _config = yaml.safe_load(fh) | |
| return _config | |
| def _find_checkpoint(model_name: str) -> Optional[Path]: | |
| """Auto-detect the checkpoint file for a given model. | |
| Searches multiple directories for the expected checkpoint filename. | |
| Args: | |
| model_name: One of ``siamese_cnn``, ``unet_pp``, ``changeformer``. | |
| Returns: | |
| Path to the checkpoint if found, ``None`` otherwise. | |
| """ | |
| filename = _MODEL_CHECKPOINT_NAMES.get(model_name) | |
| if filename is None: | |
| return None | |
| for search_dir in _CHECKPOINT_SEARCH_DIRS: | |
| candidate = search_dir / filename | |
| if candidate.exists(): | |
| return candidate | |
| downloaded = _download_checkpoint_from_hf(model_name) | |
| if downloaded is not None and downloaded.exists(): | |
| return downloaded | |
| return None | |
| def _get_available_models() -> List[str]: | |
| """Return a list of model names that have checkpoints available. | |
| Returns: | |
| List of model name strings with detected checkpoints. | |
| """ | |
| available = [] | |
| for model_name in _MODEL_CHECKPOINT_NAMES: | |
| if _find_checkpoint(model_name) is not None: | |
| available.append(model_name) | |
| return available | |
| def _load_model(model_name: str) -> torch.nn.Module: | |
| """Load a model using auto-detected checkpoint. | |
| Args: | |
| model_name: Architecture name. | |
| Returns: | |
| Model in eval mode on the current device. | |
| Raises: | |
| FileNotFoundError: If no checkpoint is found. | |
| """ | |
| global _cached_model, _cached_model_key | |
| if _cached_model is not None and _cached_model_key == model_name: | |
| return _cached_model | |
| ckpt_path = _find_checkpoint(model_name) | |
| if ckpt_path is None: | |
| raise FileNotFoundError( | |
| f"No checkpoint found for '{model_name}'. " | |
| f"Expected '{_MODEL_CHECKPOINT_NAMES[model_name]}' in one of: " | |
| f"{[str(d) for d in _CHECKPOINT_SEARCH_DIRS]}" | |
| ) | |
| config = _load_config() | |
| model = get_model(model_name, config).to(_device) | |
| ckpt = torch.load(ckpt_path, map_location=_device) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.eval() | |
| _cached_model = model | |
| _cached_model_key = model_name | |
| logger.info("Loaded %s from %s", model_name, ckpt_path) | |
| return model | |
| # --------------------------------------------------------------------------- | |
| # Preprocessing | |
| # --------------------------------------------------------------------------- | |
| def _numpy_to_tensor( | |
| img: np.ndarray, | |
| patch_size: int = 256, | |
| ) -> Tuple[torch.Tensor, Tuple[int, int]]: | |
| """Convert a uint8 RGB numpy image to a normalised, padded tensor. | |
| Args: | |
| img: Input image ``[H, W, 3]``, uint8, RGB. | |
| patch_size: Pad to a multiple of this value. | |
| Returns: | |
| Tuple of ``(tensor [1, 3, H_pad, W_pad], (orig_h, orig_w))``. | |
| """ | |
| orig_h, orig_w = img.shape[:2] | |
| pad_h = (patch_size - orig_h % patch_size) % patch_size | |
| pad_w = (patch_size - orig_w % patch_size) % patch_size | |
| if pad_h > 0 or pad_w > 0: | |
| img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") | |
| img_f = img.astype(np.float32) / 255.0 | |
| mean = np.array(IMAGENET_MEAN, dtype=np.float32) | |
| std = np.array(IMAGENET_STD, dtype=np.float32) | |
| img_f = (img_f - mean) / std | |
| tensor = torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float() | |
| return tensor, (orig_h, orig_w) | |
| # --------------------------------------------------------------------------- | |
| # Prediction | |
| # --------------------------------------------------------------------------- | |
| def predict( | |
| before_image: Optional[np.ndarray], | |
| after_image: Optional[np.ndarray], | |
| model_name: str, | |
| threshold: float, | |
| ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]: | |
| """Run change detection and return visualisations + summary text. | |
| Args: | |
| before_image: Before image as numpy ``[H, W, 3]`` RGB uint8. | |
| after_image: After image as numpy ``[H, W, 3]`` RGB uint8. | |
| model_name: Architecture name. | |
| threshold: Binarisation threshold for predictions. | |
| Returns: | |
| Tuple of ``(change_mask, overlay_image, summary_text)``. | |
| """ | |
| if before_image is None or after_image is None: | |
| return None, None, "Please upload both **before** and **after** images." | |
| config = _load_config() | |
| patch_size: int = config.get("dataset", {}).get("patch_size", 256) | |
| # Load model (auto-detects checkpoint) | |
| try: | |
| model = _load_model(model_name) | |
| except FileNotFoundError as exc: | |
| return None, None, f"**Error:** {exc}" | |
| # Preprocess | |
| tensor_a, (orig_h, orig_w) = _numpy_to_tensor(before_image, patch_size) | |
| tensor_b, _ = _numpy_to_tensor(after_image, patch_size) | |
| # Tiled inference | |
| prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device) | |
| prob_map = prob_map[:, :, :orig_h, :orig_w] | |
| prob_np = prob_map.squeeze().numpy() | |
| # Binary change mask | |
| binary_mask = (prob_np > threshold).astype(np.uint8) * 255 | |
| # Overlay on after image | |
| pred_tensor = (prob_map.squeeze(0) >= threshold).float() | |
| img_b_tensor = tensor_b.squeeze()[:, :orig_h, :orig_w] | |
| overlay_rgb = overlay_changes( | |
| img_after=img_b_tensor, | |
| mask_pred=pred_tensor, | |
| alpha=0.4, | |
| color=(255, 0, 0), | |
| ) | |
| # Change statistics | |
| total_pixels = orig_h * orig_w | |
| changed_pixels = int(binary_mask.sum() // 255) | |
| pct_changed = (changed_pixels / total_pixels) * 100.0 | |
| ckpt_path = _find_checkpoint(model_name) | |
| summary = ( | |
| f"### Change Detection Results\n\n" | |
| f"| Metric | Value |\n" | |
| f"|---|---|\n" | |
| f"| **Model** | {model_name} |\n" | |
| f"| **Image size** | {orig_w} x {orig_h} |\n" | |
| f"| **Total pixels** | {total_pixels:,} |\n" | |
| f"| **Changed pixels** | {changed_pixels:,} |\n" | |
| f"| **Area changed** | {pct_changed:.2f}% |\n" | |
| f"| **Threshold** | {threshold} |\n" | |
| f"| **Checkpoint** | {ckpt_path.name if ckpt_path else 'N/A'} |\n" | |
| f"| **Device** | {_device} |" | |
| ) | |
| return binary_mask, overlay_rgb, summary | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_demo() -> gr.Blocks: | |
| """Construct the Gradio Blocks interface. | |
| Returns: | |
| A ``gr.Blocks`` application ready to ``.launch()``. | |
| """ | |
| available = _get_available_models() | |
| all_models = list(_MODEL_CHECKPOINT_NAMES.keys()) | |
| # Show which models are available | |
| status_lines = [] | |
| for m in all_models: | |
| ckpt = _find_checkpoint(m) | |
| if ckpt: | |
| status_lines.append(f"- **{m}**: {ckpt.name}") | |
| else: | |
| status_lines.append(f"- **{m}**: not found") | |
| model_status = "\n".join(status_lines) | |
| default_model = available[0] if available else "changeformer" | |
| with gr.Blocks(title="Military Base Change Detection") as demo: | |
| gr.Markdown( | |
| "# Military Base Change Detection\n" | |
| "Upload **before** and **after** satellite images to detect " | |
| "construction, infrastructure changes, and runway development.\n\n" | |
| "**Available models:**\n" + model_status | |
| ) | |
| # ---- Inputs --------------------------------------------------- | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| before_img = gr.Image( | |
| label="Before Image (older)", | |
| type="numpy", | |
| sources=["upload", "clipboard"], | |
| ) | |
| with gr.Column(scale=1): | |
| after_img = gr.Image( | |
| label="After Image (newer)", | |
| type="numpy", | |
| sources=["upload", "clipboard"], | |
| ) | |
| # ---- Controls ------------------------------------------------- | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=available if available else all_models, | |
| value=default_model, | |
| label="Model Architecture", | |
| ) | |
| threshold_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="Detection Threshold", | |
| ) | |
| detect_btn = gr.Button("Detect Changes", variant="primary", size="lg") | |
| # ---- Outputs -------------------------------------------------- | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| change_mask_out = gr.Image(label="Change Mask") | |
| with gr.Column(scale=1): | |
| overlay_out = gr.Image(label="Overlay (changes in red)") | |
| summary_out = gr.Markdown(label="Summary") | |
| # ---- Wiring --------------------------------------------------- | |
| detect_btn.click( | |
| fn=predict, | |
| inputs=[before_img, after_img, model_dropdown, threshold_slider], | |
| outputs=[change_mask_out, overlay_out, summary_out], | |
| ) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| """Launch the Gradio demo server.""" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| config = _load_config() | |
| gradio_cfg = config.get("gradio", {}) | |
| demo = build_demo() | |
| in_hf_space = os.getenv("SPACE_ID") is not None | |
| demo.launch( | |
| server_name="0.0.0.0" if in_hf_space else "127.0.0.1", | |
| server_port=gradio_cfg.get("server_port", 7860), | |
| share=False if in_hf_space else gradio_cfg.get("share", False), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |