Vedant Jigarbhai Mehta commited on
Commit
5c53fad
·
1 Parent(s): 3ad9651

Implement inference pipeline and Gradio demo app

Browse files

inference.py: tiled sliding-window inference for any resolution,
reflection padding to patch-size multiples, binary mask + overlay
output, prints percentage of area changed.

app.py: Gradio Blocks UI with before/after uploads, model dropdown,
checkpoint picker, threshold slider. Returns change mask, red overlay,
and Markdown summary with change statistics. Model caching, CPU
fallback, defaults from config.yaml gradio section.

Files changed (2) hide show
  1. app.py +209 -93
  2. inference.py +194 -77
app.py CHANGED
@@ -1,7 +1,10 @@
1
- """Gradio web demo for change detection inference.
2
 
3
- Provides an interactive interface to upload before/after satellite image pairs
4
- and visualize predicted change masks with overlays.
 
 
 
5
 
6
  Usage:
7
  python app.py
@@ -9,7 +12,7 @@ Usage:
9
 
10
  import logging
11
  from pathlib import Path
12
- from typing import Optional, Tuple
13
 
14
  import cv2
15
  import gradio as gr
@@ -18,168 +21,281 @@ import torch
18
  import yaml
19
 
20
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
21
- from inference import preprocess_image, sliding_window_inference
22
  from models import get_model
23
- from utils.visualization import denormalize, overlay_changes
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
- # Global model cache
28
- _model: Optional[torch.nn.Module] = None
29
- _model_name: Optional[str] = None
 
 
 
 
30
  _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- _config = None
32
 
33
 
34
- def load_config() -> dict:
35
- """Load project config from YAML.
36
 
37
  Returns:
38
- Config dictionary.
39
  """
40
- config_path = Path("configs/config.yaml")
41
- with open(config_path, "r") as f:
42
- return yaml.safe_load(f)
 
 
 
43
 
44
 
45
- def load_model(model_name: str, checkpoint_path: str) -> torch.nn.Module:
46
- """Load a change detection model with caching.
47
 
48
  Args:
49
- model_name: Name of the model architecture.
50
- checkpoint_path: Path to the model checkpoint.
51
 
52
  Returns:
53
- Loaded model in eval mode.
 
 
 
54
  """
55
- global _model, _model_name, _config
56
 
57
- if _config is None:
58
- _config = load_config()
 
59
 
60
- if _model is not None and _model_name == model_name:
61
- return _model
 
 
62
 
63
- model = get_model(model_name, _config).to(_device)
64
- ckpt = torch.load(checkpoint_path, map_location=_device)
65
  model.load_state_dict(ckpt["model_state_dict"])
66
  model.eval()
67
 
68
- _model = model
69
- _model_name = model_name
70
- logger.info("Loaded model: %s from %s", model_name, checkpoint_path)
71
  return model
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def predict(
75
- before_image: np.ndarray,
76
- after_image: np.ndarray,
77
  model_name: str,
78
  checkpoint_path: str,
79
  threshold: float,
80
- ) -> Tuple[np.ndarray, np.ndarray]:
81
- """Run change detection on a pair of images.
82
 
83
  Args:
84
- before_image: Before image as numpy array (RGB, uint8).
85
- after_image: After image as numpy array (RGB, uint8).
86
- model_name: Model architecture name.
87
- checkpoint_path: Path to model weights.
88
- threshold: Binarization threshold.
89
 
90
  Returns:
91
- Tuple of (binary change mask, overlay visualization).
 
 
 
92
  """
93
- model = load_model(model_name, checkpoint_path)
94
- patch_size = 256
95
-
96
- # Preprocess both images
97
- def _to_tensor(img: np.ndarray) -> torch.Tensor:
98
- h, w = img.shape[:2]
99
- pad_h = (patch_size - h % patch_size) % patch_size
100
- pad_w = (patch_size - w % patch_size) % patch_size
101
- if pad_h > 0 or pad_w > 0:
102
- img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
103
- img_f = img.astype(np.float32) / 255.0
104
- mean = np.array(IMAGENET_MEAN, dtype=np.float32)
105
- std = np.array(IMAGENET_STD, dtype=np.float32)
106
- img_f = (img_f - mean) / std
107
- return torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float()
108
-
109
- orig_h, orig_w = before_image.shape[:2]
110
- tensor_a = _to_tensor(before_image)
111
- tensor_b = _to_tensor(after_image)
112
-
113
- # Run inference
114
  prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
115
  prob_map = prob_map[:, :, :orig_h, :orig_w]
 
116
 
117
- # Binary mask
118
- mask_np = prob_map.squeeze().numpy()
119
- binary_mask = (mask_np > threshold).astype(np.uint8) * 255
120
 
121
  # Overlay on after image
122
- overlay = after_image.copy().astype(np.float32) / 255.0
123
- change_pixels = mask_np > threshold
124
- overlay[change_pixels, 0] = np.clip(overlay[change_pixels, 0] * 0.6 + 0.4, 0, 1)
125
- overlay[change_pixels, 1] = overlay[change_pixels, 1] * 0.6
126
- overlay[change_pixels, 2] = overlay[change_pixels, 2] * 0.6
127
- overlay = (overlay * 255).astype(np.uint8)
 
 
128
 
129
- return binary_mask, overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
 
 
 
 
 
132
  def build_demo() -> gr.Blocks:
133
- """Build the Gradio demo interface.
134
 
135
  Returns:
136
- Gradio Blocks application.
137
  """
138
- config = load_config()
139
  gradio_cfg = config.get("gradio", {})
140
 
141
- with gr.Blocks(title="Military Base Change Detection") as demo:
142
- gr.Markdown("# Military Base Change Detection")
143
- gr.Markdown("Upload before/after satellite image pairs to detect construction and infrastructure changes.")
 
144
 
145
- with gr.Row():
146
- with gr.Column():
147
- before_img = gr.Image(label="Before Image", type="numpy")
148
- after_img = gr.Image(label="After Image", type="numpy")
149
- with gr.Column():
150
- change_mask = gr.Image(label="Change Mask")
151
- overlay_img = gr.Image(label="Overlay")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  with gr.Row():
154
  model_dropdown = gr.Dropdown(
155
  choices=["siamese_cnn", "unet_pp", "changeformer"],
156
  value=gradio_cfg.get("default_model", "unet_pp"),
157
- label="Model",
158
  )
159
  checkpoint_input = gr.Textbox(
160
  value=gradio_cfg.get("default_checkpoint", "checkpoints/unet_pp_best.pth"),
161
  label="Checkpoint Path",
162
  )
163
  threshold_slider = gr.Slider(
164
- minimum=0.1, maximum=0.9, value=0.5, step=0.05,
 
 
 
165
  label="Detection Threshold",
166
  )
167
 
168
- detect_btn = gr.Button("Detect Changes", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
169
  detect_btn.click(
170
  fn=predict,
171
- inputs=[before_img, after_img, model_dropdown, checkpoint_input, threshold_slider],
172
- outputs=[change_mask, overlay_img],
 
 
 
 
 
 
173
  )
174
 
175
  return demo
176
 
177
 
 
 
 
 
178
  def main() -> None:
179
- """Launch the Gradio demo."""
180
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
181
 
182
- config = load_config()
183
  gradio_cfg = config.get("gradio", {})
184
 
185
  demo = build_demo()
 
1
+ """Gradio web demo for satellite change detection.
2
 
3
+ Upload before/after satellite image pairs, select a model and checkpoint, and
4
+ view the predicted change mask, overlay, and change-area statistics.
5
+
6
+ Defaults (model, checkpoint, port, share) are read from the ``gradio`` section
7
+ of ``configs/config.yaml``.
8
 
9
  Usage:
10
  python app.py
 
12
 
13
  import logging
14
  from pathlib import Path
15
+ from typing import Any, Dict, Optional, Tuple
16
 
17
  import cv2
18
  import gradio as gr
 
21
  import yaml
22
 
23
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
24
+ from inference import load_and_preprocess, sliding_window_inference
25
  from models import get_model
26
+ from utils.visualization import overlay_changes
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Globals (model cache to avoid reloading on every prediction)
33
+ # ---------------------------------------------------------------------------
34
+
35
+ _cached_model: Optional[torch.nn.Module] = None
36
+ _cached_model_key: Optional[str] = None # "model_name::checkpoint_path"
37
  _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ _config: Optional[Dict[str, Any]] = None
39
 
40
 
41
+ def _load_config() -> Dict[str, Any]:
42
+ """Load and cache the project config.
43
 
44
  Returns:
45
+ Full config dict.
46
  """
47
+ global _config
48
+ if _config is None:
49
+ config_path = Path("configs/config.yaml")
50
+ with open(config_path, "r") as fh:
51
+ _config = yaml.safe_load(fh)
52
+ return _config
53
 
54
 
55
+ def _load_model(model_name: str, checkpoint_path: str) -> torch.nn.Module:
56
+ """Load a model, re-using the cache if name + checkpoint match.
57
 
58
  Args:
59
+ model_name: Architecture name (``siamese_cnn``, ``unet_pp``, ``changeformer``).
60
+ checkpoint_path: Path to the ``.pth`` checkpoint file.
61
 
62
  Returns:
63
+ Model in eval mode on the current device.
64
+
65
+ Raises:
66
+ FileNotFoundError: If the checkpoint does not exist.
67
  """
68
+ global _cached_model, _cached_model_key
69
 
70
+ cache_key = f"{model_name}::{checkpoint_path}"
71
+ if _cached_model is not None and _cached_model_key == cache_key:
72
+ return _cached_model
73
 
74
+ config = _load_config()
75
+ ckpt_path = Path(checkpoint_path)
76
+ if not ckpt_path.exists():
77
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
78
 
79
+ model = get_model(model_name, config).to(_device)
80
+ ckpt = torch.load(ckpt_path, map_location=_device)
81
  model.load_state_dict(ckpt["model_state_dict"])
82
  model.eval()
83
 
84
+ _cached_model = model
85
+ _cached_model_key = cache_key
86
+ logger.info("Loaded model %s from %s", model_name, checkpoint_path)
87
  return model
88
 
89
 
90
+ # ---------------------------------------------------------------------------
91
+ # Preprocessing helper (numpy RGB uint8 → tensor)
92
+ # ---------------------------------------------------------------------------
93
+
94
+ def _numpy_to_tensor(
95
+ img: np.ndarray,
96
+ patch_size: int = 256,
97
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
98
+ """Convert a uint8 RGB numpy image to a normalised, padded tensor.
99
+
100
+ Args:
101
+ img: Input image ``[H, W, 3]``, uint8, RGB.
102
+ patch_size: Pad to a multiple of this value.
103
+
104
+ Returns:
105
+ Tuple of ``(tensor [1, 3, H_pad, W_pad], (orig_h, orig_w))``.
106
+ """
107
+ orig_h, orig_w = img.shape[:2]
108
+
109
+ pad_h = (patch_size - orig_h % patch_size) % patch_size
110
+ pad_w = (patch_size - orig_w % patch_size) % patch_size
111
+ if pad_h > 0 or pad_w > 0:
112
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
113
+
114
+ img_f = img.astype(np.float32) / 255.0
115
+ mean = np.array(IMAGENET_MEAN, dtype=np.float32)
116
+ std = np.array(IMAGENET_STD, dtype=np.float32)
117
+ img_f = (img_f - mean) / std
118
+
119
+ tensor = torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float()
120
+ return tensor, (orig_h, orig_w)
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Prediction function (called by Gradio)
125
+ # ---------------------------------------------------------------------------
126
+
127
  def predict(
128
+ before_image: Optional[np.ndarray],
129
+ after_image: Optional[np.ndarray],
130
  model_name: str,
131
  checkpoint_path: str,
132
  threshold: float,
133
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
134
+ """Run change detection and return visualisations + summary text.
135
 
136
  Args:
137
+ before_image: Before image as numpy ``[H, W, 3]`` RGB uint8.
138
+ after_image: After image as numpy ``[H, W, 3]`` RGB uint8.
139
+ model_name: Architecture name.
140
+ checkpoint_path: Path to checkpoint file.
141
+ threshold: Binarisation threshold for predictions.
142
 
143
  Returns:
144
+ Tuple of ``(change_mask, overlay_image, summary_text)``.
145
+ - ``change_mask``: uint8 grayscale ``[H, W]`` (0 or 255).
146
+ - ``overlay_image``: uint8 RGB ``[H, W, 3]``.
147
+ - ``summary_text``: Markdown string with change statistics.
148
  """
149
+ if before_image is None or after_image is None:
150
+ return None, None, "Please upload both before and after images."
151
+
152
+ config = _load_config()
153
+ patch_size: int = config.get("dataset", {}).get("patch_size", 256)
154
+
155
+ # Load model
156
+ try:
157
+ model = _load_model(model_name, checkpoint_path)
158
+ except FileNotFoundError as exc:
159
+ return None, None, f"Error: {exc}"
160
+
161
+ # Preprocess
162
+ tensor_a, (orig_h, orig_w) = _numpy_to_tensor(before_image, patch_size)
163
+ tensor_b, _ = _numpy_to_tensor(after_image, patch_size)
164
+
165
+ # Tiled inference
 
 
 
 
166
  prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
167
  prob_map = prob_map[:, :, :orig_h, :orig_w]
168
+ prob_np = prob_map.squeeze().numpy() # [H, W]
169
 
170
+ # Binary change mask
171
+ binary_mask = (prob_np > threshold).astype(np.uint8) * 255
 
172
 
173
  # Overlay on after image
174
+ pred_tensor = (prob_map.squeeze(0) >= threshold).float() # [1, H, W]
175
+ img_b_tensor = tensor_b.squeeze()[:, :orig_h, :orig_w] # [3, H, W]
176
+ overlay_rgb = overlay_changes(
177
+ img_after=img_b_tensor,
178
+ mask_pred=pred_tensor,
179
+ alpha=0.4,
180
+ color=(255, 0, 0),
181
+ )
182
 
183
+ # Change statistics
184
+ total_pixels = orig_h * orig_w
185
+ changed_pixels = int(binary_mask.sum() // 255)
186
+ pct_changed = (changed_pixels / total_pixels) * 100.0
187
+
188
+ summary = (
189
+ f"### Change Detection Summary\n"
190
+ f"- **Image size**: {orig_w} x {orig_h}\n"
191
+ f"- **Total pixels**: {total_pixels:,}\n"
192
+ f"- **Changed pixels**: {changed_pixels:,}\n"
193
+ f"- **Area changed**: {pct_changed:.2f}%\n"
194
+ f"- **Model**: {model_name}\n"
195
+ f"- **Threshold**: {threshold}"
196
+ )
197
+
198
+ return binary_mask, overlay_rgb, summary
199
 
200
 
201
+ # ---------------------------------------------------------------------------
202
+ # Gradio UI
203
+ # ---------------------------------------------------------------------------
204
+
205
  def build_demo() -> gr.Blocks:
206
+ """Construct the Gradio Blocks interface.
207
 
208
  Returns:
209
+ A ``gr.Blocks`` application ready to ``.launch()``.
210
  """
211
+ config = _load_config()
212
  gradio_cfg = config.get("gradio", {})
213
 
214
+ with gr.Blocks(
215
+ title="Military Base Change Detection",
216
+ theme=gr.themes.Soft(),
217
+ ) as demo:
218
 
219
+ gr.Markdown(
220
+ "# Military Base Change Detection\n"
221
+ "Upload **before** and **after** satellite images to detect "
222
+ "construction, infrastructure changes, and runway development."
223
+ )
 
 
224
 
225
+ # ---- Inputs ---------------------------------------------------
226
+ with gr.Row():
227
+ with gr.Column(scale=1):
228
+ before_img = gr.Image(
229
+ label="Before Image",
230
+ type="numpy",
231
+ sources=["upload", "clipboard"],
232
+ )
233
+ with gr.Column(scale=1):
234
+ after_img = gr.Image(
235
+ label="After Image",
236
+ type="numpy",
237
+ sources=["upload", "clipboard"],
238
+ )
239
+
240
+ # ---- Controls -------------------------------------------------
241
  with gr.Row():
242
  model_dropdown = gr.Dropdown(
243
  choices=["siamese_cnn", "unet_pp", "changeformer"],
244
  value=gradio_cfg.get("default_model", "unet_pp"),
245
+ label="Model Architecture",
246
  )
247
  checkpoint_input = gr.Textbox(
248
  value=gradio_cfg.get("default_checkpoint", "checkpoints/unet_pp_best.pth"),
249
  label="Checkpoint Path",
250
  )
251
  threshold_slider = gr.Slider(
252
+ minimum=0.1,
253
+ maximum=0.9,
254
+ value=0.5,
255
+ step=0.05,
256
  label="Detection Threshold",
257
  )
258
 
259
+ detect_btn = gr.Button("Detect Changes", variant="primary", size="lg")
260
+
261
+ # ---- Outputs --------------------------------------------------
262
+ with gr.Row():
263
+ with gr.Column(scale=1):
264
+ change_mask_out = gr.Image(label="Change Mask")
265
+ with gr.Column(scale=1):
266
+ overlay_out = gr.Image(label="Overlay (changes in red)")
267
+
268
+ summary_out = gr.Markdown(label="Summary")
269
+
270
+ # ---- Wiring ---------------------------------------------------
271
  detect_btn.click(
272
  fn=predict,
273
+ inputs=[
274
+ before_img,
275
+ after_img,
276
+ model_dropdown,
277
+ checkpoint_input,
278
+ threshold_slider,
279
+ ],
280
+ outputs=[change_mask_out, overlay_out, summary_out],
281
  )
282
 
283
  return demo
284
 
285
 
286
+ # ---------------------------------------------------------------------------
287
+ # Entry point
288
+ # ---------------------------------------------------------------------------
289
+
290
  def main() -> None:
291
+ """Launch the Gradio demo server."""
292
+ logging.basicConfig(
293
+ level=logging.INFO,
294
+ format="%(asctime)s [%(levelname)s] %(message)s",
295
+ datefmt="%Y-%m-%d %H:%M:%S",
296
+ )
297
 
298
+ config = _load_config()
299
  gradio_cfg = config.get("gradio", {})
300
 
301
  demo = build_demo()
inference.py CHANGED
@@ -1,11 +1,16 @@
1
- """Run inference on arbitrary before/after image pairs.
2
 
3
- Loads a trained change detection model and produces binary change masks
4
- for new satellite image pairs.
 
 
5
 
6
  Usage:
7
  python inference.py --before path/to/before.png --after path/to/after.png \
8
  --model changeformer --checkpoint checkpoints/changeformer_best.pth
 
 
 
9
  """
10
 
11
  import argparse
@@ -17,55 +22,67 @@ import cv2
17
  import numpy as np
18
  import torch
19
  import torch.nn as nn
20
- import torch.nn.functional as F
21
  import yaml
22
 
23
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
24
  from models import get_model
25
- from utils.visualization import overlay_changes, plot_prediction
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
- def preprocess_image(
 
 
 
 
31
  image_path: Path,
32
  patch_size: int = 256,
33
  ) -> Tuple[torch.Tensor, Tuple[int, int]]:
34
- """Load and preprocess a single image for inference.
35
-
36
- Reads the image, pads to a multiple of patch_size, and applies
37
- ImageNet normalization.
38
 
39
  Args:
40
- image_path: Path to the input image.
41
- patch_size: Patch size the model expects.
42
 
43
  Returns:
44
- Tuple of (preprocessed tensor [1, 3, H, W], original (H, W)).
 
 
 
 
 
45
  """
46
  img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
47
  if img is None:
48
  raise FileNotFoundError(f"Could not read image: {image_path}")
49
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
50
  orig_h, orig_w = img.shape[:2]
 
51
 
52
- # Pad to multiple of patch_size
53
  pad_h = (patch_size - orig_h % patch_size) % patch_size
54
  pad_w = (patch_size - orig_w % patch_size) % patch_size
55
  if pad_h > 0 or pad_w > 0:
56
  img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
57
 
58
- # Normalize
59
  img = img.astype(np.float32) / 255.0
60
  mean = np.array(IMAGENET_MEAN, dtype=np.float32)
61
  std = np.array(IMAGENET_STD, dtype=np.float32)
62
  img = (img - mean) / std
63
 
64
- # HWC -> CHW, add batch dim
65
  tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
66
  return tensor, (orig_h, orig_w)
67
 
68
 
 
 
 
 
 
69
  def sliding_window_inference(
70
  model: nn.Module,
71
  img_a: torch.Tensor,
@@ -73,103 +90,203 @@ def sliding_window_inference(
73
  patch_size: int = 256,
74
  device: torch.device = torch.device("cpu"),
75
  ) -> torch.Tensor:
76
- """Run inference using sliding window for large images.
77
 
78
- Splits images into non-overlapping patches, runs model on each,
79
- and stitches results back together.
80
 
81
  Args:
82
- model: Trained change detection model.
83
- img_a: Before image tensor [1, 3, H, W].
84
- img_b: After image tensor [1, 3, H, W].
85
- patch_size: Size of each patch.
86
- device: Inference device.
87
 
88
  Returns:
89
- Probability map [1, 1, H, W] (after sigmoid).
 
90
  """
 
91
  _, _, h, w = img_a.shape
92
- output = torch.zeros(1, 1, h, w, device="cpu")
93
 
94
- model.eval()
95
- with torch.no_grad():
96
- for y in range(0, h, patch_size):
97
- for x in range(0, w, patch_size):
98
- patch_a = img_a[:, :, y:y + patch_size, x:x + patch_size].to(device)
99
- patch_b = img_b[:, :, y:y + patch_size, x:x + patch_size].to(device)
 
 
 
 
 
100
 
101
- logits = model(patch_a, patch_b)
102
- probs = torch.sigmoid(logits).cpu()
103
- output[:, :, y:y + patch_size, x:x + patch_size] = probs
104
 
 
105
  return output
106
 
107
 
108
- def save_change_mask(
109
- mask: np.ndarray,
 
 
 
 
110
  save_path: Path,
111
  threshold: float = 0.5,
112
  ) -> None:
113
- """Save binary change mask as an image.
114
 
115
  Args:
116
- mask: Probability map [H, W] with values in [0, 1].
117
- save_path: Output file path.
118
- threshold: Binarization threshold.
119
  """
120
- binary = (mask > threshold).astype(np.uint8) * 255
121
  save_path.parent.mkdir(parents=True, exist_ok=True)
122
  cv2.imwrite(str(save_path), binary)
123
- logger.info("Saved change mask: %s", save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
 
 
 
 
126
  def main() -> None:
127
- """Main inference entry point."""
128
- parser = argparse.ArgumentParser(description="Run change detection inference")
129
- parser.add_argument("--before", type=Path, required=True, help="Path to before image")
130
- parser.add_argument("--after", type=Path, required=True, help="Path to after image")
131
- parser.add_argument("--model", type=str, default=None, help="Model name")
132
- parser.add_argument("--checkpoint", type=Path, required=True, help="Path to model checkpoint")
133
- parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
134
- parser.add_argument("--output", type=Path, default=Path("outputs/inference"))
135
- parser.add_argument("--threshold", type=float, default=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  args = parser.parse_args()
137
 
138
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
 
 
 
 
139
 
140
- # Load config
141
- with open(args.config, "r") as f:
142
- config = yaml.safe_load(f)
143
 
144
- model_name = args.model or config["model"]["name"]
145
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
- patch_size = config.get("dataset", {}).get("patch_size", 256)
147
 
148
- # Load model
149
  model = get_model(model_name, config).to(device)
150
  ckpt = torch.load(args.checkpoint, map_location=device)
151
  model.load_state_dict(ckpt["model_state_dict"])
152
- logger.info("Loaded model '%s' from %s", model_name, args.checkpoint)
153
-
154
- # Preprocess images
155
- img_a, (orig_h, orig_w) = preprocess_image(args.before, patch_size)
156
- img_b, _ = preprocess_image(args.after, patch_size)
157
-
158
- # Run inference
 
 
 
 
 
 
 
 
 
 
159
  prob_map = sliding_window_inference(model, img_a, img_b, patch_size, device)
160
 
161
- # Crop back to original size and save
162
  prob_map = prob_map[:, :, :orig_h, :orig_w]
163
- mask_np = prob_map.squeeze().numpy()
164
-
165
- args.output.mkdir(parents=True, exist_ok=True)
166
- save_change_mask(mask_np, args.output / "change_mask.png", args.threshold)
167
-
168
- # Save overlay visualization
169
- overlay = overlay_changes(img_b.squeeze()[:, :orig_h, :orig_w], prob_map.squeeze(0))
170
- overlay_uint8 = (overlay * 255).astype(np.uint8)
171
- cv2.imwrite(str(args.output / "overlay.png"), cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR))
172
- logger.info("Saved overlay: %s", args.output / "overlay.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
 
175
  if __name__ == "__main__":
 
1
+ """Run change-detection inference on arbitrary before/after image pairs.
2
 
3
+ Handles images of any resolution by tiling into 256x256 patches, running the
4
+ model on each patch, and stitching the probability map back together. Outputs
5
+ a binary change mask PNG, an overlay visualisation, and prints the percentage
6
+ of area changed.
7
 
8
  Usage:
9
  python inference.py --before path/to/before.png --after path/to/after.png \
10
  --model changeformer --checkpoint checkpoints/changeformer_best.pth
11
+
12
+ python inference.py --before big_before.tif --after big_after.tif \
13
+ --checkpoint checkpoints/unet_pp_best.pth --output results/
14
  """
15
 
16
  import argparse
 
22
  import numpy as np
23
  import torch
24
  import torch.nn as nn
 
25
  import yaml
26
 
27
  from data.dataset import IMAGENET_MEAN, IMAGENET_STD
28
  from models import get_model
29
+ from utils.visualization import overlay_changes
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
 
34
+ # ---------------------------------------------------------------------------
35
+ # Image preprocessing
36
+ # ---------------------------------------------------------------------------
37
+
38
+ def load_and_preprocess(
39
  image_path: Path,
40
  patch_size: int = 256,
41
  ) -> Tuple[torch.Tensor, Tuple[int, int]]:
42
+ """Load an image from disk, pad to a patch-size multiple, and normalise.
 
 
 
43
 
44
  Args:
45
+ image_path: Path to the input image (any format OpenCV supports).
46
+ patch_size: Spatial size the model expects per patch.
47
 
48
  Returns:
49
+ Tuple of ``(tensor, original_size)`` where tensor has shape
50
+ ``[1, 3, H_padded, W_padded]`` and ``original_size`` is
51
+ ``(orig_h, orig_w)`` before padding.
52
+
53
+ Raises:
54
+ FileNotFoundError: If the image cannot be read.
55
  """
56
  img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
57
  if img is None:
58
  raise FileNotFoundError(f"Could not read image: {image_path}")
59
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
60
+
61
  orig_h, orig_w = img.shape[:2]
62
+ logger.info("Loaded %s (%d x %d)", image_path.name, orig_w, orig_h)
63
 
64
+ # Pad to the nearest multiple of patch_size using reflection
65
  pad_h = (patch_size - orig_h % patch_size) % patch_size
66
  pad_w = (patch_size - orig_w % patch_size) % patch_size
67
  if pad_h > 0 or pad_w > 0:
68
  img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
69
 
70
+ # uint8 → float32 [0,1] → ImageNet normalisation
71
  img = img.astype(np.float32) / 255.0
72
  mean = np.array(IMAGENET_MEAN, dtype=np.float32)
73
  std = np.array(IMAGENET_STD, dtype=np.float32)
74
  img = (img - mean) / std
75
 
76
+ # HWC CHW, add batch dim
77
  tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
78
  return tensor, (orig_h, orig_w)
79
 
80
 
81
+ # ---------------------------------------------------------------------------
82
+ # Tiled (sliding-window) inference
83
+ # ---------------------------------------------------------------------------
84
+
85
+ @torch.no_grad()
86
  def sliding_window_inference(
87
  model: nn.Module,
88
  img_a: torch.Tensor,
 
90
  patch_size: int = 256,
91
  device: torch.device = torch.device("cpu"),
92
  ) -> torch.Tensor:
93
+ """Run inference by tiling large images into non-overlapping patches.
94
 
95
+ Each patch pair is fed through the model independently; the resulting
96
+ probability maps are stitched back into a single full-resolution output.
97
 
98
  Args:
99
+ model: Trained change-detection model (set to eval internally).
100
+ img_a: Before image ``[1, 3, H, W]`` (padded to patch-size multiples).
101
+ img_b: After image ``[1, 3, H, W]`` (same spatial size as ``img_a``).
102
+ patch_size: Tile size in pixels.
103
+ device: Inference device (CUDA or CPU).
104
 
105
  Returns:
106
+ Probability map ``[1, 1, H, W]`` with values in ``[0, 1]`` (after
107
+ sigmoid), on CPU.
108
  """
109
+ model.eval()
110
  _, _, h, w = img_a.shape
111
+ output = torch.zeros(1, 1, h, w)
112
 
113
+ n_tiles = (h // patch_size) * (w // patch_size)
114
+ tile_idx = 0
115
+
116
+ for y in range(0, h, patch_size):
117
+ for x in range(0, w, patch_size):
118
+ patch_a = img_a[:, :, y:y + patch_size, x:x + patch_size].to(device)
119
+ patch_b = img_b[:, :, y:y + patch_size, x:x + patch_size].to(device)
120
+
121
+ logits = model(patch_a, patch_b)
122
+ probs = torch.sigmoid(logits).cpu()
123
+ output[:, :, y:y + patch_size, x:x + patch_size] = probs
124
 
125
+ tile_idx += 1
 
 
126
 
127
+ logger.info("Inference complete: %d tiles processed", n_tiles)
128
  return output
129
 
130
 
131
+ # ---------------------------------------------------------------------------
132
+ # Output helpers
133
+ # ---------------------------------------------------------------------------
134
+
135
+ def save_binary_mask(
136
+ prob_map: np.ndarray,
137
  save_path: Path,
138
  threshold: float = 0.5,
139
  ) -> None:
140
+ """Binarise a probability map and save as a PNG.
141
 
142
  Args:
143
+ prob_map: Probability values ``[H, W]`` in ``[0, 1]``.
144
+ save_path: Destination file path.
145
+ threshold: Decision threshold.
146
  """
147
+ binary = (prob_map > threshold).astype(np.uint8) * 255
148
  save_path.parent.mkdir(parents=True, exist_ok=True)
149
  cv2.imwrite(str(save_path), binary)
150
+ logger.info("Saved binary mask: %s", save_path)
151
+
152
+
153
+ def save_overlay(
154
+ img_b_tensor: torch.Tensor,
155
+ pred_tensor: torch.Tensor,
156
+ save_path: Path,
157
+ threshold: float = 0.5,
158
+ ) -> None:
159
+ """Create and save an overlay visualisation.
160
+
161
+ Args:
162
+ img_b_tensor: After image ``[3, H, W]`` (ImageNet-normalised).
163
+ pred_tensor: Prediction mask ``[1, H, W]`` (probability).
164
+ save_path: Destination file path.
165
+ threshold: Binarisation threshold applied before overlay.
166
+ """
167
+ binary_pred = (pred_tensor >= threshold).float()
168
+ overlay_rgb = overlay_changes(
169
+ img_after=img_b_tensor,
170
+ mask_pred=binary_pred,
171
+ alpha=0.4,
172
+ color=(255, 0, 0),
173
+ )
174
+ save_path.parent.mkdir(parents=True, exist_ok=True)
175
+ cv2.imwrite(str(save_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
176
+ logger.info("Saved overlay: %s", save_path)
177
 
178
 
179
+ # ---------------------------------------------------------------------------
180
+ # Main
181
+ # ---------------------------------------------------------------------------
182
+
183
  def main() -> None:
184
+ """Entry point — parse CLI args, run inference, save outputs."""
185
+ parser = argparse.ArgumentParser(
186
+ description="Run change-detection inference on a before/after image pair",
187
+ )
188
+ parser.add_argument(
189
+ "--before", type=Path, required=True,
190
+ help="Path to the *before* image.",
191
+ )
192
+ parser.add_argument(
193
+ "--after", type=Path, required=True,
194
+ help="Path to the *after* image.",
195
+ )
196
+ parser.add_argument(
197
+ "--model", type=str, default=None,
198
+ help="Model name (overrides config). One of: siamese_cnn, unet_pp, changeformer.",
199
+ )
200
+ parser.add_argument(
201
+ "--checkpoint", type=Path, required=True,
202
+ help="Path to the model checkpoint (.pth).",
203
+ )
204
+ parser.add_argument(
205
+ "--config", type=Path, default=Path("configs/config.yaml"),
206
+ help="Path to the YAML configuration file.",
207
+ )
208
+ parser.add_argument(
209
+ "--output", type=Path, default=Path("outputs/inference"),
210
+ help="Output directory for results.",
211
+ )
212
+ parser.add_argument(
213
+ "--threshold", type=float, default=None,
214
+ help="Binarisation threshold (default: from config).",
215
+ )
216
  args = parser.parse_args()
217
 
218
+ logging.basicConfig(
219
+ level=logging.INFO,
220
+ format="%(asctime)s [%(levelname)s] %(message)s",
221
+ datefmt="%Y-%m-%d %H:%M:%S",
222
+ )
223
+
224
+ # ---- Config -------------------------------------------------------
225
+ with open(args.config, "r") as fh:
226
+ config: Dict[str, Any] = yaml.safe_load(fh)
227
 
228
+ model_name: str = args.model or config["model"]["name"]
229
+ threshold: float = args.threshold or config.get("evaluation", {}).get("threshold", 0.5)
230
+ patch_size: int = config.get("dataset", {}).get("patch_size", 256)
231
 
 
232
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
+ logger.info("Device: %s | Model: %s | Threshold: %.2f", device, model_name, threshold)
234
 
235
+ # ---- Load model ---------------------------------------------------
236
  model = get_model(model_name, config).to(device)
237
  ckpt = torch.load(args.checkpoint, map_location=device)
238
  model.load_state_dict(ckpt["model_state_dict"])
239
+ logger.info(
240
+ "Loaded checkpoint: %s (epoch %d)",
241
+ args.checkpoint, ckpt.get("epoch", -1),
242
+ )
243
+
244
+ # ---- Preprocess images --------------------------------------------
245
+ img_a, (orig_h, orig_w) = load_and_preprocess(args.before, patch_size)
246
+ img_b, (orig_h_b, orig_w_b) = load_and_preprocess(args.after, patch_size)
247
+
248
+ if (orig_h, orig_w) != (orig_h_b, orig_w_b):
249
+ logger.warning(
250
+ "Image sizes differ: before=(%d,%d) after=(%d,%d). "
251
+ "Using before dimensions for cropping.",
252
+ orig_h, orig_w, orig_h_b, orig_w_b,
253
+ )
254
+
255
+ # ---- Run tiled inference ------------------------------------------
256
  prob_map = sliding_window_inference(model, img_a, img_b, patch_size, device)
257
 
258
+ # Crop back to original resolution (remove padding)
259
  prob_map = prob_map[:, :, :orig_h, :orig_w]
260
+ prob_np = prob_map.squeeze().numpy() # [H, W]
261
+
262
+ # ---- Compute change statistics ------------------------------------
263
+ binary_np = (prob_np > threshold).astype(np.float32)
264
+ total_pixels = orig_h * orig_w
265
+ changed_pixels = int(binary_np.sum())
266
+ pct_changed = (changed_pixels / total_pixels) * 100.0
267
+
268
+ logger.info("=" * 50)
269
+ logger.info(" CHANGE DETECTION RESULTS")
270
+ logger.info("=" * 50)
271
+ logger.info(" Image size : %d x %d", orig_w, orig_h)
272
+ logger.info(" Total pixels : %d", total_pixels)
273
+ logger.info(" Changed pixels : %d", changed_pixels)
274
+ logger.info(" Area changed : %.2f%%", pct_changed)
275
+ logger.info("=" * 50)
276
+
277
+ # ---- Save outputs -------------------------------------------------
278
+ output_dir = Path(args.output)
279
+ output_dir.mkdir(parents=True, exist_ok=True)
280
+
281
+ # Binary change mask
282
+ save_binary_mask(prob_np, output_dir / "change_mask.png", threshold)
283
+
284
+ # Overlay visualisation
285
+ img_b_cropped = img_b.squeeze()[:, :orig_h, :orig_w] # [3, H, W]
286
+ pred_cropped = prob_map.squeeze(0)[:, :orig_h, :orig_w] # [1, H, W]
287
+ save_overlay(img_b_cropped, pred_cropped, output_dir / "overlay.png", threshold)
288
+
289
+ logger.info("All outputs saved to: %s", output_dir)
290
 
291
 
292
  if __name__ == "__main__":