File size: 6,467 Bytes
ad40ad2 9e3fccc ad40ad2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """dataset_loader.py — Data loading for NSGF/NSGF++ experiments.
Handles:
- 2D synthetic datasets (8gaussians, moons, scurve, checkerboard)
- MNIST / CIFAR-10 for image experiments
- Source distributions (standard Gaussian)
Reference: arXiv:2401.14069, Appendix E.1 and E.2
"""
import math
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons, make_s_curve
# ============================================================
# 2D Synthetic Datasets (following Tong et al. 2023 / Grathwohl et al. 2018)
# ============================================================
def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor:
"""8 Gaussian modes arranged in a circle."""
centers = []
for i in range(8):
angle = 2 * math.pi * i / 8
centers.append((scale * math.cos(angle), scale * math.sin(angle)))
centers = np.array(centers)
idx = np.random.randint(0, 8, n)
data = centers[idx] + np.random.randn(n, 2) * std
return torch.FloatTensor(data)
def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor:
"""Two interleaving half-circles (scikit-learn moons)."""
data, _ = make_moons(n_samples=n, noise=noise)
data = data * 3.0 - np.array([1.0, 0.0])
return torch.FloatTensor(data)
def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor:
"""S-curve projected to 2D."""
data, _ = make_s_curve(n_samples=n, noise=noise)
data = data[:, [0, 2]] * 3.0
return torch.FloatTensor(data)
def sample_checkerboard(n: int) -> torch.Tensor:
"""4x4 checkerboard pattern."""
x1 = np.random.rand(n) * 4 - 2
x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2
x2 = x2_ + (np.floor(x1) % 2)
data = np.column_stack([x1, x2]) * 2
return torch.FloatTensor(data)
def sample_8gaussians_moons(n: int) -> torch.Tensor:
"""Mixture: half from 8gaussians, half from moons."""
n1 = n // 2
n2 = n - n1
g = sample_8gaussians(n1)
m = sample_moons(n2)
data = torch.cat([g, m], dim=0)
perm = torch.randperm(n)
return data[perm]
DATASET_2D = {
"8gaussians": sample_8gaussians,
"moons": sample_moons,
"scurve": sample_scurve,
"checkerboard": sample_checkerboard,
"8gaussians_moons": sample_8gaussians_moons,
}
def get_2d_dataset(name: str, n: int) -> torch.Tensor:
if name not in DATASET_2D:
raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}")
return DATASET_2D[name](n)
def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor:
return torch.randn(n, dim)
# ============================================================
# Image Datasets (MNIST, CIFAR-10)
# ============================================================
def get_image_dataloader(
dataset_name: str,
batch_size: int,
train: bool = True,
data_root: str = "./data",
num_workers: int = 2,
normalize_range: tuple = (-1.0, 1.0),
) -> DataLoader:
import torchvision
import torchvision.transforms as T
lo, hi = normalize_range
transforms_list = [T.ToTensor()]
transforms_list.append(T.Normalize(
mean=[0.5] * (1 if dataset_name == "mnist" else 3),
std=[0.5] * (1 if dataset_name == "mnist" else 3),
))
transform = T.Compose(transforms_list)
if dataset_name == "mnist":
ds = torchvision.datasets.MNIST(
root=data_root, train=train, download=True, transform=transform
)
elif dataset_name == "cifar10":
ds = torchvision.datasets.CIFAR10(
root=data_root, train=train, download=True, transform=transform
)
else:
raise ValueError(f"Unknown image dataset: {dataset_name}")
return DataLoader(
ds, batch_size=batch_size, shuffle=train,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor:
return torch.randn(n, channels, image_size, image_size)
# ============================================================
# DatasetLoader class (unified interface)
# ============================================================
class DatasetLoader:
def __init__(self, config: dict):
self.config = config
self.dataset_name = config.get("dataset", "8gaussians")
self.is_image = self.dataset_name in ("mnist", "cifar10")
def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
# Recreate DataLoader if batch size changed (different training phases
# use different batch sizes, e.g. 256 for pool building, 128 for NSF)
if not hasattr(self, "_image_loader") or self._image_batch_size != n:
self._image_batch_size = n
self._image_loader = get_image_dataloader(
self.dataset_name, batch_size=n, train=True
)
self._image_iter = iter(self._image_loader)
try:
images, _ = next(self._image_iter)
except StopIteration:
self._image_iter = iter(self._image_loader)
images, _ = next(self._image_iter)
return images.to(device)
else:
return get_2d_dataset(self.dataset_name, n).to(device)
def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
channels = self.config.get("in_channels", 1)
image_size = self.config.get("image_size", 28)
return sample_source_image(n, channels, image_size).to(device)
else:
dim = self.config.get("model", {}).get("input_dim", 2)
return sample_source_2d(n, dim).to(device)
def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
loader = get_image_dataloader(
self.dataset_name, batch_size=n, train=False
)
images, _ = next(iter(loader))
return images.to(device)
else:
return get_2d_dataset(self.dataset_name, n).to(device)
@property
def data_dim(self) -> int:
if self.is_image:
c = self.config.get("in_channels", 1)
s = self.config.get("image_size", 28)
return c * s * s
else:
return self.config.get("model", {}).get("input_dim", 2)
|