| """ |
| Streaming data pipeline for The Well datasets. |
| Handles HF streaming and local loading with robust error recovery. |
| """ |
| import torch |
| from torch.utils.data import DataLoader |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def create_dataloader( |
| dataset_name="turbulent_radiative_layer_2D", |
| split="train", |
| batch_size=4, |
| n_steps_input=1, |
| n_steps_output=1, |
| num_workers=0, |
| streaming=True, |
| local_path=None, |
| use_normalization=True, |
| ): |
| """Create a DataLoader for a Well dataset. |
| |
| Args: |
| dataset_name: Name of the Well dataset. |
| split: 'train', 'valid', or 'test'. |
| batch_size: Batch size. |
| n_steps_input: Number of input timesteps. |
| n_steps_output: Number of output timesteps. |
| num_workers: DataLoader workers (0 for streaming recommended). |
| streaming: If True, stream from HuggingFace Hub. |
| local_path: Path to local data (used if streaming=False). |
| use_normalization: Whether to normalize data. |
| |
| Returns: |
| (DataLoader, WellDataset) |
| """ |
| from the_well.data import WellDataset |
|
|
| base_path = "hf://datasets/polymathic-ai/" if streaming else local_path |
| if base_path is None: |
| raise ValueError("Must provide local_path when streaming=False") |
|
|
| logger.info(f"Creating dataset: {dataset_name}/{split} (streaming={streaming})") |
|
|
| dataset = WellDataset( |
| well_base_path=base_path, |
| well_dataset_name=dataset_name, |
| well_split_name=split, |
| n_steps_input=n_steps_input, |
| n_steps_output=n_steps_output, |
| use_normalization=use_normalization, |
| flatten_tensors=True, |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=(split == "train"), |
| num_workers=num_workers, |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=num_workers > 0, |
| ) |
|
|
| return loader, dataset |
|
|
|
|
| def to_channels_first(x): |
| """Convert Well format [B, T, H, W, C] to PyTorch [B, T*C, H, W].""" |
| if x.dim() == 5: |
| B, T, H, W, C = x.shape |
| return x.permute(0, 1, 4, 2, 3).reshape(B, T * C, H, W) |
| elif x.dim() == 4: |
| return x.permute(0, 3, 1, 2) |
| elif x.dim() == 3: |
| return x.permute(2, 0, 1) |
| return x |
|
|
|
|
| def prepare_batch(batch, device="cuda"): |
| """Convert a Well batch to model-ready tensors. |
| |
| Returns: |
| x_input: [B, Ti*C, H, W] condition frames (channels-first) |
| x_output: [B, To*C, H, W] target frames (channels-first) |
| """ |
| input_fields = batch["input_fields"].to(device, non_blocking=True) |
| output_fields = batch["output_fields"].to(device, non_blocking=True) |
|
|
| x_input = to_channels_first(input_fields).float() |
| x_output = to_channels_first(output_fields).float() |
|
|
| return x_input, x_output |
|
|
|
|
| def get_data_info(dataset): |
| """Probe dataset for shapes and channel counts.""" |
| sample = dataset[0] |
| info = {} |
| for key, val in sample.items(): |
| if isinstance(val, torch.Tensor): |
| info[key] = tuple(val.shape) |
| return info |
|
|
|
|
| def get_channel_info(dataset): |
| """Get input/output channel counts for model construction.""" |
| sample = dataset[0] |
| inp = sample["input_fields"] |
| out = sample["output_fields"] |
|
|
| ti, h, w, c_in = inp.shape |
| to_, _, _, c_out = out.shape |
|
|
| return { |
| "input_channels": ti * c_in, |
| "output_channels": to_ * c_out, |
| "raw_channels": c_in, |
| "height": h, |
| "width": w, |
| "n_steps_input": ti, |
| "n_steps_output": to_, |
| } |
|
|