rogermt commited on
Commit
ad40ad2
·
verified ·
1 Parent(s): 370da72

Upload dataset_loader.py

Browse files
Files changed (1) hide show
  1. dataset_loader.py +183 -0
dataset_loader.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """dataset_loader.py — Data loading for NSGF/NSGF++ experiments.
2
+
3
+ Handles:
4
+ - 2D synthetic datasets (8gaussians, moons, scurve, checkerboard)
5
+ - MNIST / CIFAR-10 for image experiments
6
+ - Source distributions (standard Gaussian)
7
+
8
+ Reference: arXiv:2401.14069, Appendix E.1 and E.2
9
+ """
10
+
11
+ import math
12
+ import numpy as np
13
+ import torch
14
+ from torch.utils.data import DataLoader, TensorDataset
15
+ from sklearn.datasets import make_moons, make_s_curve
16
+
17
+
18
+ # ============================================================
19
+ # 2D Synthetic Datasets (following Tong et al. 2023 / Grathwohl et al. 2018)
20
+ # ============================================================
21
+
22
+ def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor:
23
+ """8 Gaussian modes arranged in a circle."""
24
+ centers = []
25
+ for i in range(8):
26
+ angle = 2 * math.pi * i / 8
27
+ centers.append((scale * math.cos(angle), scale * math.sin(angle)))
28
+ centers = np.array(centers)
29
+ idx = np.random.randint(0, 8, n)
30
+ data = centers[idx] + np.random.randn(n, 2) * std
31
+ return torch.FloatTensor(data)
32
+
33
+
34
+ def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor:
35
+ """Two interleaving half-circles (scikit-learn moons)."""
36
+ data, _ = make_moons(n_samples=n, noise=noise)
37
+ data = data * 3.0 - np.array([1.0, 0.0])
38
+ return torch.FloatTensor(data)
39
+
40
+
41
+ def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor:
42
+ """S-curve projected to 2D."""
43
+ data, _ = make_s_curve(n_samples=n, noise=noise)
44
+ data = data[:, [0, 2]] * 3.0
45
+ return torch.FloatTensor(data)
46
+
47
+
48
+ def sample_checkerboard(n: int) -> torch.Tensor:
49
+ """4x4 checkerboard pattern."""
50
+ x1 = np.random.rand(n) * 4 - 2
51
+ x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2
52
+ x2 = x2_ + (np.floor(x1) % 2)
53
+ data = np.column_stack([x1, x2]) * 2
54
+ return torch.FloatTensor(data)
55
+
56
+
57
+ def sample_8gaussians_moons(n: int) -> torch.Tensor:
58
+ """Mixture: half from 8gaussians, half from moons."""
59
+ n1 = n // 2
60
+ n2 = n - n1
61
+ g = sample_8gaussians(n1)
62
+ m = sample_moons(n2)
63
+ data = torch.cat([g, m], dim=0)
64
+ perm = torch.randperm(n)
65
+ return data[perm]
66
+
67
+
68
+ DATASET_2D = {
69
+ "8gaussians": sample_8gaussians,
70
+ "moons": sample_moons,
71
+ "scurve": sample_scurve,
72
+ "checkerboard": sample_checkerboard,
73
+ "8gaussians_moons": sample_8gaussians_moons,
74
+ }
75
+
76
+
77
+ def get_2d_dataset(name: str, n: int) -> torch.Tensor:
78
+ if name not in DATASET_2D:
79
+ raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}")
80
+ return DATASET_2D[name](n)
81
+
82
+
83
+ def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor:
84
+ return torch.randn(n, dim)
85
+
86
+
87
+ # ============================================================
88
+ # Image Datasets (MNIST, CIFAR-10)
89
+ # ============================================================
90
+
91
+ def get_image_dataloader(
92
+ dataset_name: str,
93
+ batch_size: int,
94
+ train: bool = True,
95
+ data_root: str = "./data",
96
+ num_workers: int = 2,
97
+ normalize_range: tuple = (-1.0, 1.0),
98
+ ) -> DataLoader:
99
+ import torchvision
100
+ import torchvision.transforms as T
101
+
102
+ lo, hi = normalize_range
103
+ transforms_list = [T.ToTensor()]
104
+ transforms_list.append(T.Normalize(
105
+ mean=[0.5] * (1 if dataset_name == "mnist" else 3),
106
+ std=[0.5] * (1 if dataset_name == "mnist" else 3),
107
+ ))
108
+ transform = T.Compose(transforms_list)
109
+
110
+ if dataset_name == "mnist":
111
+ ds = torchvision.datasets.MNIST(
112
+ root=data_root, train=train, download=True, transform=transform
113
+ )
114
+ elif dataset_name == "cifar10":
115
+ ds = torchvision.datasets.CIFAR10(
116
+ root=data_root, train=train, download=True, transform=transform
117
+ )
118
+ else:
119
+ raise ValueError(f"Unknown image dataset: {dataset_name}")
120
+
121
+ return DataLoader(
122
+ ds, batch_size=batch_size, shuffle=train,
123
+ num_workers=num_workers, pin_memory=True, drop_last=True,
124
+ )
125
+
126
+
127
+ def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor:
128
+ return torch.randn(n, channels, image_size, image_size)
129
+
130
+
131
+ # ============================================================
132
+ # DatasetLoader class (unified interface)
133
+ # ============================================================
134
+
135
+ class DatasetLoader:
136
+ def __init__(self, config: dict):
137
+ self.config = config
138
+ self.dataset_name = config.get("dataset", "8gaussians")
139
+ self.is_image = self.dataset_name in ("mnist", "cifar10")
140
+
141
+ def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
142
+ if self.is_image:
143
+ if not hasattr(self, "_image_loader"):
144
+ self._image_loader = get_image_dataloader(
145
+ self.dataset_name, batch_size=n, train=True
146
+ )
147
+ self._image_iter = iter(self._image_loader)
148
+ try:
149
+ images, _ = next(self._image_iter)
150
+ except StopIteration:
151
+ self._image_iter = iter(self._image_loader)
152
+ images, _ = next(self._image_iter)
153
+ return images.to(device)
154
+ else:
155
+ return get_2d_dataset(self.dataset_name, n).to(device)
156
+
157
+ def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor:
158
+ if self.is_image:
159
+ channels = self.config.get("in_channels", 1)
160
+ image_size = self.config.get("image_size", 28)
161
+ return sample_source_image(n, channels, image_size).to(device)
162
+ else:
163
+ dim = self.config.get("model", {}).get("input_dim", 2)
164
+ return sample_source_2d(n, dim).to(device)
165
+
166
+ def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor:
167
+ if self.is_image:
168
+ loader = get_image_dataloader(
169
+ self.dataset_name, batch_size=n, train=False
170
+ )
171
+ images, _ = next(iter(loader))
172
+ return images.to(device)
173
+ else:
174
+ return get_2d_dataset(self.dataset_name, n).to(device)
175
+
176
+ @property
177
+ def data_dim(self) -> int:
178
+ if self.is_image:
179
+ c = self.config.get("in_channels", 1)
180
+ s = self.config.get("image_size", 28)
181
+ return c * s * s
182
+ else:
183
+ return self.config.get("model", {}).get("input_dim", 2)