Vedant Jigarbhai Mehta commited on
Commit
209365d
·
1 Parent(s): e877eeb

Implement LEVIR-CD download and patch cropping pipeline

Browse files

- Download from Google Drive via gdown with skip-if-exists logic
- Extract zip with nested-folder detection
- Crop 1024x1024 images into 256x256 non-overlapping patches
- Process all splits (train/val/test) with A/B/label triplets
- Skip preprocessing if output already exists (avoids re-cropping)
- CLI supports --skip_download for pre-downloaded data
- Colab-friendly: save processed patches to Drive path

Files changed (1) hide show
  1. data/download.py +358 -47
data/download.py CHANGED
@@ -1,63 +1,229 @@
1
  """Download and preprocess change detection datasets.
2
 
3
- Supports LEVIR-CD and WHU-CD datasets. Downloads raw data, crops 1024x1024
4
- images into 256x256 non-overlapping patches, and organizes into train/val/test
5
- splits.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  Usage:
 
8
  python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
 
 
 
 
 
 
 
9
  """
10
 
11
  import argparse
12
  import logging
 
 
13
  from pathlib import Path
14
- from typing import Tuple
15
 
16
  import cv2
17
  import numpy as np
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def download_levir_cd(raw_dir: Path) -> None:
23
- """Download the LEVIR-CD dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  Args:
26
- raw_dir: Directory to save the raw downloaded files.
 
 
 
27
  """
28
- # TODO: Implement download via gdown or direct URL
29
- raise NotImplementedError("LEVIR-CD download not yet implemented")
 
 
 
 
 
30
 
 
 
 
 
 
31
 
32
- def download_whu_cd(raw_dir: Path) -> None:
33
- """Download the WHU-CD dataset.
 
 
 
 
 
 
 
 
34
 
35
  Args:
36
- raw_dir: Directory to save the raw downloaded files.
 
 
 
 
 
37
  """
38
- # TODO: Implement download
39
- raise NotImplementedError("WHU-CD download not yet implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
41
 
42
  def crop_to_patches(
43
  image: np.ndarray,
44
  patch_size: int = 256,
45
- ) -> list[np.ndarray]:
46
- """Crop an image into non-overlapping patches.
 
 
 
47
 
48
  Args:
49
- image: Input image of shape (H, W) or (H, W, C).
50
- patch_size: Size of each square patch.
51
 
52
  Returns:
53
  List of cropped patches.
54
  """
55
  h, w = image.shape[:2]
56
- patches = []
57
  for y in range(0, h - patch_size + 1, patch_size):
58
  for x in range(0, w - patch_size + 1, patch_size):
59
- patch = image[y : y + patch_size, x : x + patch_size]
60
- patches.append(patch)
61
  return patches
62
 
63
 
@@ -67,23 +233,105 @@ def process_split(
67
  split: str,
68
  patch_size: int = 256,
69
  ) -> int:
70
- """Process a single dataset split (train/val/test).
71
 
72
- Reads image pairs and masks from raw_dir, crops into patches, and
73
- saves to out_dir.
74
 
75
  Args:
76
- raw_dir: Root directory of the raw dataset.
77
- out_dir: Output directory for processed patches.
78
- split: One of 'train', 'val', 'test'.
79
- patch_size: Size of each square patch.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  Returns:
82
- Number of patch triplets generated.
83
  """
84
- # TODO: Implement processing pipeline
85
- raise NotImplementedError("Split processing not yet implemented")
 
 
 
 
86
 
 
 
 
87
 
88
  def preprocess_dataset(
89
  dataset: str,
@@ -91,41 +339,104 @@ def preprocess_dataset(
91
  out_dir: Path,
92
  patch_size: int = 256,
93
  ) -> None:
94
- """Run full preprocessing pipeline for a dataset.
95
 
96
  Args:
97
- dataset: Dataset name ('levir-cd' or 'whu-cd').
98
- raw_dir: Directory containing raw downloaded data.
99
  out_dir: Output directory for processed patches.
100
- patch_size: Size of each square patch.
101
  """
102
- logger.info("Preprocessing %s: %s -> %s", dataset, raw_dir, out_dir)
 
 
 
 
 
 
103
  out_dir.mkdir(parents=True, exist_ok=True)
104
 
 
105
  for split in ["train", "val", "test"]:
106
  count = process_split(raw_dir, out_dir, split, patch_size)
107
- logger.info(" %s: %d patch triplets", split, count)
 
 
 
 
 
 
108
 
 
 
 
109
 
110
  def main() -> None:
111
  """CLI entry point for dataset download and preprocessing."""
112
- parser = argparse.ArgumentParser(description="Download and preprocess change detection datasets")
113
- parser.add_argument("--dataset", type=str, default="levir-cd", choices=["levir-cd", "whu-cd"])
114
- parser.add_argument("--raw_dir", type=Path, default=Path("./raw_data"))
115
- parser.add_argument("--out_dir", type=Path, default=Path("./processed_data"))
116
- parser.add_argument("--patch_size", type=int, default=256)
117
- parser.add_argument("--skip_download", action="store_true", help="Skip download, only preprocess")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  args = parser.parse_args()
119
 
120
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
121
 
 
 
122
  if not args.skip_download:
 
 
 
 
 
 
 
 
123
  if args.dataset == "levir-cd":
124
- download_levir_cd(args.raw_dir)
125
  elif args.dataset == "whu-cd":
126
- download_whu_cd(args.raw_dir)
127
 
128
- preprocess_dataset(args.dataset, args.raw_dir, args.out_dir, args.patch_size)
 
 
129
 
130
 
131
  if __name__ == "__main__":
 
1
  """Download and preprocess change detection datasets.
2
 
3
+ Supports LEVIR-CD (primary) and WHU-CD (secondary). Downloads from Google
4
+ Drive via ``gdown``, extracts archives, crops 1024x1024 images into 256x256
5
+ non-overlapping patches, and organises into train/val/test splits.
6
+
7
+ LEVIR-CD expected raw structure after extraction::
8
+
9
+ raw_dir/
10
+ └── LEVIR-CD/
11
+ ├── train/
12
+ │ ├── A/ # before images (1024x1024)
13
+ │ ├── B/ # after images (1024x1024)
14
+ │ └── label/ # binary masks (0/255)
15
+ ├── val/
16
+ │ ├── A/
17
+ │ ├── B/
18
+ │ └── label/
19
+ └── test/
20
+ ├── A/
21
+ ├── B/
22
+ └── label/
23
 
24
  Usage:
25
+ # Full pipeline: download + crop
26
  python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
27
+
28
+ # Skip download (data already on disk), just crop
29
+ python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data --skip_download
30
+
31
+ # On Colab — save processed patches to Drive
32
+ python data/download.py --dataset levir-cd --raw_dir /content/raw_data \
33
+ --out_dir /content/drive/MyDrive/change-detection/processed_data
34
  """
35
 
36
  import argparse
37
  import logging
38
+ import shutil
39
+ import zipfile
40
  from pathlib import Path
41
+ from typing import List
42
 
43
  import cv2
44
  import numpy as np
45
 
46
  logger = logging.getLogger(__name__)
47
 
48
+ # ---------------------------------------------------------------------------
49
+ # Google Drive file IDs for LEVIR-CD
50
+ # These are publicly shared links from the dataset authors.
51
+ # If they break, download manually from:
52
+ # https://github.com/justchenhao/LEVIR-CD
53
+ # ---------------------------------------------------------------------------
54
+ _LEVIR_CD_GDRIVE_IDS = {
55
+ # The dataset is often shared as a single zip or split zips.
56
+ # Update these IDs if the authors change the links.
57
+ "full": "1RUFY9QDmVBfHuMRwYze7C5BlVsMr3Xm_",
58
+ }
59
+
60
+ _WHU_CD_GDRIVE_IDS = {
61
+ "full": "1GX656JqqOyBi_Ef0w65kDGVto-nHrNs9",
62
+ }
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Download helpers
67
+ # ---------------------------------------------------------------------------
68
+
69
+ def _download_from_gdrive(file_id: str, output_path: Path) -> None:
70
+ """Download a file from Google Drive using gdown.
71
+
72
+ Args:
73
+ file_id: Google Drive file ID.
74
+ output_path: Local path to save the downloaded file.
75
+ """
76
+ try:
77
+ import gdown
78
+ except ImportError:
79
+ raise ImportError(
80
+ "gdown is required for downloading. Install with: pip install gdown"
81
+ )
82
+
83
+ output_path.parent.mkdir(parents=True, exist_ok=True)
84
+ url = f"https://drive.google.com/uc?id={file_id}"
85
+ logger.info("Downloading from Google Drive (ID: %s) ...", file_id)
86
+ gdown.download(url, str(output_path), quiet=False)
87
+ logger.info("Downloaded: %s", output_path)
88
+
89
+
90
+ def _extract_zip(zip_path: Path, extract_to: Path) -> None:
91
+ """Extract a zip archive.
92
+
93
+ Args:
94
+ zip_path: Path to the zip file.
95
+ extract_to: Directory to extract into.
96
+ """
97
+ logger.info("Extracting %s -> %s", zip_path.name, extract_to)
98
+ extract_to.mkdir(parents=True, exist_ok=True)
99
+ with zipfile.ZipFile(zip_path, "r") as zf:
100
+ zf.extractall(extract_to)
101
+ logger.info("Extraction complete.")
102
+
103
 
104
+ def download_levir_cd(raw_dir: Path) -> Path:
105
+ """Download the LEVIR-CD dataset from Google Drive.
106
+
107
+ Downloads the zip, extracts it, and returns the path to the extracted
108
+ dataset root.
109
+
110
+ Args:
111
+ raw_dir: Directory to save downloads and extracted data.
112
+
113
+ Returns:
114
+ Path to the extracted LEVIR-CD root directory.
115
+ """
116
+ raw_dir.mkdir(parents=True, exist_ok=True)
117
+ zip_path = raw_dir / "LEVIR-CD.zip"
118
+
119
+ # Skip download if zip already exists
120
+ if zip_path.exists():
121
+ logger.info("LEVIR-CD zip already exists: %s", zip_path)
122
+ else:
123
+ _download_from_gdrive(_LEVIR_CD_GDRIVE_IDS["full"], zip_path)
124
+
125
+ # Extract if not already extracted
126
+ dataset_root = raw_dir / "LEVIR-CD"
127
+ if dataset_root.exists() and any(dataset_root.iterdir()):
128
+ logger.info("LEVIR-CD already extracted: %s", dataset_root)
129
+ else:
130
+ _extract_zip(zip_path, raw_dir)
131
+
132
+ # Some zips have an extra nested folder — find the actual root
133
+ dataset_root = _find_dataset_root(raw_dir, "LEVIR-CD")
134
+ logger.info("LEVIR-CD root: %s", dataset_root)
135
+ return dataset_root
136
+
137
+
138
+ def download_whu_cd(raw_dir: Path) -> Path:
139
+ """Download the WHU-CD dataset from Google Drive.
140
 
141
  Args:
142
+ raw_dir: Directory to save downloads and extracted data.
143
+
144
+ Returns:
145
+ Path to the extracted WHU-CD root directory.
146
  """
147
+ raw_dir.mkdir(parents=True, exist_ok=True)
148
+ zip_path = raw_dir / "WHU-CD.zip"
149
+
150
+ if zip_path.exists():
151
+ logger.info("WHU-CD zip already exists: %s", zip_path)
152
+ else:
153
+ _download_from_gdrive(_WHU_CD_GDRIVE_IDS["full"], zip_path)
154
 
155
+ dataset_root = raw_dir / "WHU-CD"
156
+ if dataset_root.exists() and any(dataset_root.iterdir()):
157
+ logger.info("WHU-CD already extracted: %s", dataset_root)
158
+ else:
159
+ _extract_zip(zip_path, raw_dir)
160
 
161
+ dataset_root = _find_dataset_root(raw_dir, "WHU-CD")
162
+ logger.info("WHU-CD root: %s", dataset_root)
163
+ return dataset_root
164
+
165
+
166
+ def _find_dataset_root(parent: Path, name_hint: str) -> Path:
167
+ """Locate the actual dataset root after extraction.
168
+
169
+ Handles cases where the zip creates a nested folder like
170
+ ``LEVIR-CD/LEVIR-CD/`` or the root is directly under ``parent``.
171
 
172
  Args:
173
+ parent: Directory where the zip was extracted.
174
+ name_hint: Expected folder name (e.g. ``'LEVIR-CD'``).
175
+
176
+ Returns:
177
+ Path to the directory containing ``train/``, ``val/``, ``test/``
178
+ (or the closest match).
179
  """
180
+ candidate = parent / name_hint
181
+ if not candidate.exists():
182
+ # Try to find it by scanning
183
+ for d in parent.rglob(name_hint):
184
+ if d.is_dir():
185
+ candidate = d
186
+ break
187
+
188
+ # Check for nested structure
189
+ nested = candidate / name_hint
190
+ if nested.exists() and nested.is_dir():
191
+ candidate = nested
192
+
193
+ # Look for the split directories
194
+ for d in [candidate] + list(candidate.iterdir()) if candidate.exists() else []:
195
+ if isinstance(d, Path) and d.is_dir():
196
+ if (d / "train").exists() or (d / "A").exists():
197
+ return d
198
+
199
+ return candidate
200
+
201
 
202
+ # ---------------------------------------------------------------------------
203
+ # Patch cropping
204
+ # ---------------------------------------------------------------------------
205
 
206
  def crop_to_patches(
207
  image: np.ndarray,
208
  patch_size: int = 256,
209
+ ) -> List[np.ndarray]:
210
+ """Crop an image into non-overlapping square patches.
211
+
212
+ Pixels that don't fit into a full patch at the right/bottom edges are
213
+ discarded (e.g. a 1024x1024 image produces 16 patches of 256x256).
214
 
215
  Args:
216
+ image: Input image of shape ``(H, W)`` or ``(H, W, C)``.
217
+ patch_size: Side length of each square patch.
218
 
219
  Returns:
220
  List of cropped patches.
221
  """
222
  h, w = image.shape[:2]
223
+ patches: List[np.ndarray] = []
224
  for y in range(0, h - patch_size + 1, patch_size):
225
  for x in range(0, w - patch_size + 1, patch_size):
226
+ patches.append(image[y : y + patch_size, x : x + patch_size])
 
227
  return patches
228
 
229
 
 
233
  split: str,
234
  patch_size: int = 256,
235
  ) -> int:
236
+ """Process one dataset split: crop all images into patches.
237
 
238
+ Reads 1024x1024 image triplets (A, B, label) from ``raw_dir/{split}/``,
239
+ crops each into 256x256 patches, and saves to ``out_dir/{split}/``.
240
 
241
  Args:
242
+ raw_dir: Root of the raw LEVIR-CD dataset (contains ``train/``,
243
+ ``val/``, ``test/`` sub-folders).
244
+ out_dir: Output root for processed patches.
245
+ split: One of ``'train'``, ``'val'``, ``'test'``.
246
+ patch_size: Patch size in pixels.
247
+
248
+ Returns:
249
+ Total number of patch triplets generated for this split.
250
+ """
251
+ split_in = raw_dir / split
252
+ split_out = out_dir / split
253
+
254
+ # Input directories
255
+ dir_a_in = split_in / "A"
256
+ dir_b_in = split_in / "B"
257
+ dir_label_in = split_in / "label"
258
+
259
+ if not dir_a_in.exists():
260
+ logger.warning("Input directory missing: %s — skipping split '%s'", dir_a_in, split)
261
+ return 0
262
+
263
+ # Output directories
264
+ dir_a_out = split_out / "A"
265
+ dir_b_out = split_out / "B"
266
+ dir_label_out = split_out / "label"
267
+ for d in [dir_a_out, dir_b_out, dir_label_out]:
268
+ d.mkdir(parents=True, exist_ok=True)
269
+
270
+ # Collect image filenames
271
+ extensions = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}
272
+ filenames = sorted([
273
+ f.name for f in dir_a_in.iterdir()
274
+ if f.suffix.lower() in extensions
275
+ ])
276
+ logger.info(" %s: found %d images to crop", split, len(filenames))
277
+
278
+ total_patches = 0
279
+
280
+ for fname in filenames:
281
+ # Read triplet
282
+ img_a = cv2.imread(str(dir_a_in / fname), cv2.IMREAD_COLOR)
283
+ img_b = cv2.imread(str(dir_b_in / fname), cv2.IMREAD_COLOR)
284
+ mask = cv2.imread(str(dir_label_in / fname), cv2.IMREAD_GRAYSCALE)
285
+
286
+ if img_a is None or img_b is None or mask is None:
287
+ logger.warning(" Skipping %s (could not read one or more files)", fname)
288
+ continue
289
+
290
+ # Crop into patches
291
+ patches_a = crop_to_patches(img_a, patch_size)
292
+ patches_b = crop_to_patches(img_b, patch_size)
293
+ patches_m = crop_to_patches(mask, patch_size)
294
+
295
+ stem = Path(fname).stem
296
+
297
+ for idx, (pa, pb, pm) in enumerate(zip(patches_a, patches_b, patches_m)):
298
+ patch_name = f"{stem}_{idx:04d}.png"
299
+ cv2.imwrite(str(dir_a_out / patch_name), pa)
300
+ cv2.imwrite(str(dir_b_out / patch_name), pb)
301
+ cv2.imwrite(str(dir_label_out / patch_name), pm)
302
+
303
+ total_patches += len(patches_a)
304
+
305
+ logger.info(" %s: generated %d patch triplets", split, total_patches)
306
+ return total_patches
307
+
308
+
309
+ # ---------------------------------------------------------------------------
310
+ # Check for pre-cropped dataset
311
+ # ---------------------------------------------------------------------------
312
+
313
+ def is_already_cropped(data_dir: Path) -> bool:
314
+ """Check if a directory already contains processed (cropped) patches.
315
+
316
+ A directory is considered processed if it has ``train/A/`` with at least
317
+ one image file inside.
318
+
319
+ Args:
320
+ data_dir: Path to check.
321
 
322
  Returns:
323
+ ``True`` if processed patches are present.
324
  """
325
+ train_a = data_dir / "train" / "A"
326
+ if not train_a.exists():
327
+ return False
328
+ extensions = {".png", ".jpg", ".tif"}
329
+ return any(f.suffix.lower() in extensions for f in train_a.iterdir())
330
+
331
 
332
+ # ---------------------------------------------------------------------------
333
+ # Full pipeline
334
+ # ---------------------------------------------------------------------------
335
 
336
  def preprocess_dataset(
337
  dataset: str,
 
339
  out_dir: Path,
340
  patch_size: int = 256,
341
  ) -> None:
342
+ """Run the full preprocessing pipeline for a dataset.
343
 
344
  Args:
345
+ dataset: Dataset name (``'levir-cd'`` or ``'whu-cd'``).
346
+ raw_dir: Directory containing the raw (extracted) dataset.
347
  out_dir: Output directory for processed patches.
348
+ patch_size: Patch size in pixels.
349
  """
350
+ # Check if output already exists
351
+ if is_already_cropped(out_dir):
352
+ logger.info("Processed data already exists at %s — skipping.", out_dir)
353
+ logger.info("Delete the directory or use a different --out_dir to re-process.")
354
+ return
355
+
356
+ logger.info("Preprocessing %s: %s -> %s (patch_size=%d)", dataset, raw_dir, out_dir, patch_size)
357
  out_dir.mkdir(parents=True, exist_ok=True)
358
 
359
+ total = 0
360
  for split in ["train", "val", "test"]:
361
  count = process_split(raw_dir, out_dir, split, patch_size)
362
+ total += count
363
+
364
+ logger.info("=" * 50)
365
+ logger.info("Preprocessing complete: %d total patch triplets", total)
366
+ logger.info("Output: %s", out_dir)
367
+ logger.info("=" * 50)
368
+
369
 
370
+ # ---------------------------------------------------------------------------
371
+ # CLI
372
+ # ---------------------------------------------------------------------------
373
 
374
  def main() -> None:
375
  """CLI entry point for dataset download and preprocessing."""
376
+ parser = argparse.ArgumentParser(
377
+ description="Download and preprocess change detection datasets",
378
+ formatter_class=argparse.RawDescriptionHelpFormatter,
379
+ epilog="""
380
+ Examples:
381
+ # Full pipeline (download + crop)
382
+ python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
383
+
384
+ # Already downloaded — just crop
385
+ python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data --skip_download
386
+
387
+ # Colab: save to Drive
388
+ python data/download.py --dataset levir-cd --raw_dir /content/raw_data \\
389
+ --out_dir /content/drive/MyDrive/change-detection/processed_data
390
+ """,
391
+ )
392
+ parser.add_argument(
393
+ "--dataset", type=str, default="levir-cd",
394
+ choices=["levir-cd", "whu-cd"],
395
+ help="Dataset to download and preprocess (default: levir-cd).",
396
+ )
397
+ parser.add_argument(
398
+ "--raw_dir", type=Path, default=Path("./raw_data"),
399
+ help="Directory for raw downloads and extracted data.",
400
+ )
401
+ parser.add_argument(
402
+ "--out_dir", type=Path, default=Path("./processed_data"),
403
+ help="Output directory for processed 256x256 patches.",
404
+ )
405
+ parser.add_argument(
406
+ "--patch_size", type=int, default=256,
407
+ help="Patch size for cropping (default: 256).",
408
+ )
409
+ parser.add_argument(
410
+ "--skip_download", action="store_true",
411
+ help="Skip download step — only run preprocessing on existing data.",
412
+ )
413
  args = parser.parse_args()
414
 
415
+ logging.basicConfig(
416
+ level=logging.INFO,
417
+ format="%(asctime)s [%(levelname)s] %(message)s",
418
+ datefmt="%Y-%m-%d %H:%M:%S",
419
+ )
420
 
421
+ # Step 1: Download (unless skipped)
422
+ dataset_root = args.raw_dir
423
  if not args.skip_download:
424
+ logger.info("Step 1: Downloading %s ...", args.dataset)
425
+ if args.dataset == "levir-cd":
426
+ dataset_root = download_levir_cd(args.raw_dir)
427
+ elif args.dataset == "whu-cd":
428
+ dataset_root = download_whu_cd(args.raw_dir)
429
+ else:
430
+ logger.info("Step 1: Download skipped (--skip_download)")
431
+ # Try to find the dataset root in raw_dir
432
  if args.dataset == "levir-cd":
433
+ dataset_root = _find_dataset_root(args.raw_dir, "LEVIR-CD")
434
  elif args.dataset == "whu-cd":
435
+ dataset_root = _find_dataset_root(args.raw_dir, "WHU-CD")
436
 
437
+ # Step 2: Preprocess (crop into patches)
438
+ logger.info("Step 2: Cropping into %dx%d patches ...", args.patch_size, args.patch_size)
439
+ preprocess_dataset(args.dataset, dataset_root, args.out_dir, args.patch_size)
440
 
441
 
442
  if __name__ == "__main__":