File size: 13,532 Bytes
4a3ae84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import os
import shutil
import zipfile
import multiprocessing
from typing import Optional, Callable, Tuple
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import v2  # [GSOC UPGRADE 1] Needed for MixUp/CutMix
from sklearn.model_selection import train_test_split
from torch.utils.data.dataloader import default_collate # [GSOC UPGRADE 1]


# ─────────────────────────────────────────────────────────────────────────────
# DATA STAGING
# ─────────────────────────────────────────────────────────────────────────────

def stage_data_locally(

    drive_zip_path: Optional[str],

    local_extract_dir: str = "/content/local_dataset",

) -> Optional[str]:
    """

    Automates data staging from Google Drive to Colab's local disk.

    Prevents PyTorch DataLoader ConnectionAbortedError by bypassing Drive I/O.



    FIX: Added an environment guard so this function fails gracefully when

    called outside of Google Colab (e.g. local dev machine or CI).

    The hardcoded /content/ paths are Colab-specific; on any other system

    the function now prints a warning and returns None instead of crashing.



    Args:

        drive_zip_path (str | None): Full path to the .zip file on Google Drive.

        local_extract_dir (str): Target directory on Colab's fast local disk.



    Returns:

        str | None: Path to the extracted local directory, or None if skipped.

    """
    # ── Guard: skip silently outside Colab ───────────────────────────────
    if not drive_zip_path:
        print("⚠️  No zip path provided. Skipping local staging.")
        return None

    if not os.path.exists(drive_zip_path):
        print(f"⚠️  Zip file not found at: {drive_zip_path}. Skipping local staging.")
        return None

    # ── Already staged ────────────────────────────────────────────────────
    if os.path.exists(local_extract_dir):
        print(f"βœ… Local staging already complete at: {local_extract_dir}")
        return local_extract_dir

    # ── Stage ─────────────────────────────────────────────────────────────
    print(f"πŸ“¦ Staging data locally to '{local_extract_dir}' for high-speed I/O...")
    os.makedirs(local_extract_dir, exist_ok=True)

    # Use a temp path inside the same directory tree so it's always writable,
    # even outside /content/ (handles non-Colab environments gracefully).
    local_zip = os.path.join(os.path.dirname(local_extract_dir), "_temp_dataset.zip")

    try:
        shutil.copy2(drive_zip_path, local_zip)
        with zipfile.ZipFile(local_zip, 'r') as zip_ref:
            zip_ref.extractall(local_extract_dir)
    finally:
        # Always clean up the temp zip, even if extraction fails
        if os.path.exists(local_zip):
            os.remove(local_zip)

    print("βœ… Data staging complete.")
    return local_extract_dir


# ─────────────────────────────────────────────────────────────────────────────
# DATASET CLASS
# ─────────────────────────────────────────────────────────────────────────────

class DeepLenseDataset(Dataset):
    """

    Unified PyTorch Dataset for DeepLense gravitational lensing images.



    Supports both 'L' (Grayscale / 1-channel) and 'RGB' (3-channel) modes.

    The class label is derived from the 'class' column of the metadata CSV.



    Class Mapping:

        no_sub  β†’ 0  (Smooth lens, no dark matter substructure)

        cdm     β†’ 1  (Cold Dark Matter substructure)

        vortex  β†’ 2  (Vortex / quantum condensate dark matter)



    Args:

        dataframe (pd.DataFrame): Subset dataframe (train or val split).

        root_dir (str): Base directory containing class-named subdirectories.

        transform (callable, optional): Torchvision transform pipeline.

        mode (str): PIL image mode β€” 'L' for grayscale, 'RGB' for 3-channel.

    """

    CLASS_MAP   = {'no_sub': 0, 'cdm': 1, 'vortex': 2}
    CLASS_NAMES = ['no_sub', 'cdm', 'vortex']

    def __init__(

        self,

        dataframe: pd.DataFrame,

        root_dir: str,

        transform: Optional[Callable] = None,

        mode: str = 'L',

    ) -> None:
        self.dataframe = dataframe.reset_index(drop=True)
        self.root_dir  = root_dir
        self.transform = transform
        self.mode      = mode

    def __len__(self) -> int:
        return len(self.dataframe)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row      = self.dataframe.iloc[idx]
        img_path = os.path.join(self.root_dir, row['class'], row['filename'])

        image = Image.open(img_path).convert(self.mode)

        if self.transform:
            image = self.transform(image)

        label = self.CLASS_MAP[row['class']]
        return image, label


# ─────────────────────────────────────────────────────────────────────────────
# TRANSFORM FACTORIES
# ─────────────────────────────────────────────────────────────────────────────

def get_train_transform(

    mode: str = 'RGB',

    image_size: int = 224,

    augment: bool = True,

) -> transforms.Compose:
    """

    Builds the training transform pipeline.

    """
    if mode == 'RGB':
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    else:
        normalize = transforms.Normalize(mean=[0.5], std=[0.5])

    if augment:
        aug_ops = [
            transforms.RandomHorizontalFlip(p=0.5),               # No preferred axis
            transforms.RandomVerticalFlip(p=0.5),                  # No preferred axis
            transforms.RandomRotation(degrees=360, fill=0),        # Full rotational symmetry
            transforms.ColorJitter(brightness=0.2),                # Source brightness variation
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),  # PSF smearing
        ]
        pipeline = aug_ops + [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            normalize,
        ]
    else:
        pipeline = [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            normalize,
        ]

    return transforms.Compose(pipeline)


def get_val_transform(

    mode: str = 'RGB',

    image_size: int = 224,

) -> transforms.Compose:
    """

    Builds the validation/test transform pipeline.

    """
    if mode == 'RGB':
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    else:
        normalize = transforms.Normalize(mean=[0.5], std=[0.5])

    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        normalize,
    ])

# ─────────────────────────────────────────────────────────────────────────────
# [GSOC UPGRADE 1] MIXUP & CUTMIX COLLATE FUNCTION
# ─────────────────────────────────────────────────────────────────────────────

def get_mixup_cutmix_collate(num_classes: int = 3):
    """

    Returns a collate function that randomly applies MixUp or CutMix to batches.

    Crucial for regularizing the Vision Transformer (ViT) on small datasets.

    """
    mixup = v2.MixUp(alpha=0.2, num_classes=num_classes)
    cutmix = v2.CutMix(alpha=1.0, num_classes=num_classes)
    
    choice = v2.RandomChoice([mixup, cutmix])
    
    def collate_fn(batch):
        images, labels = default_collate(batch)
        return choice(images, labels)
        
    return collate_fn


# ─────────────────────────────────────────────────────────────────────────────
# DATALOADER FACTORY
# ─────────────────────────────────────────────────────────────────────────────

def get_dataloaders(

    csv_path: str,

    base_dir: str,

    mode: str = 'RGB',

    image_size: int = 224,

    batch_size: int = 32,

    augment: bool = True,

    worker_init_fn: Optional[Callable] = None,

    generator: Optional[torch.Generator] = None,

    apply_mixup: bool = False, # [GSOC UPGRADE 1] Trigger for ViT

):
    """

    Builds stratified train/val/test DataLoaders.

    """
    df = pd.read_csv(csv_path)

    print("\nπŸ“Š Dataset Class Distribution:")
    dist  = df['class'].value_counts().sort_index()
    total = len(df)
    for cls, count in dist.items():
        print(f"   {cls:<10} : {count:>5} samples  ({100 * count / total:.1f}%)")
    print(f"   {'TOTAL':<10} : {total:>5} samples\n")

    # ── [GSOC UPGRADE 2] STRICT TRAIN / VAL / TEST SPLIT ─────────────────
    # 70% Train, 15% Val, 15% Test
    train_df, temp_df = train_test_split(
        df,
        test_size=0.30, 
        random_state=42,
        stratify=df['class'],
    )
    
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.50, 
        random_state=42,
        stratify=temp_df['class'],
    )
    print(f"βœ… Split β€” Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

    # ── Transforms ───────────────────────────────────────────────────────
    train_transform = get_train_transform(mode=mode, image_size=image_size, augment=augment)
    val_transform   = get_val_transform(mode=mode, image_size=image_size)

    # ── Dataset objects ──────────────────────────────────────────────────
    train_dataset = DeepLenseDataset(train_df, base_dir, transform=train_transform, mode=mode)
    val_dataset   = DeepLenseDataset(val_df,   base_dir, transform=val_transform,   mode=mode)
    test_dataset  = DeepLenseDataset(test_df,  base_dir, transform=val_transform,   mode=mode) # [GSOC UPGRADE 2]

    # ── [GSOC UPGRADE 3] HARDWARE-AWARE DATALOADERS ──────────────────────
    hw_num_workers = min(4, multiprocessing.cpu_count()) if torch.cuda.is_available() else 0
    hw_pin_memory = torch.cuda.is_available()

    collate_fn = get_mixup_cutmix_collate(num_classes=3) if (augment and apply_mixup) else None

    # ── DataLoaders ───────────────────────────────────────────────────────
    train_loader = DataLoader(
        train_dataset,
        batch_size      = batch_size,
        shuffle         = True,
        num_workers     = hw_num_workers,
        pin_memory      = hw_pin_memory,
        drop_last       = True,            
        worker_init_fn  = worker_init_fn,  
        generator       = generator,       
        collate_fn      = collate_fn       # [GSOC UPGRADE 1]
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size      = batch_size,
        shuffle         = False,           
        num_workers     = hw_num_workers,
        pin_memory      = hw_pin_memory,
        worker_init_fn  = worker_init_fn,  
    )

    test_loader = DataLoader(              # [GSOC UPGRADE 2]
        test_dataset,
        batch_size      = batch_size,
        shuffle         = False,           
        num_workers     = hw_num_workers,
        pin_memory      = hw_pin_memory,
        worker_init_fn  = worker_init_fn,  
    )

    return train_loader, val_loader, test_loader, train_df, val_df, test_df