Vedant Jigarbhai Mehta commited on
Commit
3ad9651
·
1 Parent(s): 0cbf4d6

Implement full test-set evaluation with metrics, visualizations, and overlays

Browse files

- MetricTracker for F1/IoU/Precision/Recall/OA on raw logits
- results.json with all metrics and metadata
- 5x4 prediction grid (Before|After|GT|Pred) + 20 individual plots
- Top-10 overlay images ranked by predicted change area
- Auto eval batch size at 2x training (no gradients needed)
- Colab vs local path resolution, formatted console metrics table

Files changed (1) hide show
  1. evaluate.py +372 -77
evaluate.py CHANGED
@@ -1,17 +1,30 @@
1
- """Evaluation script for change detection models.
2
 
3
- Runs a trained model on the test set, computes all metrics, and generates
4
- visualization outputs.
 
 
5
 
6
  Usage:
7
- python evaluate.py --config configs/config.yaml --checkpoint checkpoints/unet_pp_best.pth
 
 
 
 
 
8
  """
9
 
10
  import argparse
 
11
  import logging
 
12
  from pathlib import Path
13
- from typing import Any, Dict
14
 
 
 
 
 
15
  import torch
16
  import torch.nn as nn
17
  from torch.utils.data import DataLoader
@@ -20,115 +33,397 @@ import yaml
20
 
21
  from data.dataset import ChangeDetectionDataset
22
  from models import get_model
23
- from utils.metrics import ConfusionMatrix
24
- from utils.visualization import plot_prediction
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
- def evaluate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  model: nn.Module,
31
  loader: DataLoader,
32
  device: torch.device,
33
- threshold: float = 0.5,
34
- output_dir: Path = Path("./outputs"),
35
- max_vis: int = 20,
36
- ) -> Dict[str, float]:
37
- """Evaluate model on the full test set.
38
 
39
  Args:
40
- model: Trained change detection model.
41
- loader: Test DataLoader.
42
  device: Target device.
43
- threshold: Binarization threshold for predictions.
44
- output_dir: Directory to save visualization outputs.
45
- max_vis: Maximum number of sample predictions to save.
46
 
47
  Returns:
48
- Dict of metric name -> value.
 
 
 
49
  """
50
  model.eval()
51
- cm = ConfusionMatrix()
52
- vis_dir = output_dir / "visualizations"
53
- vis_dir.mkdir(parents=True, exist_ok=True)
54
- vis_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- with torch.no_grad():
57
- for batch in tqdm(loader, desc="Evaluating"):
58
- img_a = batch["A"].to(device)
59
- img_b = batch["B"].to(device)
60
- mask = batch["mask"].to(device)
61
 
62
- logits = model(img_a, img_b)
63
- preds = (torch.sigmoid(logits) > threshold).float()
64
- cm.update(preds, mask)
65
 
66
- # Save sample visualizations
67
- if vis_count < max_vis:
68
- for i in range(min(img_a.size(0), max_vis - vis_count)):
69
- plot_prediction(
70
- img_a[i], img_b[i], mask[i], preds[i],
71
- save_path=vis_dir / f"sample_{vis_count:04d}.png",
72
- )
73
- vis_count += 1
74
 
75
- metrics = cm.compute()
76
- return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def main() -> None:
80
- """Main evaluation entry point."""
81
- parser = argparse.ArgumentParser(description="Evaluate change detection model")
82
- parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
83
- parser.add_argument("--checkpoint", type=Path, required=True)
84
- parser.add_argument("--model", type=str, default=None, help="Override model name")
85
- parser.add_argument("--threshold", type=float, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  args = parser.parse_args()
87
 
88
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
89
 
90
- with open(args.config, "r") as f:
91
- config = yaml.safe_load(f)
 
 
 
 
92
 
93
- model_name = args.model or config["model"]["name"]
94
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
- threshold = args.threshold or config.get("evaluation", {}).get("threshold", 0.5)
96
-
97
- # Resolve paths
98
- colab = config.get("colab", {})
99
- if colab.get("enabled", False):
100
- data_dir = Path(colab["data_dir"])
101
- output_dir = Path(colab["output_dir"])
102
- else:
103
- data_dir = Path(config["paths"]["processed_data"])
104
- output_dir = Path(config["paths"]["output_dir"])
105
-
106
- # Model
107
  model = get_model(model_name, config).to(device)
108
  ckpt = torch.load(args.checkpoint, map_location=device)
109
  model.load_state_dict(ckpt["model_state_dict"])
110
- logger.info("Loaded checkpoint: %s (epoch %d, F1 %.4f)",
111
- args.checkpoint, ckpt.get("epoch", -1), ckpt.get("best_f1", -1))
112
 
113
- # Test data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  ds_cfg = config.get("dataset", {})
115
- test_ds = ChangeDetectionDataset(data_dir / "test", split="test", config=config)
 
 
116
  test_loader = DataLoader(
117
- test_ds, batch_size=8, shuffle=False,
 
 
118
  num_workers=ds_cfg.get("num_workers", 4),
119
  pin_memory=ds_cfg.get("pin_memory", True),
120
  )
 
 
 
 
121
 
122
- # Evaluate
123
- metrics = evaluate(model, test_loader, device, threshold, output_dir)
 
124
 
125
- # Print results
126
- logger.info("=" * 50)
127
- logger.info("TEST SET RESULTS — %s", model_name)
128
- logger.info("=" * 50)
129
- for name, value in metrics.items():
130
- logger.info(" %-12s: %.4f", name.upper(), value)
131
- logger.info("=" * 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  if __name__ == "__main__":
 
1
+ """Evaluate a trained change-detection model on the test set.
2
 
3
+ Computes all metrics (F1, IoU, Precision, Recall, OA), saves a
4
+ ``results.json``, generates a 20-sample prediction grid, and produces
5
+ overlay images for the top-10 predictions with the largest predicted
6
+ change area.
7
 
8
  Usage:
9
+ python evaluate.py --config configs/config.yaml \
10
+ --checkpoint checkpoints/unet_pp_best.pth
11
+
12
+ python evaluate.py --config configs/config.yaml \
13
+ --checkpoint checkpoints/changeformer_best.pth \
14
+ --model changeformer --output_dir ./my_outputs
15
  """
16
 
17
  import argparse
18
+ import json
19
  import logging
20
+ import time
21
  from pathlib import Path
22
+ from typing import Any, Dict, List, Tuple
23
 
24
+ import matplotlib
25
+ matplotlib.use("Agg")
26
+ import matplotlib.pyplot as plt
27
+ import numpy as np
28
  import torch
29
  import torch.nn as nn
30
  from torch.utils.data import DataLoader
 
33
 
34
  from data.dataset import ChangeDetectionDataset
35
  from models import get_model
36
+ from utils.metrics import MetricTracker
37
+ from utils.visualization import overlay_changes, plot_prediction
38
 
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
+ # ---------------------------------------------------------------------------
43
+ # GPU / batch-size helpers
44
+ # ---------------------------------------------------------------------------
45
+
46
+ def _detect_gpu_type() -> str:
47
+ """Detect the current GPU type for batch-size selection.
48
+
49
+ Returns:
50
+ One of ``'T4'``, ``'V100'``, or ``'default'``.
51
+ """
52
+ if not torch.cuda.is_available():
53
+ return "default"
54
+ name = torch.cuda.get_device_name(0).upper()
55
+ if "T4" in name:
56
+ return "T4"
57
+ elif "V100" in name:
58
+ return "V100"
59
+ return "default"
60
+
61
+
62
+ def get_train_batch_size(config: Dict[str, Any], model_name: str) -> int:
63
+ """Look up the *training* batch size for the current GPU + model.
64
+
65
+ Args:
66
+ config: Full project config dict.
67
+ model_name: Model identifier string.
68
+
69
+ Returns:
70
+ Training batch size as an integer.
71
+ """
72
+ gpu_type = _detect_gpu_type()
73
+ model_sizes = config.get("batch_sizes", {}).get(model_name, {})
74
+ return model_sizes.get(gpu_type, model_sizes.get("default", 4))
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Path resolution (same logic as train.py)
79
+ # ---------------------------------------------------------------------------
80
+
81
+ def resolve_paths(config: Dict[str, Any]) -> Dict[str, Path]:
82
+ """Build a path dict based on whether Colab mode is enabled.
83
+
84
+ Args:
85
+ config: Full project config dict.
86
+
87
+ Returns:
88
+ Dict with keys ``'data'``, ``'checkpoints'``, ``'logs'``,
89
+ ``'outputs'``.
90
+ """
91
+ if config.get("colab", {}).get("enabled", False):
92
+ c = config["colab"]
93
+ return {
94
+ "data": Path(c["data_dir"]),
95
+ "checkpoints": Path(c["checkpoint_dir"]),
96
+ "logs": Path(c["log_dir"]),
97
+ "outputs": Path(c["output_dir"]),
98
+ }
99
+
100
+ p = config.get("paths", {})
101
+ return {
102
+ "data": Path(p.get("processed_data", "./processed_data")),
103
+ "checkpoints": Path(p.get("checkpoint_dir", "./checkpoints")),
104
+ "logs": Path(p.get("log_dir", "./logs")),
105
+ "outputs": Path(p.get("output_dir", "./outputs")),
106
+ }
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Evaluation pass
111
+ # ---------------------------------------------------------------------------
112
+
113
+ @torch.no_grad()
114
+ def run_evaluation(
115
  model: nn.Module,
116
  loader: DataLoader,
117
  device: torch.device,
118
+ tracker: MetricTracker,
119
+ ) -> Tuple[Dict[str, float], List[Dict[str, torch.Tensor]]]:
120
+ """Run inference on the full test set and collect per-sample data.
 
 
121
 
122
  Args:
123
+ model: Trained change-detection model (set to eval internally).
124
+ loader: Test ``DataLoader``.
125
  device: Target device.
126
+ tracker: ``MetricTracker`` (reset externally before this call).
 
 
127
 
128
  Returns:
129
+ Tuple of ``(metrics_dict, samples_list)``.
130
+ Each entry in ``samples_list`` is a dict with keys
131
+ ``'A'``, ``'B'``, ``'mask'``, ``'pred'``, ``'change_area'``
132
+ (all single-sample tensors on CPU).
133
  """
134
  model.eval()
135
+ all_samples: List[Dict[str, Any]] = []
136
+
137
+ for batch in tqdm(loader, desc="Evaluating", dynamic_ncols=True):
138
+ img_a = batch["A"].to(device, non_blocking=True)
139
+ img_b = batch["B"].to(device, non_blocking=True)
140
+ mask = batch["mask"].to(device, non_blocking=True)
141
+
142
+ logits = model(img_a, img_b)
143
+ tracker.update(logits, mask)
144
+
145
+ preds = (torch.sigmoid(logits) >= tracker.threshold).float()
146
+
147
+ # Store each sample for later visualisation / ranking
148
+ for i in range(img_a.size(0)):
149
+ pred_i = preds[i].cpu()
150
+ change_area = pred_i.sum().item()
151
+ all_samples.append({
152
+ "A": img_a[i].cpu(),
153
+ "B": img_b[i].cpu(),
154
+ "mask": mask[i].cpu(),
155
+ "pred": pred_i,
156
+ "change_area": change_area,
157
+ })
158
+
159
+ metrics = tracker.compute()
160
+ return metrics, all_samples
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Visualisation helpers
165
+ # ---------------------------------------------------------------------------
166
+
167
+ def save_prediction_grid(
168
+ samples: List[Dict[str, torch.Tensor]],
169
+ save_path: Path,
170
+ num_rows: int = 5,
171
+ ) -> None:
172
+ """Save a grid of sample predictions (Before | After | GT | Pred).
173
+
174
+ Args:
175
+ samples: List of per-sample dicts from ``run_evaluation``.
176
+ save_path: Destination image path.
177
+ num_rows: Number of rows in the grid (4 columns each).
178
+ """
179
+ num_samples = min(num_rows, len(samples))
180
+ fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
181
+
182
+ if num_samples == 1:
183
+ axes = axes[np.newaxis, :]
184
+
185
+ from utils.visualization import _denorm_tensor, _mask_to_numpy
186
+
187
+ col_titles = ["Before (A)", "After (B)", "Ground Truth", "Prediction"]
188
+
189
+ for row in range(num_samples):
190
+ s = samples[row]
191
+ images = [
192
+ _denorm_tensor(s["A"]),
193
+ _denorm_tensor(s["B"]),
194
+ _mask_to_numpy(s["mask"]),
195
+ (_mask_to_numpy(s["pred"]) > 0.5).astype(np.float32),
196
+ ]
197
+ cmaps = [None, None, "gray", "gray"]
198
+
199
+ for col in range(4):
200
+ ax = axes[row, col]
201
+ ax.imshow(images[col], cmap=cmaps[col], vmin=0, vmax=1)
202
+ ax.axis("off")
203
+ if row == 0:
204
+ ax.set_title(col_titles[col], fontsize=12)
205
 
206
+ fig.tight_layout(pad=1.0)
207
+ save_path.parent.mkdir(parents=True, exist_ok=True)
208
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
209
+ plt.close(fig)
210
+ logger.info("Saved prediction grid (%d samples): %s", num_samples, save_path)
211
 
 
 
 
212
 
213
+ def save_top_overlays(
214
+ samples: List[Dict[str, torch.Tensor]],
215
+ output_dir: Path,
216
+ top_k: int = 10,
217
+ ) -> None:
218
+ """Save overlay images for the top-K predictions by predicted change area.
 
 
219
 
220
+ Args:
221
+ samples: List of per-sample dicts from ``run_evaluation``.
222
+ output_dir: Directory to save overlay PNGs.
223
+ top_k: Number of overlays to save.
224
+ """
225
+ import cv2
226
+
227
+ overlay_dir = output_dir / "overlays"
228
+ overlay_dir.mkdir(parents=True, exist_ok=True)
229
+
230
+ # Sort by predicted change area (descending) — most "interesting" first
231
+ ranked = sorted(samples, key=lambda s: s["change_area"], reverse=True)
232
+ num = min(top_k, len(ranked))
233
+
234
+ for idx in range(num):
235
+ s = ranked[idx]
236
+ overlay_img = overlay_changes(
237
+ img_after=s["B"],
238
+ mask_pred=s["pred"],
239
+ alpha=0.4,
240
+ color=(255, 0, 0),
241
+ )
242
+ save_file = overlay_dir / f"top_{idx + 1:02d}_area_{s['change_area']:.0f}.png"
243
+ cv2.imwrite(str(save_file), cv2.cvtColor(overlay_img, cv2.COLOR_RGB2BGR))
244
+
245
+ logger.info("Saved %d overlay images: %s", num, overlay_dir)
246
 
247
 
248
+ # ---------------------------------------------------------------------------
249
+ # Console formatting
250
+ # ---------------------------------------------------------------------------
251
+
252
+ def print_metrics_table(
253
+ metrics: Dict[str, float],
254
+ model_name: str,
255
+ checkpoint_path: Path,
256
+ epoch: int,
257
+ ) -> None:
258
+ """Print a formatted metrics table to the console.
259
+
260
+ Args:
261
+ metrics: Dict of metric name to value.
262
+ model_name: Model architecture name.
263
+ checkpoint_path: Path to the loaded checkpoint.
264
+ epoch: Training epoch the checkpoint was saved at.
265
+ """
266
+ border = "=" * 50
267
+ logger.info(border)
268
+ logger.info(" TEST SET RESULTS")
269
+ logger.info(border)
270
+ logger.info(" Model : %s", model_name)
271
+ logger.info(" Checkpoint : %s", checkpoint_path)
272
+ logger.info(" Epoch : %d", epoch)
273
+ logger.info(border)
274
+ logger.info(" %-12s %s", "METRIC", "VALUE")
275
+ logger.info(" " + "-" * 24)
276
+ for name, value in metrics.items():
277
+ logger.info(" %-12s %.4f", name.upper(), value)
278
+ logger.info(border)
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Main
283
+ # ---------------------------------------------------------------------------
284
+
285
  def main() -> None:
286
+ """Entry point parse CLI args, evaluate model, save outputs."""
287
+ parser = argparse.ArgumentParser(
288
+ description="Evaluate a trained change-detection model on the test set",
289
+ )
290
+ parser.add_argument(
291
+ "--config", type=Path, default=Path("configs/config.yaml"),
292
+ help="Path to the YAML configuration file.",
293
+ )
294
+ parser.add_argument(
295
+ "--checkpoint", type=Path, required=True,
296
+ help="Path to the model checkpoint (.pth).",
297
+ )
298
+ parser.add_argument(
299
+ "--model", type=str, default=None,
300
+ help="Override the model name from config.",
301
+ )
302
+ parser.add_argument(
303
+ "--output_dir", type=Path, default=None,
304
+ help="Override the output directory (default: from config).",
305
+ )
306
+ parser.add_argument(
307
+ "--threshold", type=float, default=None,
308
+ help="Override the binarisation threshold (default: from config).",
309
+ )
310
  args = parser.parse_args()
311
 
312
+ logging.basicConfig(
313
+ level=logging.INFO,
314
+ format="%(asctime)s [%(levelname)s] %(message)s",
315
+ datefmt="%Y-%m-%d %H:%M:%S",
316
+ )
317
 
318
+ # ---- Config -------------------------------------------------------
319
+ with open(args.config, "r") as fh:
320
+ config: Dict[str, Any] = yaml.safe_load(fh)
321
+
322
+ model_name: str = args.model or config["model"]["name"]
323
+ threshold: float = args.threshold or config.get("evaluation", {}).get("threshold", 0.5)
324
 
 
325
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326
+ logger.info("Device: %s", device)
327
+
328
+ # ---- Paths --------------------------------------------------------
329
+ paths = resolve_paths(config)
330
+ output_dir = args.output_dir or paths["outputs"]
331
+ output_dir = Path(output_dir) / model_name
332
+ output_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ # ---- Load model ---------------------------------------------------
 
 
 
335
  model = get_model(model_name, config).to(device)
336
  ckpt = torch.load(args.checkpoint, map_location=device)
337
  model.load_state_dict(ckpt["model_state_dict"])
 
 
338
 
339
+ ckpt_epoch = ckpt.get("epoch", -1)
340
+ ckpt_f1 = ckpt.get("best_f1", -1.0)
341
+ logger.info(
342
+ "Loaded checkpoint: %s (epoch %d, best F1 %.4f)",
343
+ args.checkpoint, ckpt_epoch, ckpt_f1,
344
+ )
345
+
346
+ param_count = sum(p.numel() for p in model.parameters()) / 1e6
347
+ logger.info("Model: %s (%.2fM parameters)", model_name, param_count)
348
+
349
+ # ---- Test data ----------------------------------------------------
350
+ # No gradients stored during eval → safe to use 2x training batch size
351
+ train_bs = get_train_batch_size(config, model_name)
352
+ eval_bs = train_bs * 2
353
+
354
  ds_cfg = config.get("dataset", {})
355
+ test_ds = ChangeDetectionDataset(
356
+ root=paths["data"] / "test", split="test", config=config,
357
+ )
358
  test_loader = DataLoader(
359
+ test_ds,
360
+ batch_size=eval_bs,
361
+ shuffle=False,
362
  num_workers=ds_cfg.get("num_workers", 4),
363
  pin_memory=ds_cfg.get("pin_memory", True),
364
  )
365
+ logger.info(
366
+ "Test set: %d samples, %d batches (batch_size=%d, 2x train)",
367
+ len(test_ds), len(test_loader), eval_bs,
368
+ )
369
 
370
+ # ---- Run evaluation -----------------------------------------------
371
+ tracker = MetricTracker(threshold=threshold)
372
+ wall_start = time.monotonic()
373
 
374
+ metrics, all_samples = run_evaluation(model, test_loader, device, tracker)
375
+
376
+ eval_time = time.monotonic() - wall_start
377
+ logger.info("Evaluation completed in %.1fs", eval_time)
378
+
379
+ # ---- Print formatted table ----------------------------------------
380
+ print_metrics_table(metrics, model_name, args.checkpoint, ckpt_epoch)
381
+
382
+ # ---- Save results.json --------------------------------------------
383
+ results = {
384
+ "model": model_name,
385
+ "checkpoint": str(args.checkpoint),
386
+ "epoch": ckpt_epoch,
387
+ "threshold": threshold,
388
+ "num_test_samples": len(test_ds),
389
+ "eval_time_seconds": round(eval_time, 2),
390
+ "metrics": {k: round(v, 6) for k, v in metrics.items()},
391
+ }
392
+ results_path = output_dir / "results.json"
393
+ with open(results_path, "w") as f:
394
+ json.dump(results, f, indent=2)
395
+ logger.info("Saved results: %s", results_path)
396
+
397
+ # ---- Prediction grid (20 samples, 5 rows x 4 cols) ----------------
398
+ save_prediction_grid(
399
+ samples=all_samples,
400
+ save_path=output_dir / "prediction_grid.png",
401
+ num_rows=min(5, len(all_samples)),
402
+ )
403
+
404
+ # ---- Individual sample plots (up to 20) ---------------------------
405
+ vis_dir = output_dir / "predictions"
406
+ vis_dir.mkdir(parents=True, exist_ok=True)
407
+ num_individual = min(20, len(all_samples))
408
+ for idx in range(num_individual):
409
+ s = all_samples[idx]
410
+ plot_prediction(
411
+ img_a=s["A"],
412
+ img_b=s["B"],
413
+ mask_true=s["mask"],
414
+ mask_pred=s["pred"],
415
+ filename=vis_dir / f"sample_{idx + 1:03d}.png",
416
+ )
417
+ logger.info("Saved %d individual prediction plots: %s", num_individual, vis_dir)
418
+
419
+ # ---- Top-10 overlay images (by predicted change area) -------------
420
+ save_top_overlays(
421
+ samples=all_samples,
422
+ output_dir=output_dir,
423
+ top_k=10,
424
+ )
425
+
426
+ logger.info("All outputs saved to: %s", output_dir)
427
 
428
 
429
  if __name__ == "__main__":