BaseChange / data /dataset.py
Vedant Jigarbhai Mehta
Initial scaffolding for military base change detection project
b25c087
"""PyTorch Dataset for change detection tasks.
Loads pre-cropped 256x256 image patches (before/after) and binary change masks.
Supports synchronized augmentations via albumentations.ReplayCompose.
"""
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import albumentations as A
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
# ImageNet normalization constants
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def get_train_transforms(config: Dict[str, Any]) -> A.ReplayCompose:
"""Build training augmentation pipeline with synchronized transforms.
Args:
config: Augmentation config dict from config.yaml.
Returns:
ReplayCompose that applies identical spatial transforms to A, B, and mask.
"""
aug_cfg = config.get("augmentation", {})
transforms = []
if aug_cfg.get("horizontal_flip", 0) > 0:
transforms.append(A.HorizontalFlip(p=aug_cfg["horizontal_flip"]))
if aug_cfg.get("vertical_flip", 0) > 0:
transforms.append(A.VerticalFlip(p=aug_cfg["vertical_flip"]))
if aug_cfg.get("random_rotate_90", 0) > 0:
transforms.append(A.RandomRotate90(p=aug_cfg["random_rotate_90"]))
transforms.append(A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD))
return A.ReplayCompose(
transforms,
additional_targets={"image_b": "image", "mask": "mask"},
)
def get_val_transforms() -> A.Compose:
"""Build validation/test transform pipeline (normalize only).
Returns:
Compose with ImageNet normalization only.
"""
return A.Compose(
[A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)],
additional_targets={"image_b": "image"},
)
class ChangeDetectionDataset(Dataset):
"""Dataset for loading change detection image pairs and masks.
Expects directory structure:
root/
β”œβ”€β”€ A/ # before images
β”œβ”€β”€ B/ # after images
└── label/ # binary change masks (0=no change, 255=change)
Args:
root: Path to the split directory (e.g., processed_data/train).
split: One of 'train', 'val', 'test'.
config: Full config dict for augmentation settings.
transform: Optional override for the transform pipeline.
"""
def __init__(
self,
root: Path,
split: str = "train",
config: Optional[Dict[str, Any]] = None,
transform: Optional[Any] = None,
) -> None:
self.root = Path(root)
self.split = split
self.dir_a = self.root / "A"
self.dir_b = self.root / "B"
self.dir_label = self.root / "label"
# Collect sorted file lists
self.filenames = sorted([f.name for f in self.dir_a.iterdir() if f.suffix in (".png", ".jpg", ".tif")])
logger.info("Loaded %d samples for split '%s' from %s", len(self.filenames), split, root)
# Set up transforms
if transform is not None:
self.transform = transform
elif split == "train" and config is not None:
self.transform = get_train_transforms(config)
else:
self.transform = get_val_transforms()
def __len__(self) -> int:
return len(self.filenames)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Load a single sample.
Args:
idx: Sample index.
Returns:
Dict with keys 'A', 'B', 'mask', 'filename'.
- A: before image tensor [3, H, W]
- B: after image tensor [3, H, W]
- mask: binary change mask tensor [1, H, W] (float, 0 or 1)
- filename: original filename string
"""
fname = self.filenames[idx]
# Lazy load β€” read from disk each time (no RAM caching)
img_a = cv2.imread(str(self.dir_a / fname), cv2.IMREAD_COLOR)
img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
img_b = cv2.imread(str(self.dir_b / fname), cv2.IMREAD_COLOR)
img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB)
mask = cv2.imread(str(self.dir_label / fname), cv2.IMREAD_GRAYSCALE)
# Normalize 0/255 -> 0/1
mask = (mask / 255.0).astype(np.float32)
# Apply synchronized augmentations
if isinstance(self.transform, A.ReplayCompose):
transformed = self.transform(image=img_a, image_b=img_b, mask=mask)
img_a = transformed["image"]
img_b = transformed["image_b"]
mask = transformed["mask"]
else:
transformed = self.transform(image=img_a, image_b=img_b)
img_a = transformed["image"]
img_b = transformed["image_b"]
# Normalize only applied to images, mask stays as-is
# HWC -> CHW for images, add channel dim for mask
img_a = torch.from_numpy(img_a).permute(2, 0, 1).float()
img_b = torch.from_numpy(img_b).permute(2, 0, 1).float()
mask = torch.from_numpy(mask).unsqueeze(0).float()
return {"A": img_a, "B": img_b, "mask": mask, "filename": fname}