Spaces:
Sleeping
Sleeping
| """Quick OSCD RGB download script.""" | |
| from datasets import load_dataset | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| import os | |
| # Suppress symlinks warning on Windows | |
| os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' | |
| print("Downloading OSCD RGB dataset...") | |
| try: | |
| ds = load_dataset("blanchon/OSCD_RGB", split="train") | |
| print(f"Downloaded {len(ds)} samples (taking first 5)") | |
| print(f"Sample keys: {ds[0].keys()}") | |
| # Save samples | |
| examples_dir = Path("examples") | |
| (examples_dir / "before").mkdir(parents=True, exist_ok=True) | |
| (examples_dir / "after").mkdir(parents=True, exist_ok=True) | |
| (examples_dir / "ground_truth").mkdir(parents=True, exist_ok=True) | |
| for idx in range(min(5, len(ds))): | |
| try: | |
| sample = ds[idx] | |
| # OSCD_RGB dataset uses 'image1', 'image2', 'mask' keys | |
| if 'image1' not in sample or 'image2' not in sample or 'mask' not in sample: | |
| print(f" Expected keys not found. Available keys: {sample.keys()}") | |
| continue | |
| before = np.array(sample['image1'], dtype=np.uint8) | |
| after = np.array(sample['image2'], dtype=np.uint8) | |
| gt = np.array(sample['mask'], dtype=np.uint8) | |
| # Ensure 3-channel RGB | |
| if before.ndim == 3 and before.shape[2] >= 3: | |
| before = before[:, :, :3] | |
| if after.ndim == 3 and after.shape[2] >= 3: | |
| after = after[:, :, :3] | |
| # Save images | |
| before_path = examples_dir / "before" / f"oscd_{idx:02d}.png" | |
| after_path = examples_dir / "after" / f"oscd_{idx:02d}.png" | |
| gt_path = examples_dir / "ground_truth" / f"oscd_{idx:02d}.png" | |
| # Convert RGB to BGR for cv2 (if not already BGR) | |
| if before.dtype == np.uint8: | |
| before_bgr = cv2.cvtColor(before, cv2.COLOR_RGB2BGR) if before.max() > 1 else before | |
| after_bgr = cv2.cvtColor(after, cv2.COLOR_RGB2BGR) if after.max() > 1 else after | |
| else: | |
| before_bgr = before | |
| after_bgr = after | |
| cv2.imwrite(str(before_path), before_bgr) | |
| cv2.imwrite(str(after_path), after_bgr) | |
| cv2.imwrite(str(gt_path), gt * 255 if gt.max() <= 1 else gt) | |
| print(f"✓ Saved sample {idx+1}: before={before.shape}, after={after.shape}, gt={gt.shape}") | |
| except Exception as e: | |
| print(f"✗ Error saving sample {idx}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\n OSCD RGB images downloaded successfully!") | |
| except Exception as e: | |
| print(f"Error downloading dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\nMake sure internet is connected and try again") | |