SentinelWatch / download_oscd.py
VishaliniS456's picture
Upload 5 files
4cda8b7 verified
"""Quick OSCD RGB download script."""
from datasets import load_dataset
import cv2
import numpy as np
from pathlib import Path
import os
# Suppress symlinks warning on Windows
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
print("Downloading OSCD RGB dataset...")
try:
ds = load_dataset("blanchon/OSCD_RGB", split="train")
print(f"Downloaded {len(ds)} samples (taking first 5)")
print(f"Sample keys: {ds[0].keys()}")
# Save samples
examples_dir = Path("examples")
(examples_dir / "before").mkdir(parents=True, exist_ok=True)
(examples_dir / "after").mkdir(parents=True, exist_ok=True)
(examples_dir / "ground_truth").mkdir(parents=True, exist_ok=True)
for idx in range(min(5, len(ds))):
try:
sample = ds[idx]
# OSCD_RGB dataset uses 'image1', 'image2', 'mask' keys
if 'image1' not in sample or 'image2' not in sample or 'mask' not in sample:
print(f" Expected keys not found. Available keys: {sample.keys()}")
continue
before = np.array(sample['image1'], dtype=np.uint8)
after = np.array(sample['image2'], dtype=np.uint8)
gt = np.array(sample['mask'], dtype=np.uint8)
# Ensure 3-channel RGB
if before.ndim == 3 and before.shape[2] >= 3:
before = before[:, :, :3]
if after.ndim == 3 and after.shape[2] >= 3:
after = after[:, :, :3]
# Save images
before_path = examples_dir / "before" / f"oscd_{idx:02d}.png"
after_path = examples_dir / "after" / f"oscd_{idx:02d}.png"
gt_path = examples_dir / "ground_truth" / f"oscd_{idx:02d}.png"
# Convert RGB to BGR for cv2 (if not already BGR)
if before.dtype == np.uint8:
before_bgr = cv2.cvtColor(before, cv2.COLOR_RGB2BGR) if before.max() > 1 else before
after_bgr = cv2.cvtColor(after, cv2.COLOR_RGB2BGR) if after.max() > 1 else after
else:
before_bgr = before
after_bgr = after
cv2.imwrite(str(before_path), before_bgr)
cv2.imwrite(str(after_path), after_bgr)
cv2.imwrite(str(gt_path), gt * 255 if gt.max() <= 1 else gt)
print(f"✓ Saved sample {idx+1}: before={before.shape}, after={after.shape}, gt={gt.shape}")
except Exception as e:
print(f"✗ Error saving sample {idx}: {e}")
import traceback
traceback.print_exc()
print("\n OSCD RGB images downloaded successfully!")
except Exception as e:
print(f"Error downloading dataset: {e}")
import traceback
traceback.print_exc()
print("\nMake sure internet is connected and try again")