Vedant Jigarbhai Mehta commited on
Commit
0cbf4d6
·
1 Parent(s): b25c087

Implement full training loop and visualization utilities

Browse files

train.py: AMP, gradient accumulation, gradient clipping, warmup +
cosine scheduler, MetricTracker integration, early stopping on val F1,
checkpoint resume (model + optimizer + scheduler + scaler state),
auto GPU batch-size detection, TensorBoard logging with prediction grids,
conditional Colab/local paths, training time summary.

utils/visualization.py: Agg backend for headless environments,
plot_prediction (1x4 grid), overlay_changes (uint8 output),
plot_metrics_history (per-metric subplots),
log_predictions_to_tensorboard (interleaved sample grid).

Files changed (2) hide show
  1. train.py +474 -211
  2. utils/visualization.py +217 -61
train.py CHANGED
@@ -1,25 +1,28 @@
1
  """Main training script for change detection models.
2
 
3
- Supports AMP, gradient clipping, early stopping, checkpoint saving to Google
4
- Drive, and resume from checkpoint after Colab disconnects.
 
5
 
6
  Usage:
7
  python train.py --config configs/config.yaml --model unet_pp
8
- python train.py --config configs/config.yaml --model changeformer --resume checkpoints/changeformer_last.pth
 
9
  """
10
 
11
  import argparse
12
  import logging
13
  import random
 
14
  from pathlib import Path
15
- from typing import Any, Dict, Tuple
16
 
17
  import numpy as np
18
  import torch
19
  import torch.nn as nn
20
  from torch.cuda.amp import GradScaler, autocast
21
  from torch.optim import AdamW
22
- from torch.optim.lr_scheduler import CosineAnnealingLR
23
  from torch.utils.data import DataLoader
24
  from torch.utils.tensorboard import SummaryWriter
25
  from tqdm import tqdm
@@ -28,14 +31,21 @@ import yaml
28
  from data.dataset import ChangeDetectionDataset
29
  from models import get_model
30
  from utils.losses import get_loss
31
- from utils.metrics import ConfusionMatrix
32
- from utils.visualization import plot_prediction
33
 
34
  logger = logging.getLogger(__name__)
35
 
36
 
 
 
 
 
37
  def set_seed(seed: int) -> None:
38
- """Set random seeds for reproducibility.
 
 
 
39
 
40
  Args:
41
  seed: Random seed value.
@@ -48,11 +58,15 @@ def set_seed(seed: int) -> None:
48
  torch.backends.cudnn.benchmark = False
49
 
50
 
 
 
 
 
51
  def detect_gpu_type() -> str:
52
- """Detect the current GPU type for batch size selection.
53
 
54
  Returns:
55
- GPU type string ('T4', 'V100', or 'default').
56
  """
57
  if not torch.cuda.is_available():
58
  return "default"
@@ -65,80 +79,252 @@ def detect_gpu_type() -> str:
65
 
66
 
67
  def get_batch_size(config: Dict[str, Any], model_name: str) -> int:
68
- """Get appropriate batch size based on GPU and model.
69
 
70
  Args:
71
- config: Full config dict.
72
- model_name: Model name string.
73
 
74
  Returns:
75
- Batch size integer.
76
  """
77
  gpu_type = detect_gpu_type()
78
- batch_sizes = config.get("batch_sizes", {}).get(model_name, {})
79
- return batch_sizes.get(gpu_type, batch_sizes.get("default", 4))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
81
 
82
- def get_paths(config: Dict[str, Any]) -> Dict[str, Path]:
83
- """Resolve paths based on whether running on Colab or locally.
84
 
85
  Args:
86
- config: Full config dict.
87
 
88
  Returns:
89
- Dict with keys: 'data', 'checkpoints', 'logs', 'outputs'.
 
90
  """
91
  if config.get("colab", {}).get("enabled", False):
92
- colab = config["colab"]
93
  return {
94
- "data": Path(colab["data_dir"]),
95
- "checkpoints": Path(colab["checkpoint_dir"]),
96
- "logs": Path(colab["log_dir"]),
97
- "outputs": Path(colab["output_dir"]),
98
- }
99
- else:
100
- paths = config.get("paths", {})
101
- return {
102
- "data": Path(paths.get("processed_data", "./processed_data")),
103
- "checkpoints": Path(paths.get("checkpoint_dir", "./checkpoints")),
104
- "logs": Path(paths.get("log_dir", "./logs")),
105
- "outputs": Path(paths.get("output_dir", "./outputs")),
106
  }
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def build_dataloaders(
110
  config: Dict[str, Any],
111
  data_dir: Path,
112
  batch_size: int,
113
  ) -> Tuple[DataLoader, DataLoader]:
114
- """Create train and validation DataLoaders.
115
 
116
  Args:
117
- config: Full config dict.
118
- data_dir: Path to processed dataset root.
119
- batch_size: Batch size.
120
 
121
  Returns:
122
- Tuple of (train_loader, val_loader).
123
  """
124
  ds_cfg = config.get("dataset", {})
125
  num_workers = ds_cfg.get("num_workers", 4)
126
  pin_memory = ds_cfg.get("pin_memory", True)
127
 
128
- train_ds = ChangeDetectionDataset(data_dir / "train", split="train", config=config)
129
- val_ds = ChangeDetectionDataset(data_dir / "val", split="val", config=config)
 
 
 
 
130
 
131
  train_loader = DataLoader(
132
- train_ds, batch_size=batch_size, shuffle=True,
133
- num_workers=num_workers, pin_memory=pin_memory, drop_last=True,
 
 
 
 
134
  )
135
  val_loader = DataLoader(
136
- val_ds, batch_size=batch_size, shuffle=False,
137
- num_workers=num_workers, pin_memory=pin_memory,
 
 
 
138
  )
139
  return train_loader, val_loader
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def train_one_epoch(
143
  model: nn.Module,
144
  loader: DataLoader,
@@ -146,36 +332,37 @@ def train_one_epoch(
146
  optimizer: torch.optim.Optimizer,
147
  scaler: GradScaler,
148
  device: torch.device,
149
- config: Dict[str, Any],
 
 
150
  ) -> Tuple[float, Dict[str, float]]:
151
- """Run one training epoch.
152
 
153
  Args:
154
- model: The change detection model.
155
- loader: Training DataLoader.
156
- criterion: Loss function.
157
- optimizer: Optimizer.
158
- scaler: GradScaler for AMP.
159
- device: Target device.
160
- config: Full config dict.
 
 
161
 
162
  Returns:
163
- Tuple of (average loss, metrics dict).
164
  """
165
  model.train()
166
  running_loss = 0.0
167
- cm = ConfusionMatrix()
168
- train_cfg = config.get("training", {})
169
- accum_steps = train_cfg.get("gradient_accumulation_steps", 1)
170
- grad_clip = train_cfg.get("grad_clip_max_norm", 1.0)
171
- threshold = config.get("evaluation", {}).get("threshold", 0.5)
172
 
173
- optimizer.zero_grad()
174
 
175
- for step, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
176
- img_a = batch["A"].to(device)
177
- img_b = batch["B"].to(device)
178
- mask = batch["mask"].to(device)
 
179
 
180
  with autocast():
181
  logits = model(img_a, img_b)
@@ -188,17 +375,19 @@ def train_one_epoch(
188
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
189
  scaler.step(optimizer)
190
  scaler.update()
191
- optimizer.zero_grad()
192
 
 
193
  running_loss += loss.item() * accum_steps
 
194
 
195
- # Metrics
196
- with torch.no_grad():
197
- preds = (torch.sigmoid(logits) > threshold).float()
198
- cm.update(preds, mask)
199
 
200
- avg_loss = running_loss / len(loader)
201
- metrics = cm.compute()
 
 
202
  return avg_loss, metrics
203
 
204
 
@@ -208,210 +397,284 @@ def validate(
208
  loader: DataLoader,
209
  criterion: nn.Module,
210
  device: torch.device,
211
- threshold: float = 0.5,
212
- ) -> Tuple[float, Dict[str, float]]:
213
- """Run validation.
214
 
215
  Args:
216
- model: The change detection model.
217
- loader: Validation DataLoader.
218
- criterion: Loss function.
219
  device: Target device.
220
- threshold: Binarization threshold.
221
 
222
  Returns:
223
- Tuple of (average loss, metrics dict).
 
224
  """
225
  model.eval()
226
  running_loss = 0.0
227
- cm = ConfusionMatrix()
 
228
 
229
- for batch in tqdm(loader, desc="Val", leave=False):
230
- img_a = batch["A"].to(device)
231
- img_b = batch["B"].to(device)
232
- mask = batch["mask"].to(device)
 
233
 
234
  logits = model(img_a, img_b)
235
  loss = criterion(logits, mask)
236
- running_loss += loss.item()
237
-
238
- preds = (torch.sigmoid(logits) > threshold).float()
239
- cm.update(preds, mask)
240
-
241
- avg_loss = running_loss / len(loader)
242
- metrics = cm.compute()
243
- return avg_loss, metrics
244
-
245
 
246
- def save_checkpoint(
247
- model: nn.Module,
248
- optimizer: torch.optim.Optimizer,
249
- scheduler: Any,
250
- scaler: GradScaler,
251
- epoch: int,
252
- best_f1: float,
253
- save_path: Path,
254
- ) -> None:
255
- """Save a training checkpoint.
 
256
 
257
- Args:
258
- model: Model to save.
259
- optimizer: Optimizer state.
260
- scheduler: LR scheduler state.
261
- scaler: GradScaler state.
262
- epoch: Current epoch number.
263
- best_f1: Best validation F1 so far.
264
- save_path: Path to save the checkpoint.
265
- """
266
- save_path.parent.mkdir(parents=True, exist_ok=True)
267
- torch.save({
268
- "epoch": epoch,
269
- "model_state_dict": model.state_dict(),
270
- "optimizer_state_dict": optimizer.state_dict(),
271
- "scheduler_state_dict": scheduler.state_dict(),
272
- "scaler_state_dict": scaler.state_dict(),
273
- "best_f1": best_f1,
274
- }, save_path)
275
- logger.info("Saved checkpoint: %s", save_path)
276
 
 
 
 
277
 
278
- def load_checkpoint(
279
- path: Path,
280
- model: nn.Module,
281
- optimizer: torch.optim.Optimizer,
282
- scheduler: Any,
283
- scaler: GradScaler,
284
- device: torch.device,
285
- ) -> Tuple[int, float]:
286
- """Load a training checkpoint for resume.
287
-
288
- Args:
289
- path: Path to the checkpoint file.
290
- model: Model to load weights into.
291
- optimizer: Optimizer to load state into.
292
- scheduler: Scheduler to load state into.
293
- scaler: GradScaler to load state into.
294
- device: Target device.
295
-
296
- Returns:
297
- Tuple of (start_epoch, best_f1).
298
- """
299
- ckpt = torch.load(path, map_location=device)
300
- model.load_state_dict(ckpt["model_state_dict"])
301
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
302
- scheduler.load_state_dict(ckpt["scheduler_state_dict"])
303
- scaler.load_state_dict(ckpt["scaler_state_dict"])
304
- logger.info("Resumed from epoch %d (best F1: %.4f)", ckpt["epoch"], ckpt["best_f1"])
305
- return ckpt["epoch"], ckpt["best_f1"]
306
 
 
 
 
307
 
308
  def main() -> None:
309
- """Main training entry point."""
310
- parser = argparse.ArgumentParser(description="Train change detection model")
311
- parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
312
- parser.add_argument("--model", type=str, default=None, help="Override model name from config")
313
- parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint for resume")
 
 
 
 
 
 
 
 
 
 
 
 
314
  args = parser.parse_args()
315
 
316
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
 
 
 
 
317
 
318
- # Load config
319
- with open(args.config, "r") as f:
320
- config = yaml.safe_load(f)
321
 
322
- model_name = args.model or config["model"]["name"]
323
- seed = config.get("project", {}).get("seed", 42)
324
  set_seed(seed)
325
 
326
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
327
- logger.info("Device: %s", device)
 
328
 
329
- # Resolve paths
330
- paths = get_paths(config)
331
  for p in paths.values():
332
  p.mkdir(parents=True, exist_ok=True)
333
 
334
- # Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  model = get_model(model_name, config).to(device)
336
- logger.info("Model: %s (%.1fM params)", model_name,
337
- sum(p.numel() for p in model.parameters()) / 1e6)
338
 
339
- # Data
340
- batch_size = get_batch_size(config, model_name)
341
  train_loader, val_loader = build_dataloaders(config, paths["data"], batch_size)
 
 
 
 
342
 
343
- # Loss, optimizer, scheduler
344
- criterion = get_loss(config)
345
- lr = config.get("learning_rates", {}).get(model_name, config["training"]["learning_rate"])
346
- epochs = config.get("epoch_counts", {}).get(model_name, config["training"]["epochs"])
347
 
348
- optimizer = AdamW(model.parameters(), lr=lr, weight_decay=config["training"]["weight_decay"])
349
- scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
350
- scaler = GradScaler()
 
 
351
 
352
- # TensorBoard
 
 
 
353
  writer = SummaryWriter(log_dir=str(paths["logs"] / model_name))
354
 
355
- # Resume
356
- start_epoch = 0
357
- best_f1 = 0.0
358
- if args.resume and args.resume.exists():
359
- start_epoch, best_f1 = load_checkpoint(
360
- args.resume, model, optimizer, scheduler, scaler, device
 
 
 
 
 
 
361
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
- # Early stopping state
364
- es_cfg = config["training"]["early_stopping"]
365
- patience = es_cfg.get("patience", 15)
366
- patience_counter = 0
367
- threshold = config.get("evaluation", {}).get("threshold", 0.5)
368
 
369
- # Training loop
370
- for epoch in range(start_epoch, epochs):
371
- logger.info("Epoch %d/%d", epoch + 1, epochs)
372
 
 
 
373
  train_loss, train_metrics = train_one_epoch(
374
- model, train_loader, criterion, optimizer, scaler, device, config
 
375
  )
376
- val_loss, val_metrics = validate(model, val_loader, criterion, device, threshold)
 
 
 
 
 
 
 
377
  scheduler.step()
378
 
379
- # Log
380
- writer.add_scalar("Loss/train", train_loss, epoch)
381
- writer.add_scalar("Loss/val", val_loss, epoch)
382
- for k, v in val_metrics.items():
383
- writer.add_scalar(f"Val/{k}", v, epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
 
 
385
  logger.info(
386
- " Train Loss: %.4f | Val Loss: %.4f | Val F1: %.4f | Val IoU: %.4f",
387
- train_loss, val_loss, val_metrics["f1"], val_metrics["iou"],
388
  )
 
 
 
 
 
 
 
 
 
 
389
 
390
- # Save last checkpoint (always)
391
  save_checkpoint(
392
- model, optimizer, scheduler, scaler, epoch + 1, best_f1,
393
- paths["checkpoints"] / f"{model_name}_last.pth",
 
 
 
394
  )
395
 
396
- # Save best checkpoint
397
  if val_metrics["f1"] > best_f1:
398
  best_f1 = val_metrics["f1"]
 
399
  patience_counter = 0
400
  save_checkpoint(
401
- model, optimizer, scheduler, scaler, epoch + 1, best_f1,
402
- paths["checkpoints"] / f"{model_name}_best.pth",
 
 
 
403
  )
404
- logger.info(" New best F1: %.4f", best_f1)
405
  else:
406
  patience_counter += 1
 
 
 
407
 
408
- # Early stopping
409
- if es_cfg.get("enabled", True) and patience_counter >= patience:
410
- logger.info("Early stopping triggered at epoch %d", epoch + 1)
 
 
 
411
  break
412
 
 
413
  writer.close()
414
- logger.info("Training complete. Best F1: %.4f", best_f1)
 
 
 
 
 
 
 
 
 
415
 
416
 
417
  if __name__ == "__main__":
 
1
  """Main training script for change detection models.
2
 
3
+ Supports mixed-precision training, gradient accumulation, gradient clipping,
4
+ early stopping on validation F1, checkpoint saving (best + last) to Google
5
+ Drive or local disk, and full resume from checkpoint after Colab disconnects.
6
 
7
  Usage:
8
  python train.py --config configs/config.yaml --model unet_pp
9
+ python train.py --config configs/config.yaml --model changeformer \
10
+ --resume /content/drive/MyDrive/change-detection/checkpoints/changeformer_last.pth
11
  """
12
 
13
  import argparse
14
  import logging
15
  import random
16
+ import time
17
  from pathlib import Path
18
+ from typing import Any, Dict, Optional, Tuple
19
 
20
  import numpy as np
21
  import torch
22
  import torch.nn as nn
23
  from torch.cuda.amp import GradScaler, autocast
24
  from torch.optim import AdamW
25
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
26
  from torch.utils.data import DataLoader
27
  from torch.utils.tensorboard import SummaryWriter
28
  from tqdm import tqdm
 
31
  from data.dataset import ChangeDetectionDataset
32
  from models import get_model
33
  from utils.losses import get_loss
34
+ from utils.metrics import MetricTracker
35
+ from utils.visualization import log_predictions_to_tensorboard
36
 
37
  logger = logging.getLogger(__name__)
38
 
39
 
40
+ # ---------------------------------------------------------------------------
41
+ # Reproducibility
42
+ # ---------------------------------------------------------------------------
43
+
44
  def set_seed(seed: int) -> None:
45
+ """Set all random seeds for reproducibility.
46
+
47
+ Configures Python, NumPy, PyTorch (CPU + CUDA), and cuDNN for
48
+ deterministic behaviour.
49
 
50
  Args:
51
  seed: Random seed value.
 
58
  torch.backends.cudnn.benchmark = False
59
 
60
 
61
+ # ---------------------------------------------------------------------------
62
+ # GPU / config helpers
63
+ # ---------------------------------------------------------------------------
64
+
65
  def detect_gpu_type() -> str:
66
+ """Detect the current GPU type for automatic batch-size selection.
67
 
68
  Returns:
69
+ One of ``'T4'``, ``'V100'``, or ``'default'``.
70
  """
71
  if not torch.cuda.is_available():
72
  return "default"
 
79
 
80
 
81
  def get_batch_size(config: Dict[str, Any], model_name: str) -> int:
82
+ """Look up the batch size for the current GPU + model combination.
83
 
84
  Args:
85
+ config: Full project config dict.
86
+ model_name: Model identifier string.
87
 
88
  Returns:
89
+ Batch size as an integer.
90
  """
91
  gpu_type = detect_gpu_type()
92
+ model_sizes = config.get("batch_sizes", {}).get(model_name, {})
93
+ return model_sizes.get(gpu_type, model_sizes.get("default", 4))
94
+
95
+
96
+ def get_learning_rate(config: Dict[str, Any], model_name: str) -> float:
97
+ """Look up the per-model learning rate, falling back to the global default.
98
+
99
+ Args:
100
+ config: Full project config dict.
101
+ model_name: Model identifier string.
102
+
103
+ Returns:
104
+ Learning rate as a float.
105
+ """
106
+ return config.get("learning_rates", {}).get(
107
+ model_name, config["training"]["learning_rate"]
108
+ )
109
+
110
+
111
+ def get_num_epochs(config: Dict[str, Any], model_name: str) -> int:
112
+ """Look up the per-model epoch count, falling back to the global default.
113
+
114
+ Args:
115
+ config: Full project config dict.
116
+ model_name: Model identifier string.
117
+
118
+ Returns:
119
+ Number of epochs as an integer.
120
+ """
121
+ return config.get("epoch_counts", {}).get(
122
+ model_name, config["training"]["epochs"]
123
+ )
124
+
125
 
126
+ def resolve_paths(config: Dict[str, Any]) -> Dict[str, Path]:
127
+ """Build a path dict based on whether Colab mode is enabled.
128
 
129
+ When ``config["colab"]["enabled"]`` is ``True`` all persistent artefacts
130
+ point to Google Drive; otherwise they use the local ``paths`` section.
131
 
132
  Args:
133
+ config: Full project config dict.
134
 
135
  Returns:
136
+ Dict with keys ``'data'``, ``'checkpoints'``, ``'logs'``,
137
+ ``'outputs'``.
138
  """
139
  if config.get("colab", {}).get("enabled", False):
140
+ c = config["colab"]
141
  return {
142
+ "data": Path(c["data_dir"]),
143
+ "checkpoints": Path(c["checkpoint_dir"]),
144
+ "logs": Path(c["log_dir"]),
145
+ "outputs": Path(c["output_dir"]),
 
 
 
 
 
 
 
 
146
  }
147
 
148
+ p = config.get("paths", {})
149
+ return {
150
+ "data": Path(p.get("processed_data", "./processed_data")),
151
+ "checkpoints": Path(p.get("checkpoint_dir", "./checkpoints")),
152
+ "logs": Path(p.get("log_dir", "./logs")),
153
+ "outputs": Path(p.get("output_dir", "./outputs")),
154
+ }
155
+
156
+
157
+ # ---------------------------------------------------------------------------
158
+ # Data
159
+ # ---------------------------------------------------------------------------
160
 
161
  def build_dataloaders(
162
  config: Dict[str, Any],
163
  data_dir: Path,
164
  batch_size: int,
165
  ) -> Tuple[DataLoader, DataLoader]:
166
+ """Create training and validation ``DataLoader`` instances.
167
 
168
  Args:
169
+ config: Full project config dict.
170
+ data_dir: Root of the processed dataset (contains ``train/``, ``val/``).
171
+ batch_size: Mini-batch size.
172
 
173
  Returns:
174
+ Tuple of ``(train_loader, val_loader)``.
175
  """
176
  ds_cfg = config.get("dataset", {})
177
  num_workers = ds_cfg.get("num_workers", 4)
178
  pin_memory = ds_cfg.get("pin_memory", True)
179
 
180
+ train_ds = ChangeDetectionDataset(
181
+ root=data_dir / "train", split="train", config=config,
182
+ )
183
+ val_ds = ChangeDetectionDataset(
184
+ root=data_dir / "val", split="val", config=config,
185
+ )
186
 
187
  train_loader = DataLoader(
188
+ train_ds,
189
+ batch_size=batch_size,
190
+ shuffle=True,
191
+ num_workers=num_workers,
192
+ pin_memory=pin_memory,
193
+ drop_last=True,
194
  )
195
  val_loader = DataLoader(
196
+ val_ds,
197
+ batch_size=batch_size,
198
+ shuffle=False,
199
+ num_workers=num_workers,
200
+ pin_memory=pin_memory,
201
  )
202
  return train_loader, val_loader
203
 
204
 
205
+ # ---------------------------------------------------------------------------
206
+ # Scheduler with linear warmup
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def build_scheduler(
210
+ optimizer: torch.optim.Optimizer,
211
+ total_epochs: int,
212
+ warmup_epochs: int,
213
+ ) -> torch.optim.lr_scheduler._LRScheduler:
214
+ """Create a CosineAnnealingLR scheduler preceded by linear warmup.
215
+
216
+ During the first ``warmup_epochs`` the LR ramps linearly from
217
+ ``start_factor`` to the base LR, then cosine-decays for the remainder.
218
+
219
+ Args:
220
+ optimizer: Optimizer whose LR groups will be scheduled.
221
+ total_epochs: Total number of training epochs.
222
+ warmup_epochs: Number of warmup epochs (0 to disable).
223
+
224
+ Returns:
225
+ A learning-rate scheduler instance.
226
+ """
227
+ if warmup_epochs > 0 and warmup_epochs < total_epochs:
228
+ warmup = LinearLR(
229
+ optimizer,
230
+ start_factor=0.01,
231
+ end_factor=1.0,
232
+ total_iters=warmup_epochs,
233
+ )
234
+ cosine = CosineAnnealingLR(
235
+ optimizer,
236
+ T_max=total_epochs - warmup_epochs,
237
+ )
238
+ return SequentialLR(
239
+ optimizer,
240
+ schedulers=[warmup, cosine],
241
+ milestones=[warmup_epochs],
242
+ )
243
+
244
+ return CosineAnnealingLR(optimizer, T_max=total_epochs)
245
+
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # Checkpointing
249
+ # ---------------------------------------------------------------------------
250
+
251
+ def save_checkpoint(
252
+ model: nn.Module,
253
+ optimizer: torch.optim.Optimizer,
254
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
255
+ scaler: GradScaler,
256
+ epoch: int,
257
+ best_f1: float,
258
+ best_epoch: int,
259
+ save_path: Path,
260
+ ) -> None:
261
+ """Persist a full training checkpoint to disk.
262
+
263
+ Args:
264
+ model: Model whose weights to save.
265
+ optimizer: Optimizer state to save.
266
+ scheduler: LR scheduler state to save.
267
+ scaler: ``GradScaler`` state to save.
268
+ epoch: Epoch number just completed (1-indexed).
269
+ best_f1: Best validation F1 achieved so far.
270
+ best_epoch: Epoch that achieved ``best_f1``.
271
+ save_path: Destination file path.
272
+ """
273
+ save_path.parent.mkdir(parents=True, exist_ok=True)
274
+ torch.save(
275
+ {
276
+ "epoch": epoch,
277
+ "model_state_dict": model.state_dict(),
278
+ "optimizer_state_dict": optimizer.state_dict(),
279
+ "scheduler_state_dict": scheduler.state_dict(),
280
+ "scaler_state_dict": scaler.state_dict(),
281
+ "best_f1": best_f1,
282
+ "best_epoch": best_epoch,
283
+ },
284
+ save_path,
285
+ )
286
+ logger.info("Checkpoint saved → %s", save_path)
287
+
288
+
289
+ def load_checkpoint(
290
+ path: Path,
291
+ model: nn.Module,
292
+ optimizer: torch.optim.Optimizer,
293
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
294
+ scaler: GradScaler,
295
+ device: torch.device,
296
+ ) -> Tuple[int, float, int]:
297
+ """Restore training state from a checkpoint.
298
+
299
+ Args:
300
+ path: Checkpoint file to load.
301
+ model: Model to receive saved weights.
302
+ optimizer: Optimizer to receive saved state.
303
+ scheduler: Scheduler to receive saved state.
304
+ scaler: ``GradScaler`` to receive saved state.
305
+ device: Target device for ``map_location``.
306
+
307
+ Returns:
308
+ Tuple of ``(start_epoch, best_f1, best_epoch)``.
309
+ """
310
+ ckpt = torch.load(path, map_location=device)
311
+ model.load_state_dict(ckpt["model_state_dict"])
312
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
313
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
314
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
315
+ best_f1 = ckpt["best_f1"]
316
+ best_epoch = ckpt.get("best_epoch", ckpt["epoch"])
317
+ logger.info(
318
+ "Resumed from epoch %d (best F1: %.4f @ epoch %d)",
319
+ ckpt["epoch"], best_f1, best_epoch,
320
+ )
321
+ return ckpt["epoch"], best_f1, best_epoch
322
+
323
+
324
+ # ---------------------------------------------------------------------------
325
+ # Train / validate one epoch
326
+ # ---------------------------------------------------------------------------
327
+
328
  def train_one_epoch(
329
  model: nn.Module,
330
  loader: DataLoader,
 
332
  optimizer: torch.optim.Optimizer,
333
  scaler: GradScaler,
334
  device: torch.device,
335
+ tracker: MetricTracker,
336
+ accum_steps: int,
337
+ grad_clip: float,
338
  ) -> Tuple[float, Dict[str, float]]:
339
+ """Execute one full training epoch.
340
 
341
  Args:
342
+ model: Change-detection model.
343
+ loader: Training ``DataLoader``.
344
+ criterion: Loss module (operates on raw logits).
345
+ optimizer: Optimiser instance.
346
+ scaler: ``GradScaler`` for mixed-precision training.
347
+ device: Target CUDA / CPU device.
348
+ tracker: ``MetricTracker`` (reset externally before this call).
349
+ accum_steps: Number of gradient-accumulation micro-steps.
350
+ grad_clip: Maximum gradient norm for clipping.
351
 
352
  Returns:
353
+ Tuple of ``(average_loss, metrics_dict)``.
354
  """
355
  model.train()
356
  running_loss = 0.0
357
+ num_batches = 0
 
 
 
 
358
 
359
+ optimizer.zero_grad(set_to_none=True)
360
 
361
+ pbar = tqdm(loader, desc=" Train", leave=False, dynamic_ncols=True)
362
+ for step, batch in enumerate(pbar):
363
+ img_a = batch["A"].to(device, non_blocking=True)
364
+ img_b = batch["B"].to(device, non_blocking=True)
365
+ mask = batch["mask"].to(device, non_blocking=True)
366
 
367
  with autocast():
368
  logits = model(img_a, img_b)
 
375
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
376
  scaler.step(optimizer)
377
  scaler.update()
378
+ optimizer.zero_grad(set_to_none=True)
379
 
380
+ # Track loss (undo the accumulation scaling for logging)
381
  running_loss += loss.item() * accum_steps
382
+ num_batches += 1
383
 
384
+ # Track metrics (MetricTracker handles sigmoid + threshold internally)
385
+ tracker.update(logits.detach(), mask)
 
 
386
 
387
+ pbar.set_postfix(loss=f"{running_loss / num_batches:.4f}")
388
+
389
+ avg_loss = running_loss / max(num_batches, 1)
390
+ metrics = tracker.compute()
391
  return avg_loss, metrics
392
 
393
 
 
397
  loader: DataLoader,
398
  criterion: nn.Module,
399
  device: torch.device,
400
+ tracker: MetricTracker,
401
+ ) -> Tuple[float, Dict[str, float], Optional[Dict[str, torch.Tensor]]]:
402
+ """Run one full validation pass.
403
 
404
  Args:
405
+ model: Change-detection model (set to eval internally).
406
+ loader: Validation ``DataLoader``.
407
+ criterion: Loss module (operates on raw logits).
408
  device: Target device.
409
+ tracker: ``MetricTracker`` (reset externally before this call).
410
 
411
  Returns:
412
+ Tuple of ``(average_loss, metrics_dict, last_batch)`` where
413
+ ``last_batch`` is the final mini-batch dict (for visualisation).
414
  """
415
  model.eval()
416
  running_loss = 0.0
417
+ num_batches = 0
418
+ last_batch: Optional[Dict[str, torch.Tensor]] = None
419
 
420
+ pbar = tqdm(loader, desc=" Val ", leave=False, dynamic_ncols=True)
421
+ for batch in pbar:
422
+ img_a = batch["A"].to(device, non_blocking=True)
423
+ img_b = batch["B"].to(device, non_blocking=True)
424
+ mask = batch["mask"].to(device, non_blocking=True)
425
 
426
  logits = model(img_a, img_b)
427
  loss = criterion(logits, mask)
 
 
 
 
 
 
 
 
 
428
 
429
+ running_loss += loss.item()
430
+ num_batches += 1
431
+ tracker.update(logits, mask)
432
+
433
+ # Keep the last batch for TensorBoard visualisation
434
+ last_batch = {
435
+ "A": img_a,
436
+ "B": img_b,
437
+ "mask": mask,
438
+ "logits": logits,
439
+ }
440
 
441
+ pbar.set_postfix(loss=f"{running_loss / num_batches:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ avg_loss = running_loss / max(num_batches, 1)
444
+ metrics = tracker.compute()
445
+ return avg_loss, metrics, last_batch
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
+ # ---------------------------------------------------------------------------
449
+ # Main
450
+ # ---------------------------------------------------------------------------
451
 
452
  def main() -> None:
453
+ """Entry point parse CLI args, build components, run training loop."""
454
+ # ---- CLI ----------------------------------------------------------
455
+ parser = argparse.ArgumentParser(
456
+ description="Train a change-detection model",
457
+ )
458
+ parser.add_argument(
459
+ "--config", type=Path, default=Path("configs/config.yaml"),
460
+ help="Path to the YAML configuration file.",
461
+ )
462
+ parser.add_argument(
463
+ "--model", type=str, default=None,
464
+ help="Override the model name from config (siamese_cnn | unet_pp | changeformer).",
465
+ )
466
+ parser.add_argument(
467
+ "--resume", type=Path, default=None,
468
+ help="Path to a checkpoint file to resume training from.",
469
+ )
470
  args = parser.parse_args()
471
 
472
+ logging.basicConfig(
473
+ level=logging.INFO,
474
+ format="%(asctime)s [%(levelname)s] %(message)s",
475
+ datefmt="%Y-%m-%d %H:%M:%S",
476
+ )
477
+
478
+ # ---- Config -------------------------------------------------------
479
+ with open(args.config, "r") as fh:
480
+ config: Dict[str, Any] = yaml.safe_load(fh)
481
 
482
+ model_name: str = args.model or config["model"]["name"]
483
+ train_cfg = config["training"]
484
+ seed: int = config.get("project", {}).get("seed", 42)
485
 
 
 
486
  set_seed(seed)
487
 
488
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
489
+ gpu_type = detect_gpu_type()
490
+ logger.info("Device: %s | GPU type: %s", device, gpu_type)
491
 
492
+ # ---- Paths --------------------------------------------------------
493
+ paths = resolve_paths(config)
494
  for p in paths.values():
495
  p.mkdir(parents=True, exist_ok=True)
496
 
497
+ # ---- Hyperparams (auto from per-model tables) ---------------------
498
+ batch_size = get_batch_size(config, model_name)
499
+ lr = get_learning_rate(config, model_name)
500
+ num_epochs = get_num_epochs(config, model_name)
501
+ accum_steps: int = train_cfg.get("gradient_accumulation_steps", 1)
502
+ grad_clip: float = train_cfg.get("grad_clip_max_norm", 1.0)
503
+ warmup_epochs: int = train_cfg.get("warmup_epochs", 5)
504
+ vis_interval: int = train_cfg.get("vis_interval", 5)
505
+ threshold: float = config.get("evaluation", {}).get("threshold", 0.5)
506
+
507
+ logger.info(
508
+ "Hyperparams → model=%s bs=%d lr=%.1e epochs=%d accum=%d warmup=%d",
509
+ model_name, batch_size, lr, num_epochs, accum_steps, warmup_epochs,
510
+ )
511
+
512
+ # ---- Model --------------------------------------------------------
513
  model = get_model(model_name, config).to(device)
514
+ param_count = sum(p.numel() for p in model.parameters()) / 1e6
515
+ logger.info("Model: %s (%.2fM parameters)", model_name, param_count)
516
 
517
+ # ---- Data ---------------------------------------------------------
 
518
  train_loader, val_loader = build_dataloaders(config, paths["data"], batch_size)
519
+ logger.info(
520
+ "Data: %d train batches, %d val batches (batch_size=%d)",
521
+ len(train_loader), len(val_loader), batch_size,
522
+ )
523
 
524
+ # ---- Loss / optimiser / scheduler ---------------------------------
525
+ criterion = get_loss(config).to(device)
 
 
526
 
527
+ optimizer = AdamW(
528
+ model.parameters(),
529
+ lr=lr,
530
+ weight_decay=train_cfg["weight_decay"],
531
+ )
532
 
533
+ scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs)
534
+ scaler = GradScaler(enabled=train_cfg.get("amp", True))
535
+
536
+ # ---- TensorBoard --------------------------------------------------
537
  writer = SummaryWriter(log_dir=str(paths["logs"] / model_name))
538
 
539
+ # ---- MetricTrackers -----------------------------------------------
540
+ train_tracker = MetricTracker(threshold=threshold)
541
+ val_tracker = MetricTracker(threshold=threshold)
542
+
543
+ # ---- Resume -------------------------------------------------------
544
+ start_epoch: int = 0
545
+ best_f1: float = 0.0
546
+ best_epoch: int = 0
547
+
548
+ if args.resume is not None and args.resume.exists():
549
+ start_epoch, best_f1, best_epoch = load_checkpoint(
550
+ args.resume, model, optimizer, scheduler, scaler, device,
551
  )
552
+ elif args.resume is not None:
553
+ logger.warning("Resume path does not exist: %s — training from scratch", args.resume)
554
+
555
+ # ---- Early stopping state -----------------------------------------
556
+ es_cfg = train_cfg.get("early_stopping", {})
557
+ es_enabled: bool = es_cfg.get("enabled", True)
558
+ patience: int = es_cfg.get("patience", 15)
559
+ patience_counter: int = 0
560
+
561
+ # ---- Training loop ------------------------------------------------
562
+ wall_start = time.monotonic()
563
+
564
+ logger.info("=" * 60)
565
+ logger.info("Starting training from epoch %d", start_epoch + 1)
566
+ logger.info("=" * 60)
567
 
568
+ for epoch in range(start_epoch, num_epochs):
569
+ epoch_start = time.monotonic()
570
+ epoch_num = epoch + 1 # 1-indexed for display / checkpoints
 
 
571
 
572
+ current_lr = optimizer.param_groups[0]["lr"]
573
+ logger.info("Epoch %d/%d (lr=%.2e)", epoch_num, num_epochs, current_lr)
 
574
 
575
+ # -- Train ------------------------------------------------------
576
+ train_tracker.reset()
577
  train_loss, train_metrics = train_one_epoch(
578
+ model, train_loader, criterion, optimizer, scaler, device,
579
+ train_tracker, accum_steps, grad_clip,
580
  )
581
+
582
+ # -- Validate ---------------------------------------------------
583
+ val_tracker.reset()
584
+ val_loss, val_metrics, last_val_batch = validate(
585
+ model, val_loader, criterion, device, val_tracker,
586
+ )
587
+
588
+ # -- Step scheduler (after both train + val) --------------------
589
  scheduler.step()
590
 
591
+ # -- TensorBoard scalars ----------------------------------------
592
+ writer.add_scalar("Loss/train", train_loss, epoch_num)
593
+ writer.add_scalar("Loss/val", val_loss, epoch_num)
594
+ writer.add_scalar("LR", current_lr, epoch_num)
595
+
596
+ for key, value in train_metrics.items():
597
+ writer.add_scalar(f"Train/{key}", value, epoch_num)
598
+ for key, value in val_metrics.items():
599
+ writer.add_scalar(f"Val/{key}", value, epoch_num)
600
+
601
+ # -- TensorBoard prediction images ------------------------------
602
+ if last_val_batch is not None and epoch_num % vis_interval == 0:
603
+ log_predictions_to_tensorboard(
604
+ writer,
605
+ img_a=last_val_batch["A"],
606
+ img_b=last_val_batch["B"],
607
+ mask_true=last_val_batch["mask"],
608
+ mask_pred=last_val_batch["logits"],
609
+ step=epoch_num,
610
+ num_samples=4,
611
+ )
612
 
613
+ # -- Console log ------------------------------------------------
614
+ epoch_time = time.monotonic() - epoch_start
615
  logger.info(
616
+ " Train loss: %.4f | F1: %.4f | IoU: %.4f",
617
+ train_loss, train_metrics["f1"], train_metrics["iou"],
618
  )
619
+ logger.info(
620
+ " Val — loss: %.4f | F1: %.4f | IoU: %.4f | Prec: %.4f | Rec: %.4f | OA: %.4f",
621
+ val_loss,
622
+ val_metrics["f1"],
623
+ val_metrics["iou"],
624
+ val_metrics["precision"],
625
+ val_metrics["recall"],
626
+ val_metrics["oa"],
627
+ )
628
+ logger.info(" Epoch time: %.1fs", epoch_time)
629
 
630
+ # -- Save last checkpoint (every epoch) -------------------------
631
  save_checkpoint(
632
+ model, optimizer, scheduler, scaler,
633
+ epoch=epoch_num,
634
+ best_f1=best_f1,
635
+ best_epoch=best_epoch,
636
+ save_path=paths["checkpoints"] / f"{model_name}_last.pth",
637
  )
638
 
639
+ # -- Save best checkpoint (if improved) -------------------------
640
  if val_metrics["f1"] > best_f1:
641
  best_f1 = val_metrics["f1"]
642
+ best_epoch = epoch_num
643
  patience_counter = 0
644
  save_checkpoint(
645
+ model, optimizer, scheduler, scaler,
646
+ epoch=epoch_num,
647
+ best_f1=best_f1,
648
+ best_epoch=best_epoch,
649
+ save_path=paths["checkpoints"] / f"{model_name}_best.pth",
650
  )
651
+ logger.info(" New best F1: %.4f (epoch %d)", best_f1, best_epoch)
652
  else:
653
  patience_counter += 1
654
+ logger.info(
655
+ " No improvement (%d/%d patience)", patience_counter, patience,
656
+ )
657
 
658
+ # -- Early stopping ---------------------------------------------
659
+ if es_enabled and patience_counter >= patience:
660
+ logger.info(
661
+ "Early stopping triggered after %d epochs without improvement.",
662
+ patience,
663
+ )
664
  break
665
 
666
+ # ---- Summary ------------------------------------------------------
667
  writer.close()
668
+ total_time = time.monotonic() - wall_start
669
+ hours, remainder = divmod(total_time, 3600)
670
+ minutes, seconds = divmod(remainder, 60)
671
+
672
+ logger.info("=" * 60)
673
+ logger.info("Training complete.")
674
+ logger.info(" Best val F1 : %.4f (epoch %d)", best_f1, best_epoch)
675
+ logger.info(" Total time : %dh %dm %ds", int(hours), int(minutes), int(seconds))
676
+ logger.info(" Checkpoints : %s", paths["checkpoints"])
677
+ logger.info("=" * 60)
678
 
679
 
680
  if __name__ == "__main__":
utils/visualization.py CHANGED
@@ -1,141 +1,297 @@
1
  """Visualization utilities for change detection results.
2
 
3
- Provides functions to plot predictions, overlay change maps, and track
4
- training metrics over time.
 
 
 
 
 
 
 
 
5
  """
6
 
 
 
 
 
7
  from pathlib import Path
8
- from typing import Dict, List, Optional
9
 
10
  import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def denormalize(
16
  img: np.ndarray,
17
- mean: tuple = (0.485, 0.456, 0.406),
18
- std: tuple = (0.229, 0.224, 0.225),
19
  ) -> np.ndarray:
20
- """Reverse ImageNet normalization for display.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  Args:
23
- img: Normalized image array [H, W, 3].
24
- mean: Channel means used for normalization.
25
- std: Channel stds used for normalization.
26
 
27
  Returns:
28
- Denormalized image clipped to [0, 1].
29
  """
30
- img = img * np.array(std) + np.array(mean)
31
- return np.clip(img, 0, 1)
32
 
33
 
 
 
 
 
34
  def plot_prediction(
35
  img_a: torch.Tensor,
36
  img_b: torch.Tensor,
37
- mask_gt: torch.Tensor,
38
  mask_pred: torch.Tensor,
39
- save_path: Optional[Path] = None,
40
  ) -> plt.Figure:
41
- """Plot a single change detection prediction.
 
 
42
 
43
- Shows: Before | After | Ground Truth | Prediction in a 1x4 grid.
 
44
 
45
  Args:
46
- img_a: Before image tensor [3, H, W] (normalized).
47
- img_b: After image tensor [3, H, W] (normalized).
48
- mask_gt: Ground truth mask [1, H, W] (binary).
49
- mask_pred: Predicted mask [1, H, W] (binary or probability).
50
- save_path: Optional path to save the figure.
 
51
 
52
  Returns:
53
- Matplotlib figure.
54
  """
55
- fig, axes = plt.subplots(1, 4, figsize=(16, 4))
 
 
 
56
 
57
- # Convert tensors to numpy
58
- a = denormalize(img_a.permute(1, 2, 0).cpu().numpy())
59
- b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
60
- gt = mask_gt.squeeze(0).cpu().numpy()
61
- pred = mask_pred.squeeze(0).cpu().numpy()
62
 
 
63
  titles = ["Before (A)", "After (B)", "Ground Truth", "Prediction"]
64
- images = [a, b, gt, pred]
65
  cmaps = [None, None, "gray", "gray"]
66
 
67
  for ax, img, title, cmap in zip(axes, images, titles, cmaps):
68
  ax.imshow(img, cmap=cmap, vmin=0, vmax=1)
69
- ax.set_title(title)
70
  ax.axis("off")
71
 
72
- plt.tight_layout()
73
 
74
- if save_path is not None:
75
- fig.savefig(save_path, dpi=150, bbox_inches="tight")
 
 
 
 
76
 
77
  return fig
78
 
79
 
 
 
 
 
80
  def overlay_changes(
81
- img_b: torch.Tensor,
82
  mask_pred: torch.Tensor,
83
  alpha: float = 0.4,
84
- color: tuple = (1.0, 0.0, 0.0),
85
  ) -> np.ndarray:
86
- """Overlay predicted change mask on the after image.
 
 
 
87
 
88
  Args:
89
- img_b: After image tensor [3, H, W] (normalized).
90
- mask_pred: Predicted binary mask [1, H, W].
91
- alpha: Overlay transparency.
92
- color: RGB color for the overlay (default: red).
 
 
93
 
94
  Returns:
95
- Overlaid image as numpy array [H, W, 3].
 
96
  """
97
- b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
98
- mask = mask_pred.squeeze(0).cpu().numpy()
 
 
 
99
 
100
- overlay = b.copy()
 
101
  for c in range(3):
102
  overlay[:, :, c] = np.where(
103
- mask > 0.5,
104
- b[:, :, c] * (1 - alpha) + color[c] * alpha,
105
- b[:, :, c],
106
  )
107
- return overlay
 
108
 
109
 
 
 
 
 
110
  def plot_metrics_history(
111
- history: Dict[str, List[float]],
112
- save_path: Optional[Path] = None,
113
  ) -> plt.Figure:
114
- """Plot training metric curves over epochs.
 
 
 
115
 
116
  Args:
117
- history: Dict mapping metric names to lists of per-epoch values.
118
- save_path: Optional path to save the figure.
 
119
 
120
  Returns:
121
- Matplotlib figure.
122
  """
123
- n_metrics = len(history)
124
- fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4))
 
 
125
 
 
126
  if n_metrics == 1:
127
  axes = [axes]
128
 
129
- for ax, (name, values) in zip(axes, history.items()):
130
- ax.plot(values, marker="o", markersize=2)
131
- ax.set_title(name)
 
132
  ax.set_xlabel("Epoch")
133
  ax.set_ylabel(name)
134
  ax.grid(True, alpha=0.3)
135
 
136
- plt.tight_layout()
137
 
138
  if save_path is not None:
139
- fig.savefig(save_path, dpi=150, bbox_inches="tight")
 
 
 
 
140
 
141
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Visualization utilities for change detection results.
2
 
3
+ Provides helpers for:
4
+ - Plotting side-by-side predictions (Before | After | GT | Pred)
5
+ - Overlaying predicted change masks on satellite images
6
+ - Plotting metric curves across epochs
7
+ - Logging sample prediction grids to TensorBoard
8
+
9
+ All public functions accept **ImageNet-normalised** ``torch.Tensor`` inputs
10
+ with shape ``[C, H, W]`` and handle denormalisation internally. The Agg
11
+ backend is set at import time so the module works in headless environments
12
+ (Google Colab, CI, remote servers).
13
  """
14
 
15
+ import matplotlib
16
+ matplotlib.use("Agg") # headless backend — must be set before pyplot import
17
+
18
+ import logging
19
  from pathlib import Path
20
+ from typing import Dict, List, Optional, Tuple, Union
21
 
22
  import matplotlib.pyplot as plt
23
  import numpy as np
24
  import torch
25
+ from torch.utils.tensorboard import SummaryWriter
26
+ import torchvision.utils as vutils
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # ImageNet constants (duplicated here to avoid circular imports from data/)
31
+ _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
32
+ _IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Internal helpers
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def _to_numpy_hwc(tensor: torch.Tensor) -> np.ndarray:
40
+ """Convert a ``[C, H, W]`` torch tensor to ``[H, W, C]`` numpy array.
41
+
42
+ Args:
43
+ tensor: Image tensor of shape ``[C, H, W]``.
44
+
45
+ Returns:
46
+ Numpy array of shape ``[H, W, C]`` (float32).
47
+ """
48
+ return tensor.detach().cpu().float().permute(1, 2, 0).numpy()
49
+
50
+
51
+ def _mask_to_numpy(tensor: torch.Tensor) -> np.ndarray:
52
+ """Convert a ``[1, H, W]`` mask tensor to ``[H, W]`` numpy array.
53
+
54
+ Args:
55
+ tensor: Mask tensor of shape ``[1, H, W]``.
56
+
57
+ Returns:
58
+ Numpy array of shape ``[H, W]`` (float32).
59
+ """
60
+ return tensor.detach().cpu().float().squeeze(0).numpy()
61
 
62
 
63
  def denormalize(
64
  img: np.ndarray,
65
+ mean: np.ndarray = _IMAGENET_MEAN,
66
+ std: np.ndarray = _IMAGENET_STD,
67
  ) -> np.ndarray:
68
+ """Reverse ImageNet normalisation for display.
69
+
70
+ Args:
71
+ img: Normalised image of shape ``[H, W, 3]`` (float32).
72
+ mean: Per-channel means used during normalisation.
73
+ std: Per-channel standard deviations used during normalisation.
74
+
75
+ Returns:
76
+ Denormalised image clipped to ``[0, 1]``.
77
+ """
78
+ return np.clip(img * std + mean, 0.0, 1.0)
79
+
80
+
81
+ def _denorm_tensor(tensor: torch.Tensor) -> np.ndarray:
82
+ """Shortcut: ``[C, H, W]`` tensor → denormalised ``[H, W, C]`` numpy.
83
 
84
  Args:
85
+ tensor: ImageNet-normalised image ``[C, H, W]``.
 
 
86
 
87
  Returns:
88
+ Denormalised numpy array ``[H, W, C]`` in ``[0, 1]``.
89
  """
90
+ return denormalize(_to_numpy_hwc(tensor))
 
91
 
92
 
93
+ # ---------------------------------------------------------------------------
94
+ # 1. plot_prediction
95
+ # ---------------------------------------------------------------------------
96
+
97
  def plot_prediction(
98
  img_a: torch.Tensor,
99
  img_b: torch.Tensor,
100
+ mask_true: torch.Tensor,
101
  mask_pred: torch.Tensor,
102
+ filename: Optional[Union[str, Path]] = None,
103
  ) -> plt.Figure:
104
+ """Plot a single change-detection prediction as a 1×4 grid.
105
+
106
+ Columns: **Before (A)** | **After (B)** | **Ground Truth** | **Prediction**.
107
 
108
+ Images are denormalised from ImageNet stats before display. Masks are
109
+ rendered in binary black / white.
110
 
111
  Args:
112
+ img_a: Before image ``[3, H, W]`` (ImageNet-normalised).
113
+ img_b: After image ``[3, H, W]`` (ImageNet-normalised).
114
+ mask_true: Ground-truth binary mask ``[1, H, W]`` (0 or 1).
115
+ mask_pred: Predicted mask ``[1, H, W]`` (binary or probability).
116
+ filename: If provided, save the figure to this path and close it.
117
+ Otherwise the caller is responsible for ``plt.close(fig)``.
118
 
119
  Returns:
120
+ The ``matplotlib.figure.Figure`` object.
121
  """
122
+ a_np = _denorm_tensor(img_a)
123
+ b_np = _denorm_tensor(img_b)
124
+ gt_np = _mask_to_numpy(mask_true)
125
+ pred_np = _mask_to_numpy(mask_pred)
126
 
127
+ # Binarise prediction for clean display (handles probability maps)
128
+ pred_np = (pred_np > 0.5).astype(np.float32)
 
 
 
129
 
130
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4))
131
  titles = ["Before (A)", "After (B)", "Ground Truth", "Prediction"]
132
+ images = [a_np, b_np, gt_np, pred_np]
133
  cmaps = [None, None, "gray", "gray"]
134
 
135
  for ax, img, title, cmap in zip(axes, images, titles, cmaps):
136
  ax.imshow(img, cmap=cmap, vmin=0, vmax=1)
137
+ ax.set_title(title, fontsize=11)
138
  ax.axis("off")
139
 
140
+ fig.tight_layout(pad=1.0)
141
 
142
+ if filename is not None:
143
+ path = Path(filename)
144
+ path.parent.mkdir(parents=True, exist_ok=True)
145
+ fig.savefig(path, dpi=150, bbox_inches="tight")
146
+ plt.close(fig)
147
+ logger.debug("Saved prediction plot: %s", path)
148
 
149
  return fig
150
 
151
 
152
+ # ---------------------------------------------------------------------------
153
+ # 2. overlay_changes
154
+ # ---------------------------------------------------------------------------
155
+
156
  def overlay_changes(
157
+ img_after: torch.Tensor,
158
  mask_pred: torch.Tensor,
159
  alpha: float = 0.4,
160
+ color: Tuple[int, int, int] = (255, 0, 0),
161
  ) -> np.ndarray:
162
+ """Overlay predicted change pixels on the *after* image.
163
+
164
+ Changed pixels are tinted with ``color`` at the given ``alpha``
165
+ transparency; unchanged pixels are left as-is.
166
 
167
  Args:
168
+ img_after: After image ``[3, H, W]`` (ImageNet-normalised).
169
+ mask_pred: Predicted binary mask ``[1, H, W]`` (0 or 1).
170
+ alpha: Blending factor for the overlay colour (0 = transparent,
171
+ 1 = fully opaque).
172
+ color: RGB overlay colour as **uint8** values in ``[0, 255]``
173
+ (default red).
174
 
175
  Returns:
176
+ Composited RGB image as a **uint8** numpy array ``[H, W, 3]``
177
+ with values in ``[0, 255]``, ready for ``cv2.imwrite`` or display.
178
  """
179
+ base = _denorm_tensor(img_after) # [H, W, 3], float32 in [0, 1]
180
+ mask = _mask_to_numpy(mask_pred) # [H, W], float32
181
+
182
+ # Normalise colour to [0, 1]
183
+ color_f = np.array(color, dtype=np.float32) / 255.0
184
 
185
+ overlay = base.copy()
186
+ change_mask = mask > 0.5
187
  for c in range(3):
188
  overlay[:, :, c] = np.where(
189
+ change_mask,
190
+ base[:, :, c] * (1.0 - alpha) + color_f[c] * alpha,
191
+ base[:, :, c],
192
  )
193
+
194
+ return (overlay * 255.0).astype(np.uint8)
195
 
196
 
197
+ # ---------------------------------------------------------------------------
198
+ # 3. plot_metrics_history
199
+ # ---------------------------------------------------------------------------
200
+
201
  def plot_metrics_history(
202
+ history_dict: Dict[str, List[float]],
203
+ save_path: Optional[Union[str, Path]] = None,
204
  ) -> plt.Figure:
205
+ """Plot training / validation metric curves across epochs.
206
+
207
+ Creates one subplot per metric key. Suitable for inclusion in reports
208
+ or as a TensorBoard-compatible image.
209
 
210
  Args:
211
+ history_dict: Mapping from metric name to a list of per-epoch
212
+ values, e.g. ``{"f1": [0.5, 0.6, ...], "loss": [0.8, ...]}``.
213
+ save_path: If provided, save the figure and close it.
214
 
215
  Returns:
216
+ The ``matplotlib.figure.Figure`` object.
217
  """
218
+ n_metrics = len(history_dict)
219
+ if n_metrics == 0:
220
+ fig, _ = plt.subplots()
221
+ return fig
222
 
223
+ fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4))
224
  if n_metrics == 1:
225
  axes = [axes]
226
 
227
+ for ax, (name, values) in zip(axes, history_dict.items()):
228
+ epochs = list(range(1, len(values) + 1))
229
+ ax.plot(epochs, values, marker="o", markersize=3, linewidth=1.5)
230
+ ax.set_title(name.upper(), fontsize=11)
231
  ax.set_xlabel("Epoch")
232
  ax.set_ylabel(name)
233
  ax.grid(True, alpha=0.3)
234
 
235
+ fig.tight_layout(pad=1.5)
236
 
237
  if save_path is not None:
238
+ path = Path(save_path)
239
+ path.parent.mkdir(parents=True, exist_ok=True)
240
+ fig.savefig(path, dpi=150, bbox_inches="tight")
241
+ plt.close(fig)
242
+ logger.debug("Saved metrics plot: %s", path)
243
 
244
  return fig
245
+
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # 4. log_predictions_to_tensorboard
249
+ # ---------------------------------------------------------------------------
250
+
251
+ def log_predictions_to_tensorboard(
252
+ writer: SummaryWriter,
253
+ img_a: torch.Tensor,
254
+ img_b: torch.Tensor,
255
+ mask_true: torch.Tensor,
256
+ mask_pred: torch.Tensor,
257
+ step: int,
258
+ num_samples: int = 4,
259
+ ) -> None:
260
+ """Log a grid of sample predictions to TensorBoard.
261
+
262
+ For each sample the grid contains four rows:
263
+ *Before*, *After*, *Ground Truth*, *Prediction*.
264
+
265
+ Images are denormalised; masks are expanded to 3-channel for consistent
266
+ grid rendering.
267
+
268
+ Args:
269
+ writer: Active ``SummaryWriter`` instance.
270
+ img_a: Before images ``[B, 3, H, W]`` (ImageNet-normalised).
271
+ img_b: After images ``[B, 3, H, W]`` (ImageNet-normalised).
272
+ mask_true: Ground-truth masks ``[B, 1, H, W]`` (binary).
273
+ mask_pred: Predicted masks ``[B, 1, H, W]`` (binary or probability).
274
+ step: Global training step (used as the x-axis in TensorBoard).
275
+ num_samples: How many samples from the batch to include (taken
276
+ from the front of the batch dimension).
277
+ """
278
+ n = min(num_samples, img_a.size(0))
279
+
280
+ # Denormalise images on CPU (keep as tensors for vutils.make_grid)
281
+ mean = torch.tensor(_IMAGENET_MEAN).view(1, 3, 1, 1)
282
+ std = torch.tensor(_IMAGENET_STD).view(1, 3, 1, 1)
283
+
284
+ a = (img_a[:n].cpu().float() * std + mean).clamp(0.0, 1.0)
285
+ b = (img_b[:n].cpu().float() * std + mean).clamp(0.0, 1.0)
286
+
287
+ # Expand single-channel masks to 3-channel for the grid
288
+ gt = mask_true[:n].cpu().float().expand(-1, 3, -1, -1)
289
+ pred = (mask_pred[:n].cpu().float() > 0.5).float().expand(-1, 3, -1, -1)
290
+
291
+ # Interleave: [a0, b0, gt0, pred0, a1, b1, gt1, pred1, ...]
292
+ rows = []
293
+ for i in range(n):
294
+ rows.extend([a[i], b[i], gt[i], pred[i]])
295
+
296
+ grid = vutils.make_grid(rows, nrow=4, padding=2, normalize=False)
297
+ writer.add_image("Predictions/before_after_gt_pred", grid, global_step=step)