File size: 8,929 Bytes
e23f433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e32ac2
 
 
 
e23f433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e32ac2
 
 
 
 
 
 
e23f433
 
 
 
3e32ac2
 
 
e23f433
3e32ac2
 
 
 
 
e23f433
3e32ac2
 
 
 
e23f433
 
 
 
3e32ac2
 
 
e23f433
 
 
 
3e32ac2
e23f433
3e32ac2
 
 
 
 
e23f433
 
 
3e32ac2
e23f433
3e32ac2
 
 
e23f433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e32ac2
 
 
 
 
e23f433
 
 
 
 
 
 
3e32ac2
 
 
 
e23f433
 
3e32ac2
 
 
e23f433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e32ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e23f433
 
3e32ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e23f433
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""sinkhorn_flow.py — Sinkhorn gradient flow and W_ε potential computation.

Core implementation of:
  - Sinkhorn divergence computation via GeomLoss
  - W_ε-potential gradients (∇f_{μ,μ} and ∇f_{μ,μ*})
  - Velocity field: v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x)  (Theorem 1, Eq. 10)
  - Euler discretization of the Sinkhorn WGF (Algorithm 1)
  - Trajectory pool construction for velocity field matching

Reference: arXiv:2401.14069, Section 4.1, 4.3, Appendix A
"""

import torch
import torch.nn as nn
from typing import List, Tuple, Optional
from geomloss import SamplesLoss


class SinkhornPotentialComputer:
    """Computes W_ε-potentials and their gradients using GeomLoss.

    The velocity field of the Sinkhorn WGF is (Theorem 1):
        v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x)

    IMPORTANT: GeomLoss SamplesLoss requires inputs as (N, D) or (B, N, D) tensors.
    For image data (N, C, H, W), we flatten to (N, C*H*W) before calling geomloss,
    then reshape gradients back to (N, C, H, W).

    Args:
        blur: GeomLoss blur parameter (related to ε: ε = blur^p).
        scaling: Multiscale scaling parameter for Sinkhorn iterations.
        p: Cost exponent (default 2 for squared Euclidean).
        backend: GeomLoss backend ('auto', 'tensorized', 'online').
    """

    def __init__(self, blur: float = 0.5, scaling: float = 0.80,
                 p: int = 2, backend: str = "tensorized"):
        self.blur = blur
        self.scaling = scaling
        self.p = p
        self.backend = backend

        self.loss_fn = SamplesLoss(
            loss="sinkhorn", p=p, blur=blur, scaling=scaling,
            backend=backend, potentials=True,
        )
        self.loss_monitor = SamplesLoss(
            loss="sinkhorn", p=p, blur=blur, scaling=scaling,
            backend=backend, potentials=False,
        )

    def _flatten_if_image(self, X: torch.Tensor) -> Tuple[torch.Tensor, bool, torch.Size]:
        """Flatten (N,C,H,W) → (N,D) for geomloss. Returns (flat_tensor, was_image, original_shape)."""
        original_shape = X.shape
        if X.dim() == 4:
            return X.view(X.shape[0], -1), True, original_shape
        return X, False, original_shape

    def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        """Compute the Sinkhorn WGF velocity field at particles X.

        v(X_i) = ∇f_{μ,μ}(X_i) - ∇f_{μ,μ*}(X_i)

        Handles both 2D point clouds (N,D) and images (N,C,H,W) by
        flattening images before geomloss calls.
        """
        original_shape = X.shape

        # Flatten if image tensors
        X_flat, is_image, _ = self._flatten_if_image(X.detach().clone())
        Y_flat, _, _ = self._flatten_if_image(Y.detach())

        # --- Self-potential: ∇f_{μ,μ}(X) ---
        X_grad = X_flat.requires_grad_(True)
        X_self_detached = X_flat.detach().clone()
        F_self, _ = self.loss_fn(X_grad, X_self_detached)
        grad_self = torch.autograd.grad(
            F_self.sum(), X_grad, create_graph=False, retain_graph=False
        )[0]

        # --- Cross-potential: ∇f_{μ,μ*}(X) ---
        X_grad2 = X_flat.detach().clone().requires_grad_(True)
        F_cross, _ = self.loss_fn(X_grad2, Y_flat)
        grad_cross = torch.autograd.grad(
            F_cross.sum(), X_grad2, create_graph=False, retain_graph=False
        )[0]

        # Velocity = ∇f_{μ,μ} - ∇f_{μ,μ*}
        velocity = grad_self.detach() - grad_cross.detach()

        # Reshape back to original shape if image
        if is_image:
            velocity = velocity.view(original_shape)

        return velocity

    def compute_sinkhorn_divergence(self, X: torch.Tensor, Y: torch.Tensor) -> float:
        """Compute Sinkhorn divergence S_ε(μ, μ*). Handles image tensors."""
        with torch.no_grad():
            X_flat, _, _ = self._flatten_if_image(X)
            Y_flat, _, _ = self._flatten_if_image(Y)
            return self.loss_monitor(X_flat, Y_flat).item()


class SinkhornGradientFlow:
    """Implements the discrete Sinkhorn Wasserstein Gradient Flow.

    Evolves particles via Euler steps:
        X^{t+1} = X^t + η * v(X^t)
    """

    def __init__(self, potential_computer: SinkhornPotentialComputer,
                 eta: float = 1.0, num_steps: int = 5):
        self.potential_computer = potential_computer
        self.eta = eta
        self.num_steps = num_steps

    def run_flow(self, X0: torch.Tensor, Y: torch.Tensor,
                 store_trajectory: bool = True
                 ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, int]]]:
        trajectory = []
        X_t = X0.clone()

        for t in range(self.num_steps):
            v_t = self.potential_computer.compute_velocity(X_t, Y)
            if store_trajectory:
                trajectory.append((
                    X_t.detach().cpu().clone(),
                    v_t.detach().cpu().clone(),
                    t,
                ))
            X_t = X_t.detach() + self.eta * v_t.detach()

        return X_t, trajectory

    def run_flow_no_store(self, X0: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        X_T, _ = self.run_flow(X0, Y, store_trajectory=False)
        return X_T


class TrajectoryPool:
    """Stores (x, v, t) tuples from Sinkhorn gradient flow trajectories.

    After building, call finalize() to pre-concatenate tensors for O(1) sampling.
    Without finalize(), sampling is O(pool_size) per call due to torch.cat.
    """

    def __init__(self, max_size: int = 1_000_000):
        self.max_size = max_size
        self.x_pool: List[torch.Tensor] = []
        self.v_pool: List[torch.Tensor] = []
        self.t_pool: List[int] = []
        self._size = 0
        self._finalized = False
        self._all_x = None
        self._all_v = None
        self._all_t = None

    def add_trajectory(self, trajectory: List[Tuple[torch.Tensor, torch.Tensor, int]]):
        """Add (x, v, t) entries from a flow trajectory. Call before finalize()."""
        if self._finalized:
            raise RuntimeError("Cannot add to a finalized pool. Create a new pool.")
        for x, v, t in trajectory:
            n = x.shape[0]
            if self._size + n > self.max_size:
                excess = (self._size + n) - self.max_size
                self._drop_oldest(excess)
            self.x_pool.append(x)
            self.v_pool.append(v)
            self.t_pool.extend([t] * n)
            self._size += n

    def _drop_oldest(self, n: int):
        removed = 0
        while removed < n and len(self.x_pool) > 0:
            batch_size = self.x_pool[0].shape[0]
            if removed + batch_size <= n:
                self.x_pool.pop(0)
                self.v_pool.pop(0)
                self.t_pool = self.t_pool[batch_size:]
                removed += batch_size
                self._size -= batch_size
            else:
                keep = batch_size - (n - removed)
                self.x_pool[0] = self.x_pool[0][-keep:]
                self.v_pool[0] = self.v_pool[0][-keep:]
                self.t_pool = self.t_pool[(batch_size - keep):]
                self._size -= (batch_size - keep)
                removed = n

    def finalize(self):
        """Pre-concatenate all pool data for fast O(1) sampling.

        Call this once after all trajectories have been added.
        After finalization, sample() is fast (just random indexing).
        """
        if self._size == 0:
            raise RuntimeError("Cannot finalize an empty pool.")
        self._all_x = torch.cat(self.x_pool, dim=0)
        self._all_v = torch.cat(self.v_pool, dim=0)
        self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
        # Free the lists to save memory
        self.x_pool = None
        self.v_pool = None
        self.t_pool = None
        self._finalized = True

    def sample(self, batch_size: int, device: str = "cpu"
               ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Sample a random batch from the pool.

        If finalize() was called, this is O(1). Otherwise falls back to O(pool_size).
        """
        if self._finalized:
            idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
            return (
                self._all_x[idx].to(device),
                self._all_v[idx].to(device),
                self._all_t[idx].to(device),
            )
        else:
            # Fallback: concatenate on the fly (slow for large pools)
            all_x = torch.cat(self.x_pool, dim=0)
            all_v = torch.cat(self.v_pool, dim=0)
            all_t = torch.tensor(self.t_pool, dtype=torch.float32)
            idx = torch.randint(0, all_x.shape[0], (batch_size,))
            return all_x[idx].to(device), all_v[idx].to(device), all_t[idx].to(device)

    @property
    def size(self) -> int:
        return self._size

    def __len__(self) -> int:
        return self._size