Spaces:
Runtime error
Runtime error
| """PyTorch Dataset for change detection tasks. | |
| Loads pre-cropped 256x256 image patches (before/after) and binary change masks. | |
| Supports synchronized augmentations via albumentations.ReplayCompose. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Tuple | |
| import albumentations as A | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| logger = logging.getLogger(__name__) | |
| # ImageNet normalization constants | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| def get_train_transforms(config: Dict[str, Any]) -> A.ReplayCompose: | |
| """Build training augmentation pipeline with synchronized transforms. | |
| Args: | |
| config: Augmentation config dict from config.yaml. | |
| Returns: | |
| ReplayCompose that applies identical spatial transforms to A, B, and mask. | |
| """ | |
| aug_cfg = config.get("augmentation", {}) | |
| transforms = [] | |
| if aug_cfg.get("horizontal_flip", 0) > 0: | |
| transforms.append(A.HorizontalFlip(p=aug_cfg["horizontal_flip"])) | |
| if aug_cfg.get("vertical_flip", 0) > 0: | |
| transforms.append(A.VerticalFlip(p=aug_cfg["vertical_flip"])) | |
| if aug_cfg.get("random_rotate_90", 0) > 0: | |
| transforms.append(A.RandomRotate90(p=aug_cfg["random_rotate_90"])) | |
| transforms.append(A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)) | |
| return A.ReplayCompose( | |
| transforms, | |
| additional_targets={"image_b": "image", "mask": "mask"}, | |
| ) | |
| def get_val_transforms() -> A.Compose: | |
| """Build validation/test transform pipeline (normalize only). | |
| Returns: | |
| Compose with ImageNet normalization only. | |
| """ | |
| return A.Compose( | |
| [A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)], | |
| additional_targets={"image_b": "image"}, | |
| ) | |
| class ChangeDetectionDataset(Dataset): | |
| """Dataset for loading change detection image pairs and masks. | |
| Expects directory structure: | |
| root/ | |
| βββ A/ # before images | |
| βββ B/ # after images | |
| βββ label/ # binary change masks (0=no change, 255=change) | |
| Args: | |
| root: Path to the split directory (e.g., processed_data/train). | |
| split: One of 'train', 'val', 'test'. | |
| config: Full config dict for augmentation settings. | |
| transform: Optional override for the transform pipeline. | |
| """ | |
| def __init__( | |
| self, | |
| root: Path, | |
| split: str = "train", | |
| config: Optional[Dict[str, Any]] = None, | |
| transform: Optional[Any] = None, | |
| ) -> None: | |
| self.root = Path(root) | |
| self.split = split | |
| self.dir_a = self.root / "A" | |
| self.dir_b = self.root / "B" | |
| self.dir_label = self.root / "label" | |
| # Collect sorted file lists | |
| self.filenames = sorted([f.name for f in self.dir_a.iterdir() if f.suffix in (".png", ".jpg", ".tif")]) | |
| logger.info("Loaded %d samples for split '%s' from %s", len(self.filenames), split, root) | |
| # Set up transforms | |
| if transform is not None: | |
| self.transform = transform | |
| elif split == "train" and config is not None: | |
| self.transform = get_train_transforms(config) | |
| else: | |
| self.transform = get_val_transforms() | |
| def __len__(self) -> int: | |
| return len(self.filenames) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| """Load a single sample. | |
| Args: | |
| idx: Sample index. | |
| Returns: | |
| Dict with keys 'A', 'B', 'mask', 'filename'. | |
| - A: before image tensor [3, H, W] | |
| - B: after image tensor [3, H, W] | |
| - mask: binary change mask tensor [1, H, W] (float, 0 or 1) | |
| - filename: original filename string | |
| """ | |
| fname = self.filenames[idx] | |
| # Lazy load β read from disk each time (no RAM caching) | |
| img_a = cv2.imread(str(self.dir_a / fname), cv2.IMREAD_COLOR) | |
| img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB) | |
| img_b = cv2.imread(str(self.dir_b / fname), cv2.IMREAD_COLOR) | |
| img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB) | |
| mask = cv2.imread(str(self.dir_label / fname), cv2.IMREAD_GRAYSCALE) | |
| # Normalize 0/255 -> 0/1 | |
| mask = (mask / 255.0).astype(np.float32) | |
| # Apply synchronized augmentations | |
| if isinstance(self.transform, A.ReplayCompose): | |
| transformed = self.transform(image=img_a, image_b=img_b, mask=mask) | |
| img_a = transformed["image"] | |
| img_b = transformed["image_b"] | |
| mask = transformed["mask"] | |
| else: | |
| transformed = self.transform(image=img_a, image_b=img_b) | |
| img_a = transformed["image"] | |
| img_b = transformed["image_b"] | |
| # Normalize only applied to images, mask stays as-is | |
| # HWC -> CHW for images, add channel dim for mask | |
| img_a = torch.from_numpy(img_a).permute(2, 0, 1).float() | |
| img_b = torch.from_numpy(img_b).permute(2, 0, 1).float() | |
| mask = torch.from_numpy(mask).unsqueeze(0).float() | |
| return {"A": img_a, "B": img_b, "mask": mask, "filename": fname} | |