File size: 6,467 Bytes
ad40ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3fccc
 
 
 
ad40ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""dataset_loader.py — Data loading for NSGF/NSGF++ experiments.

Handles:
  - 2D synthetic datasets (8gaussians, moons, scurve, checkerboard)
  - MNIST / CIFAR-10 for image experiments
  - Source distributions (standard Gaussian)

Reference: arXiv:2401.14069, Appendix E.1 and E.2
"""

import math
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons, make_s_curve


# ============================================================
# 2D Synthetic Datasets (following Tong et al. 2023 / Grathwohl et al. 2018)
# ============================================================

def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor:
    """8 Gaussian modes arranged in a circle."""
    centers = []
    for i in range(8):
        angle = 2 * math.pi * i / 8
        centers.append((scale * math.cos(angle), scale * math.sin(angle)))
    centers = np.array(centers)
    idx = np.random.randint(0, 8, n)
    data = centers[idx] + np.random.randn(n, 2) * std
    return torch.FloatTensor(data)


def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor:
    """Two interleaving half-circles (scikit-learn moons)."""
    data, _ = make_moons(n_samples=n, noise=noise)
    data = data * 3.0 - np.array([1.0, 0.0])
    return torch.FloatTensor(data)


def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor:
    """S-curve projected to 2D."""
    data, _ = make_s_curve(n_samples=n, noise=noise)
    data = data[:, [0, 2]] * 3.0
    return torch.FloatTensor(data)


def sample_checkerboard(n: int) -> torch.Tensor:
    """4x4 checkerboard pattern."""
    x1 = np.random.rand(n) * 4 - 2
    x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2
    x2 = x2_ + (np.floor(x1) % 2)
    data = np.column_stack([x1, x2]) * 2
    return torch.FloatTensor(data)


def sample_8gaussians_moons(n: int) -> torch.Tensor:
    """Mixture: half from 8gaussians, half from moons."""
    n1 = n // 2
    n2 = n - n1
    g = sample_8gaussians(n1)
    m = sample_moons(n2)
    data = torch.cat([g, m], dim=0)
    perm = torch.randperm(n)
    return data[perm]


DATASET_2D = {
    "8gaussians": sample_8gaussians,
    "moons": sample_moons,
    "scurve": sample_scurve,
    "checkerboard": sample_checkerboard,
    "8gaussians_moons": sample_8gaussians_moons,
}


def get_2d_dataset(name: str, n: int) -> torch.Tensor:
    if name not in DATASET_2D:
        raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}")
    return DATASET_2D[name](n)


def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor:
    return torch.randn(n, dim)


# ============================================================
# Image Datasets (MNIST, CIFAR-10)
# ============================================================

def get_image_dataloader(
    dataset_name: str,
    batch_size: int,
    train: bool = True,
    data_root: str = "./data",
    num_workers: int = 2,
    normalize_range: tuple = (-1.0, 1.0),
) -> DataLoader:
    import torchvision
    import torchvision.transforms as T

    lo, hi = normalize_range
    transforms_list = [T.ToTensor()]
    transforms_list.append(T.Normalize(
        mean=[0.5] * (1 if dataset_name == "mnist" else 3),
        std=[0.5] * (1 if dataset_name == "mnist" else 3),
    ))
    transform = T.Compose(transforms_list)

    if dataset_name == "mnist":
        ds = torchvision.datasets.MNIST(
            root=data_root, train=train, download=True, transform=transform
        )
    elif dataset_name == "cifar10":
        ds = torchvision.datasets.CIFAR10(
            root=data_root, train=train, download=True, transform=transform
        )
    else:
        raise ValueError(f"Unknown image dataset: {dataset_name}")

    return DataLoader(
        ds, batch_size=batch_size, shuffle=train,
        num_workers=num_workers, pin_memory=True, drop_last=True,
    )


def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor:
    return torch.randn(n, channels, image_size, image_size)


# ============================================================
# DatasetLoader class (unified interface)
# ============================================================

class DatasetLoader:
    def __init__(self, config: dict):
        self.config = config
        self.dataset_name = config.get("dataset", "8gaussians")
        self.is_image = self.dataset_name in ("mnist", "cifar10")

    def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
        if self.is_image:
            # Recreate DataLoader if batch size changed (different training phases
            # use different batch sizes, e.g. 256 for pool building, 128 for NSF)
            if not hasattr(self, "_image_loader") or self._image_batch_size != n:
                self._image_batch_size = n
                self._image_loader = get_image_dataloader(
                    self.dataset_name, batch_size=n, train=True
                )
                self._image_iter = iter(self._image_loader)
            try:
                images, _ = next(self._image_iter)
            except StopIteration:
                self._image_iter = iter(self._image_loader)
                images, _ = next(self._image_iter)
            return images.to(device)
        else:
            return get_2d_dataset(self.dataset_name, n).to(device)

    def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor:
        if self.is_image:
            channels = self.config.get("in_channels", 1)
            image_size = self.config.get("image_size", 28)
            return sample_source_image(n, channels, image_size).to(device)
        else:
            dim = self.config.get("model", {}).get("input_dim", 2)
            return sample_source_2d(n, dim).to(device)

    def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor:
        if self.is_image:
            loader = get_image_dataloader(
                self.dataset_name, batch_size=n, train=False
            )
            images, _ = next(iter(loader))
            return images.to(device)
        else:
            return get_2d_dataset(self.dataset_name, n).to(device)

    @property
    def data_dim(self) -> int:
        if self.is_image:
            c = self.config.get("in_channels", 1)
            s = self.config.get("image_size", 28)
            return c * s * s
        else:
            return self.config.get("model", {}).get("input_dim", 2)