deep123shah456's picture
Upload 3 files
4a3ae84 verified
import os
import shutil
import zipfile
import multiprocessing
from typing import Optional, Callable, Tuple
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import v2 # [GSOC UPGRADE 1] Needed for MixUp/CutMix
from sklearn.model_selection import train_test_split
from torch.utils.data.dataloader import default_collate # [GSOC UPGRADE 1]
# ─────────────────────────────────────────────────────────────────────────────
# DATA STAGING
# ─────────────────────────────────────────────────────────────────────────────
def stage_data_locally(
drive_zip_path: Optional[str],
local_extract_dir: str = "/content/local_dataset",
) -> Optional[str]:
"""
Automates data staging from Google Drive to Colab's local disk.
Prevents PyTorch DataLoader ConnectionAbortedError by bypassing Drive I/O.
FIX: Added an environment guard so this function fails gracefully when
called outside of Google Colab (e.g. local dev machine or CI).
The hardcoded /content/ paths are Colab-specific; on any other system
the function now prints a warning and returns None instead of crashing.
Args:
drive_zip_path (str | None): Full path to the .zip file on Google Drive.
local_extract_dir (str): Target directory on Colab's fast local disk.
Returns:
str | None: Path to the extracted local directory, or None if skipped.
"""
# ── Guard: skip silently outside Colab ───────────────────────────────
if not drive_zip_path:
print("⚠️ No zip path provided. Skipping local staging.")
return None
if not os.path.exists(drive_zip_path):
print(f"⚠️ Zip file not found at: {drive_zip_path}. Skipping local staging.")
return None
# ── Already staged ────────────────────────────────────────────────────
if os.path.exists(local_extract_dir):
print(f"βœ… Local staging already complete at: {local_extract_dir}")
return local_extract_dir
# ── Stage ─────────────────────────────────────────────────────────────
print(f"πŸ“¦ Staging data locally to '{local_extract_dir}' for high-speed I/O...")
os.makedirs(local_extract_dir, exist_ok=True)
# Use a temp path inside the same directory tree so it's always writable,
# even outside /content/ (handles non-Colab environments gracefully).
local_zip = os.path.join(os.path.dirname(local_extract_dir), "_temp_dataset.zip")
try:
shutil.copy2(drive_zip_path, local_zip)
with zipfile.ZipFile(local_zip, 'r') as zip_ref:
zip_ref.extractall(local_extract_dir)
finally:
# Always clean up the temp zip, even if extraction fails
if os.path.exists(local_zip):
os.remove(local_zip)
print("βœ… Data staging complete.")
return local_extract_dir
# ─────────────────────────────────────────────────────────────────────────────
# DATASET CLASS
# ─────────────────────────────────────────────────────────────────────────────
class DeepLenseDataset(Dataset):
"""
Unified PyTorch Dataset for DeepLense gravitational lensing images.
Supports both 'L' (Grayscale / 1-channel) and 'RGB' (3-channel) modes.
The class label is derived from the 'class' column of the metadata CSV.
Class Mapping:
no_sub β†’ 0 (Smooth lens, no dark matter substructure)
cdm β†’ 1 (Cold Dark Matter substructure)
vortex β†’ 2 (Vortex / quantum condensate dark matter)
Args:
dataframe (pd.DataFrame): Subset dataframe (train or val split).
root_dir (str): Base directory containing class-named subdirectories.
transform (callable, optional): Torchvision transform pipeline.
mode (str): PIL image mode β€” 'L' for grayscale, 'RGB' for 3-channel.
"""
CLASS_MAP = {'no_sub': 0, 'cdm': 1, 'vortex': 2}
CLASS_NAMES = ['no_sub', 'cdm', 'vortex']
def __init__(
self,
dataframe: pd.DataFrame,
root_dir: str,
transform: Optional[Callable] = None,
mode: str = 'L',
) -> None:
self.dataframe = dataframe.reset_index(drop=True)
self.root_dir = root_dir
self.transform = transform
self.mode = mode
def __len__(self) -> int:
return len(self.dataframe)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
row = self.dataframe.iloc[idx]
img_path = os.path.join(self.root_dir, row['class'], row['filename'])
image = Image.open(img_path).convert(self.mode)
if self.transform:
image = self.transform(image)
label = self.CLASS_MAP[row['class']]
return image, label
# ─────────────────────────────────────────────────────────────────────────────
# TRANSFORM FACTORIES
# ─────────────────────────────────────────────────────────────────────────────
def get_train_transform(
mode: str = 'RGB',
image_size: int = 224,
augment: bool = True,
) -> transforms.Compose:
"""
Builds the training transform pipeline.
"""
if mode == 'RGB':
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
else:
normalize = transforms.Normalize(mean=[0.5], std=[0.5])
if augment:
aug_ops = [
transforms.RandomHorizontalFlip(p=0.5), # No preferred axis
transforms.RandomVerticalFlip(p=0.5), # No preferred axis
transforms.RandomRotation(degrees=360, fill=0), # Full rotational symmetry
transforms.ColorJitter(brightness=0.2), # Source brightness variation
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)), # PSF smearing
]
pipeline = aug_ops + [
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
normalize,
]
else:
pipeline = [
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
normalize,
]
return transforms.Compose(pipeline)
def get_val_transform(
mode: str = 'RGB',
image_size: int = 224,
) -> transforms.Compose:
"""
Builds the validation/test transform pipeline.
"""
if mode == 'RGB':
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
else:
normalize = transforms.Normalize(mean=[0.5], std=[0.5])
return transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
normalize,
])
# ─────────────────────────────────────────────────────────────────────────────
# [GSOC UPGRADE 1] MIXUP & CUTMIX COLLATE FUNCTION
# ─────────────────────────────────────────────────────────────────────────────
def get_mixup_cutmix_collate(num_classes: int = 3):
"""
Returns a collate function that randomly applies MixUp or CutMix to batches.
Crucial for regularizing the Vision Transformer (ViT) on small datasets.
"""
mixup = v2.MixUp(alpha=0.2, num_classes=num_classes)
cutmix = v2.CutMix(alpha=1.0, num_classes=num_classes)
choice = v2.RandomChoice([mixup, cutmix])
def collate_fn(batch):
images, labels = default_collate(batch)
return choice(images, labels)
return collate_fn
# ─────────────────────────────────────────────────────────────────────────────
# DATALOADER FACTORY
# ─────────────────────────────────────────────────────────────────────────────
def get_dataloaders(
csv_path: str,
base_dir: str,
mode: str = 'RGB',
image_size: int = 224,
batch_size: int = 32,
augment: bool = True,
worker_init_fn: Optional[Callable] = None,
generator: Optional[torch.Generator] = None,
apply_mixup: bool = False, # [GSOC UPGRADE 1] Trigger for ViT
):
"""
Builds stratified train/val/test DataLoaders.
"""
df = pd.read_csv(csv_path)
print("\nπŸ“Š Dataset Class Distribution:")
dist = df['class'].value_counts().sort_index()
total = len(df)
for cls, count in dist.items():
print(f" {cls:<10} : {count:>5} samples ({100 * count / total:.1f}%)")
print(f" {'TOTAL':<10} : {total:>5} samples\n")
# ── [GSOC UPGRADE 2] STRICT TRAIN / VAL / TEST SPLIT ─────────────────
# 70% Train, 15% Val, 15% Test
train_df, temp_df = train_test_split(
df,
test_size=0.30,
random_state=42,
stratify=df['class'],
)
val_df, test_df = train_test_split(
temp_df,
test_size=0.50,
random_state=42,
stratify=temp_df['class'],
)
print(f"βœ… Split β€” Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
# ── Transforms ───────────────────────────────────────────────────────
train_transform = get_train_transform(mode=mode, image_size=image_size, augment=augment)
val_transform = get_val_transform(mode=mode, image_size=image_size)
# ── Dataset objects ──────────────────────────────────────────────────
train_dataset = DeepLenseDataset(train_df, base_dir, transform=train_transform, mode=mode)
val_dataset = DeepLenseDataset(val_df, base_dir, transform=val_transform, mode=mode)
test_dataset = DeepLenseDataset(test_df, base_dir, transform=val_transform, mode=mode) # [GSOC UPGRADE 2]
# ── [GSOC UPGRADE 3] HARDWARE-AWARE DATALOADERS ──────────────────────
hw_num_workers = min(4, multiprocessing.cpu_count()) if torch.cuda.is_available() else 0
hw_pin_memory = torch.cuda.is_available()
collate_fn = get_mixup_cutmix_collate(num_classes=3) if (augment and apply_mixup) else None
# ── DataLoaders ───────────────────────────────────────────────────────
train_loader = DataLoader(
train_dataset,
batch_size = batch_size,
shuffle = True,
num_workers = hw_num_workers,
pin_memory = hw_pin_memory,
drop_last = True,
worker_init_fn = worker_init_fn,
generator = generator,
collate_fn = collate_fn # [GSOC UPGRADE 1]
)
val_loader = DataLoader(
val_dataset,
batch_size = batch_size,
shuffle = False,
num_workers = hw_num_workers,
pin_memory = hw_pin_memory,
worker_init_fn = worker_init_fn,
)
test_loader = DataLoader( # [GSOC UPGRADE 2]
test_dataset,
batch_size = batch_size,
shuffle = False,
num_workers = hw_num_workers,
pin_memory = hw_pin_memory,
worker_init_fn = worker_init_fn,
)
return train_loader, val_loader, test_loader, train_df, val_df, test_df