Vedant Jigarbhai Mehta commited on
Commit
027adea
·
1 Parent(s): c95b5c2

Auto-detect checkpoints in Gradio app, no manual path needed

Browse files
Files changed (2) hide show
  1. app.py +119 -65
  2. configs/config.yaml +3 -3
app.py CHANGED
@@ -1,10 +1,9 @@
1
  """Gradio web demo for satellite change detection.
2
 
3
- Upload before/after satellite image pairs, select a model and checkpoint, and
4
- view the predicted change mask, overlay, and change-area statistics.
5
 
6
- Defaults (model, checkpoint, port, share) are read from the ``gradio`` section
7
- of ``configs/config.yaml``.
8
 
9
  Usage:
10
  python app.py
@@ -12,16 +11,15 @@ Usage:
12
 
13
  import logging
14
  from pathlib import Path
15
- from typing import Any, Dict, Optional, Tuple
16
 
17
- import cv2
18
  import gradio as gr
19
  import numpy as np
20
  import torch
21
  import yaml
22
 
23
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
24
- from inference import load_and_preprocess, sliding_window_inference
25
  from models import get_model
26
  from utils.visualization import overlay_changes
27
 
@@ -29,14 +27,32 @@ logger = logging.getLogger(__name__)
29
 
30
 
31
  # ---------------------------------------------------------------------------
32
- # Globals (model cache to avoid reloading on every prediction)
33
  # ---------------------------------------------------------------------------
34
 
35
  _cached_model: Optional[torch.nn.Module] = None
36
- _cached_model_key: Optional[str] = None # "model_name::checkpoint_path"
37
  _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  _config: Optional[Dict[str, Any]] = None
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def _load_config() -> Dict[str, Any]:
42
  """Load and cache the project config.
@@ -52,43 +68,81 @@ def _load_config() -> Dict[str, Any]:
52
  return _config
53
 
54
 
55
- def _load_model(model_name: str, checkpoint_path: str) -> torch.nn.Module:
56
- """Load a model, re-using the cache if name + checkpoint match.
 
 
57
 
58
  Args:
59
- model_name: Architecture name (``siamese_cnn``, ``unet_pp``, ``changeformer``).
60
- checkpoint_path: Path to the ``.pth`` checkpoint file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  Returns:
63
  Model in eval mode on the current device.
64
 
65
  Raises:
66
- FileNotFoundError: If the checkpoint does not exist.
67
  """
68
  global _cached_model, _cached_model_key
69
 
70
- cache_key = f"{model_name}::{checkpoint_path}"
71
- if _cached_model is not None and _cached_model_key == cache_key:
72
  return _cached_model
73
 
74
- config = _load_config()
75
- ckpt_path = Path(checkpoint_path)
76
- if not ckpt_path.exists():
77
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
 
 
 
78
 
 
79
  model = get_model(model_name, config).to(_device)
80
  ckpt = torch.load(ckpt_path, map_location=_device)
81
  model.load_state_dict(ckpt["model_state_dict"])
82
  model.eval()
83
 
84
  _cached_model = model
85
- _cached_model_key = cache_key
86
- logger.info("Loaded model %s from %s", model_name, checkpoint_path)
87
  return model
88
 
89
 
90
  # ---------------------------------------------------------------------------
91
- # Preprocessing helper (numpy RGB uint8 → tensor)
92
  # ---------------------------------------------------------------------------
93
 
94
  def _numpy_to_tensor(
@@ -121,14 +175,13 @@ def _numpy_to_tensor(
121
 
122
 
123
  # ---------------------------------------------------------------------------
124
- # Prediction function (called by Gradio)
125
  # ---------------------------------------------------------------------------
126
 
127
  def predict(
128
  before_image: Optional[np.ndarray],
129
  after_image: Optional[np.ndarray],
130
  model_name: str,
131
- checkpoint_path: str,
132
  threshold: float,
133
  ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
134
  """Run change detection and return visualisations + summary text.
@@ -137,26 +190,22 @@ def predict(
137
  before_image: Before image as numpy ``[H, W, 3]`` RGB uint8.
138
  after_image: After image as numpy ``[H, W, 3]`` RGB uint8.
139
  model_name: Architecture name.
140
- checkpoint_path: Path to checkpoint file.
141
  threshold: Binarisation threshold for predictions.
142
 
143
  Returns:
144
  Tuple of ``(change_mask, overlay_image, summary_text)``.
145
- - ``change_mask``: uint8 grayscale ``[H, W]`` (0 or 255).
146
- - ``overlay_image``: uint8 RGB ``[H, W, 3]``.
147
- - ``summary_text``: Markdown string with change statistics.
148
  """
149
  if before_image is None or after_image is None:
150
- return None, None, "Please upload both before and after images."
151
 
152
  config = _load_config()
153
  patch_size: int = config.get("dataset", {}).get("patch_size", 256)
154
 
155
- # Load model
156
  try:
157
- model = _load_model(model_name, checkpoint_path)
158
  except FileNotFoundError as exc:
159
- return None, None, f"Error: {exc}"
160
 
161
  # Preprocess
162
  tensor_a, (orig_h, orig_w) = _numpy_to_tensor(before_image, patch_size)
@@ -165,14 +214,14 @@ def predict(
165
  # Tiled inference
166
  prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
167
  prob_map = prob_map[:, :, :orig_h, :orig_w]
168
- prob_np = prob_map.squeeze().numpy() # [H, W]
169
 
170
  # Binary change mask
171
  binary_mask = (prob_np > threshold).astype(np.uint8) * 255
172
 
173
  # Overlay on after image
174
- pred_tensor = (prob_map.squeeze(0) >= threshold).float() # [1, H, W]
175
- img_b_tensor = tensor_b.squeeze()[:, :orig_h, :orig_w] # [3, H, W]
176
  overlay_rgb = overlay_changes(
177
  img_after=img_b_tensor,
178
  mask_pred=pred_tensor,
@@ -185,14 +234,19 @@ def predict(
185
  changed_pixels = int(binary_mask.sum() // 255)
186
  pct_changed = (changed_pixels / total_pixels) * 100.0
187
 
 
188
  summary = (
189
- f"### Change Detection Summary\n"
190
- f"- **Image size**: {orig_w} x {orig_h}\n"
191
- f"- **Total pixels**: {total_pixels:,}\n"
192
- f"- **Changed pixels**: {changed_pixels:,}\n"
193
- f"- **Area changed**: {pct_changed:.2f}%\n"
194
- f"- **Model**: {model_name}\n"
195
- f"- **Threshold**: {threshold}"
 
 
 
 
196
  )
197
 
198
  return binary_mask, overlay_rgb, summary
@@ -208,31 +262,41 @@ def build_demo() -> gr.Blocks:
208
  Returns:
209
  A ``gr.Blocks`` application ready to ``.launch()``.
210
  """
211
- config = _load_config()
212
- gradio_cfg = config.get("gradio", {})
213
 
214
- with gr.Blocks(
215
- title="Military Base Change Detection",
216
- theme=gr.themes.Soft(),
217
- ) as demo:
 
 
 
 
 
 
 
 
 
218
 
219
  gr.Markdown(
220
  "# Military Base Change Detection\n"
221
  "Upload **before** and **after** satellite images to detect "
222
- "construction, infrastructure changes, and runway development."
 
223
  )
224
 
225
  # ---- Inputs ---------------------------------------------------
226
  with gr.Row():
227
  with gr.Column(scale=1):
228
  before_img = gr.Image(
229
- label="Before Image",
230
  type="numpy",
231
  sources=["upload", "clipboard"],
232
  )
233
  with gr.Column(scale=1):
234
  after_img = gr.Image(
235
- label="After Image",
236
  type="numpy",
237
  sources=["upload", "clipboard"],
238
  )
@@ -240,14 +304,10 @@ def build_demo() -> gr.Blocks:
240
  # ---- Controls -------------------------------------------------
241
  with gr.Row():
242
  model_dropdown = gr.Dropdown(
243
- choices=["siamese_cnn", "unet_pp", "changeformer"],
244
- value=gradio_cfg.get("default_model", "unet_pp"),
245
  label="Model Architecture",
246
  )
247
- checkpoint_input = gr.Textbox(
248
- value=gradio_cfg.get("default_checkpoint", "checkpoints/unet_pp_best.pth"),
249
- label="Checkpoint Path",
250
- )
251
  threshold_slider = gr.Slider(
252
  minimum=0.1,
253
  maximum=0.9,
@@ -270,13 +330,7 @@ def build_demo() -> gr.Blocks:
270
  # ---- Wiring ---------------------------------------------------
271
  detect_btn.click(
272
  fn=predict,
273
- inputs=[
274
- before_img,
275
- after_img,
276
- model_dropdown,
277
- checkpoint_input,
278
- threshold_slider,
279
- ],
280
  outputs=[change_mask_out, overlay_out, summary_out],
281
  )
282
 
 
1
  """Gradio web demo for satellite change detection.
2
 
3
+ Upload before/after satellite image pairs, select a model, and view the
4
+ predicted change mask, overlay, and change-area statistics.
5
 
6
+ Auto-detects available checkpoints no manual path entry needed.
 
7
 
8
  Usage:
9
  python app.py
 
11
 
12
  import logging
13
  from pathlib import Path
14
+ from typing import Any, Dict, List, Optional, Tuple
15
 
 
16
  import gradio as gr
17
  import numpy as np
18
  import torch
19
  import yaml
20
 
21
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
22
+ from inference import sliding_window_inference
23
  from models import get_model
24
  from utils.visualization import overlay_changes
25
 
 
27
 
28
 
29
  # ---------------------------------------------------------------------------
30
+ # Globals
31
  # ---------------------------------------------------------------------------
32
 
33
  _cached_model: Optional[torch.nn.Module] = None
34
+ _cached_model_key: Optional[str] = None
35
  _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  _config: Optional[Dict[str, Any]] = None
37
 
38
+ # Search these directories for checkpoint files
39
+ _CHECKPOINT_SEARCH_DIRS = [
40
+ Path("checkpoints"),
41
+ Path("/kaggle/working/checkpoints"),
42
+ Path("/content/drive/MyDrive/change-detection/checkpoints"),
43
+ ]
44
+
45
+ # Map model names to expected checkpoint filenames
46
+ _MODEL_CHECKPOINT_NAMES = {
47
+ "siamese_cnn": "siamese_cnn_best.pth",
48
+ "unet_pp": "unet_pp_best.pth",
49
+ "changeformer": "changeformer_best.pth",
50
+ }
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Config / model loading
55
+ # ---------------------------------------------------------------------------
56
 
57
  def _load_config() -> Dict[str, Any]:
58
  """Load and cache the project config.
 
68
  return _config
69
 
70
 
71
+ def _find_checkpoint(model_name: str) -> Optional[Path]:
72
+ """Auto-detect the checkpoint file for a given model.
73
+
74
+ Searches multiple directories for the expected checkpoint filename.
75
 
76
  Args:
77
+ model_name: One of ``siamese_cnn``, ``unet_pp``, ``changeformer``.
78
+
79
+ Returns:
80
+ Path to the checkpoint if found, ``None`` otherwise.
81
+ """
82
+ filename = _MODEL_CHECKPOINT_NAMES.get(model_name)
83
+ if filename is None:
84
+ return None
85
+
86
+ for search_dir in _CHECKPOINT_SEARCH_DIRS:
87
+ candidate = search_dir / filename
88
+ if candidate.exists():
89
+ return candidate
90
+
91
+ return None
92
+
93
+
94
+ def _get_available_models() -> List[str]:
95
+ """Return a list of model names that have checkpoints available.
96
+
97
+ Returns:
98
+ List of model name strings with detected checkpoints.
99
+ """
100
+ available = []
101
+ for model_name in _MODEL_CHECKPOINT_NAMES:
102
+ if _find_checkpoint(model_name) is not None:
103
+ available.append(model_name)
104
+ return available
105
+
106
+
107
+ def _load_model(model_name: str) -> torch.nn.Module:
108
+ """Load a model using auto-detected checkpoint.
109
+
110
+ Args:
111
+ model_name: Architecture name.
112
 
113
  Returns:
114
  Model in eval mode on the current device.
115
 
116
  Raises:
117
+ FileNotFoundError: If no checkpoint is found.
118
  """
119
  global _cached_model, _cached_model_key
120
 
121
+ if _cached_model is not None and _cached_model_key == model_name:
 
122
  return _cached_model
123
 
124
+ ckpt_path = _find_checkpoint(model_name)
125
+ if ckpt_path is None:
126
+ raise FileNotFoundError(
127
+ f"No checkpoint found for '{model_name}'. "
128
+ f"Expected '{_MODEL_CHECKPOINT_NAMES[model_name]}' in one of: "
129
+ f"{[str(d) for d in _CHECKPOINT_SEARCH_DIRS]}"
130
+ )
131
 
132
+ config = _load_config()
133
  model = get_model(model_name, config).to(_device)
134
  ckpt = torch.load(ckpt_path, map_location=_device)
135
  model.load_state_dict(ckpt["model_state_dict"])
136
  model.eval()
137
 
138
  _cached_model = model
139
+ _cached_model_key = model_name
140
+ logger.info("Loaded %s from %s", model_name, ckpt_path)
141
  return model
142
 
143
 
144
  # ---------------------------------------------------------------------------
145
+ # Preprocessing
146
  # ---------------------------------------------------------------------------
147
 
148
  def _numpy_to_tensor(
 
175
 
176
 
177
  # ---------------------------------------------------------------------------
178
+ # Prediction
179
  # ---------------------------------------------------------------------------
180
 
181
  def predict(
182
  before_image: Optional[np.ndarray],
183
  after_image: Optional[np.ndarray],
184
  model_name: str,
 
185
  threshold: float,
186
  ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
187
  """Run change detection and return visualisations + summary text.
 
190
  before_image: Before image as numpy ``[H, W, 3]`` RGB uint8.
191
  after_image: After image as numpy ``[H, W, 3]`` RGB uint8.
192
  model_name: Architecture name.
 
193
  threshold: Binarisation threshold for predictions.
194
 
195
  Returns:
196
  Tuple of ``(change_mask, overlay_image, summary_text)``.
 
 
 
197
  """
198
  if before_image is None or after_image is None:
199
+ return None, None, "Please upload both **before** and **after** images."
200
 
201
  config = _load_config()
202
  patch_size: int = config.get("dataset", {}).get("patch_size", 256)
203
 
204
+ # Load model (auto-detects checkpoint)
205
  try:
206
+ model = _load_model(model_name)
207
  except FileNotFoundError as exc:
208
+ return None, None, f"**Error:** {exc}"
209
 
210
  # Preprocess
211
  tensor_a, (orig_h, orig_w) = _numpy_to_tensor(before_image, patch_size)
 
214
  # Tiled inference
215
  prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
216
  prob_map = prob_map[:, :, :orig_h, :orig_w]
217
+ prob_np = prob_map.squeeze().numpy()
218
 
219
  # Binary change mask
220
  binary_mask = (prob_np > threshold).astype(np.uint8) * 255
221
 
222
  # Overlay on after image
223
+ pred_tensor = (prob_map.squeeze(0) >= threshold).float()
224
+ img_b_tensor = tensor_b.squeeze()[:, :orig_h, :orig_w]
225
  overlay_rgb = overlay_changes(
226
  img_after=img_b_tensor,
227
  mask_pred=pred_tensor,
 
234
  changed_pixels = int(binary_mask.sum() // 255)
235
  pct_changed = (changed_pixels / total_pixels) * 100.0
236
 
237
+ ckpt_path = _find_checkpoint(model_name)
238
  summary = (
239
+ f"### Change Detection Results\n\n"
240
+ f"| Metric | Value |\n"
241
+ f"|---|---|\n"
242
+ f"| **Model** | {model_name} |\n"
243
+ f"| **Image size** | {orig_w} x {orig_h} |\n"
244
+ f"| **Total pixels** | {total_pixels:,} |\n"
245
+ f"| **Changed pixels** | {changed_pixels:,} |\n"
246
+ f"| **Area changed** | {pct_changed:.2f}% |\n"
247
+ f"| **Threshold** | {threshold} |\n"
248
+ f"| **Checkpoint** | {ckpt_path.name if ckpt_path else 'N/A'} |\n"
249
+ f"| **Device** | {_device} |"
250
  )
251
 
252
  return binary_mask, overlay_rgb, summary
 
262
  Returns:
263
  A ``gr.Blocks`` application ready to ``.launch()``.
264
  """
265
+ available = _get_available_models()
266
+ all_models = list(_MODEL_CHECKPOINT_NAMES.keys())
267
 
268
+ # Show which models are available
269
+ status_lines = []
270
+ for m in all_models:
271
+ ckpt = _find_checkpoint(m)
272
+ if ckpt:
273
+ status_lines.append(f"- **{m}**: {ckpt.name}")
274
+ else:
275
+ status_lines.append(f"- **{m}**: not found")
276
+ model_status = "\n".join(status_lines)
277
+
278
+ default_model = available[0] if available else "changeformer"
279
+
280
+ with gr.Blocks(title="Military Base Change Detection") as demo:
281
 
282
  gr.Markdown(
283
  "# Military Base Change Detection\n"
284
  "Upload **before** and **after** satellite images to detect "
285
+ "construction, infrastructure changes, and runway development.\n\n"
286
+ "**Available models:**\n" + model_status
287
  )
288
 
289
  # ---- Inputs ---------------------------------------------------
290
  with gr.Row():
291
  with gr.Column(scale=1):
292
  before_img = gr.Image(
293
+ label="Before Image (older)",
294
  type="numpy",
295
  sources=["upload", "clipboard"],
296
  )
297
  with gr.Column(scale=1):
298
  after_img = gr.Image(
299
+ label="After Image (newer)",
300
  type="numpy",
301
  sources=["upload", "clipboard"],
302
  )
 
304
  # ---- Controls -------------------------------------------------
305
  with gr.Row():
306
  model_dropdown = gr.Dropdown(
307
+ choices=available if available else all_models,
308
+ value=default_model,
309
  label="Model Architecture",
310
  )
 
 
 
 
311
  threshold_slider = gr.Slider(
312
  minimum=0.1,
313
  maximum=0.9,
 
330
  # ---- Wiring ---------------------------------------------------
331
  detect_btn.click(
332
  fn=predict,
333
+ inputs=[before_img, after_img, model_dropdown, threshold_slider],
 
 
 
 
 
 
334
  outputs=[change_mask_out, overlay_out, summary_out],
335
  )
336
 
configs/config.yaml CHANGED
@@ -9,7 +9,7 @@ project:
9
 
10
  # --- Colab / runtime settings ---
11
  colab:
12
- enabled: true
13
  drive_root: "/content/drive/MyDrive/change-detection"
14
  checkpoint_dir: "/content/drive/MyDrive/change-detection/checkpoints"
15
  log_dir: "/content/drive/MyDrive/change-detection/logs"
@@ -139,5 +139,5 @@ epoch_counts:
139
  gradio:
140
  server_port: 7860
141
  share: false
142
- default_model: "unet_pp"
143
- default_checkpoint: "checkpoints/unet_pp_best.pth"
 
9
 
10
  # --- Colab / runtime settings ---
11
  colab:
12
+ enabled: false
13
  drive_root: "/content/drive/MyDrive/change-detection"
14
  checkpoint_dir: "/content/drive/MyDrive/change-detection/checkpoints"
15
  log_dir: "/content/drive/MyDrive/change-detection/logs"
 
139
  gradio:
140
  server_port: 7860
141
  share: false
142
+ default_model: "changeformer"
143
+ default_checkpoint: "checkpoints/changeformer_best.pth"