rogermt commited on
Commit
3e32ac2
·
verified ·
1 Parent(s): 88f3058

Fix geomloss tensor shape bug for images + optimize pool sampling

Browse files

- compute_velocity now flattens (N,C,H,W) → (N,D) before geomloss calls,
reshapes gradients back to original shape after. Fixes MNIST/CIFAR crash.
- compute_sinkhorn_divergence also handles image tensors.
- TrajectoryPool.finalize() pre-concatenates after pool building (O(1) sampling
instead of O(pool_size) per step)."

Files changed (1) hide show
  1. sinkhorn_flow.py +83 -12
sinkhorn_flow.py CHANGED
@@ -22,6 +22,10 @@ class SinkhornPotentialComputer:
22
  The velocity field of the Sinkhorn WGF is (Theorem 1):
23
  v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x)
24
 
 
 
 
 
25
  Args:
26
  blur: GeomLoss blur parameter (related to ε: ε = blur^p).
27
  scaling: Multiscale scaling parameter for Sinkhorn iterations.
@@ -45,31 +49,57 @@ class SinkhornPotentialComputer:
45
  backend=backend, potentials=False,
46
  )
47
 
 
 
 
 
 
 
 
48
  def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
49
  """Compute the Sinkhorn WGF velocity field at particles X.
50
 
51
  v(X_i) = ∇f_{μ,μ}(X_i) - ∇f_{μ,μ*}(X_i)
 
 
 
52
  """
53
- X_grad = X.detach().clone().requires_grad_(True)
54
- Y_det = Y.detach()
 
 
 
55
 
56
- F_self, _ = self.loss_fn(X_grad, X_grad.detach().clone())
 
 
 
57
  grad_self = torch.autograd.grad(
58
  F_self.sum(), X_grad, create_graph=False, retain_graph=False
59
  )[0]
60
 
61
- X_grad2 = X.detach().clone().requires_grad_(True)
62
- F_cross, _ = self.loss_fn(X_grad2, Y_det)
 
63
  grad_cross = torch.autograd.grad(
64
  F_cross.sum(), X_grad2, create_graph=False, retain_graph=False
65
  )[0]
66
 
 
67
  velocity = grad_self.detach() - grad_cross.detach()
 
 
 
 
 
68
  return velocity
69
 
70
  def compute_sinkhorn_divergence(self, X: torch.Tensor, Y: torch.Tensor) -> float:
 
71
  with torch.no_grad():
72
- return self.loss_monitor(X, Y).item()
 
 
73
 
74
 
75
  class SinkhornGradientFlow:
@@ -109,7 +139,11 @@ class SinkhornGradientFlow:
109
 
110
 
111
  class TrajectoryPool:
112
- """Stores (x, v, t) tuples from Sinkhorn gradient flow trajectories."""
 
 
 
 
113
 
114
  def __init__(self, max_size: int = 1_000_000):
115
  self.max_size = max_size
@@ -117,8 +151,15 @@ class TrajectoryPool:
117
  self.v_pool: List[torch.Tensor] = []
118
  self.t_pool: List[int] = []
119
  self._size = 0
 
 
 
 
120
 
121
  def add_trajectory(self, trajectory: List[Tuple[torch.Tensor, torch.Tensor, int]]):
 
 
 
122
  for x, v, t in trajectory:
123
  n = x.shape[0]
124
  if self._size + n > self.max_size:
@@ -147,13 +188,43 @@ class TrajectoryPool:
147
  self._size -= (batch_size - keep)
148
  removed = n
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def sample(self, batch_size: int, device: str = "cpu"
151
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152
- all_x = torch.cat(self.x_pool, dim=0)
153
- all_v = torch.cat(self.v_pool, dim=0)
154
- all_t = torch.tensor(self.t_pool, dtype=torch.float32)
155
- idx = torch.randint(0, all_x.shape[0], (batch_size,))
156
- return all_x[idx].to(device), all_v[idx].to(device), all_t[idx].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @property
159
  def size(self) -> int:
 
22
  The velocity field of the Sinkhorn WGF is (Theorem 1):
23
  v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x)
24
 
25
+ IMPORTANT: GeomLoss SamplesLoss requires inputs as (N, D) or (B, N, D) tensors.
26
+ For image data (N, C, H, W), we flatten to (N, C*H*W) before calling geomloss,
27
+ then reshape gradients back to (N, C, H, W).
28
+
29
  Args:
30
  blur: GeomLoss blur parameter (related to ε: ε = blur^p).
31
  scaling: Multiscale scaling parameter for Sinkhorn iterations.
 
49
  backend=backend, potentials=False,
50
  )
51
 
52
+ def _flatten_if_image(self, X: torch.Tensor) -> Tuple[torch.Tensor, bool, torch.Size]:
53
+ """Flatten (N,C,H,W) → (N,D) for geomloss. Returns (flat_tensor, was_image, original_shape)."""
54
+ original_shape = X.shape
55
+ if X.dim() == 4:
56
+ return X.view(X.shape[0], -1), True, original_shape
57
+ return X, False, original_shape
58
+
59
  def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
60
  """Compute the Sinkhorn WGF velocity field at particles X.
61
 
62
  v(X_i) = ∇f_{μ,μ}(X_i) - ∇f_{μ,μ*}(X_i)
63
+
64
+ Handles both 2D point clouds (N,D) and images (N,C,H,W) by
65
+ flattening images before geomloss calls.
66
  """
67
+ original_shape = X.shape
68
+
69
+ # Flatten if image tensors
70
+ X_flat, is_image, _ = self._flatten_if_image(X.detach().clone())
71
+ Y_flat, _, _ = self._flatten_if_image(Y.detach())
72
 
73
+ # --- Self-potential: ∇f_{μ,μ}(X) ---
74
+ X_grad = X_flat.requires_grad_(True)
75
+ X_self_detached = X_flat.detach().clone()
76
+ F_self, _ = self.loss_fn(X_grad, X_self_detached)
77
  grad_self = torch.autograd.grad(
78
  F_self.sum(), X_grad, create_graph=False, retain_graph=False
79
  )[0]
80
 
81
+ # --- Cross-potential: ∇f_{μ,μ*}(X) ---
82
+ X_grad2 = X_flat.detach().clone().requires_grad_(True)
83
+ F_cross, _ = self.loss_fn(X_grad2, Y_flat)
84
  grad_cross = torch.autograd.grad(
85
  F_cross.sum(), X_grad2, create_graph=False, retain_graph=False
86
  )[0]
87
 
88
+ # Velocity = ∇f_{μ,μ} - ∇f_{μ,μ*}
89
  velocity = grad_self.detach() - grad_cross.detach()
90
+
91
+ # Reshape back to original shape if image
92
+ if is_image:
93
+ velocity = velocity.view(original_shape)
94
+
95
  return velocity
96
 
97
  def compute_sinkhorn_divergence(self, X: torch.Tensor, Y: torch.Tensor) -> float:
98
+ """Compute Sinkhorn divergence S_ε(μ, μ*). Handles image tensors."""
99
  with torch.no_grad():
100
+ X_flat, _, _ = self._flatten_if_image(X)
101
+ Y_flat, _, _ = self._flatten_if_image(Y)
102
+ return self.loss_monitor(X_flat, Y_flat).item()
103
 
104
 
105
  class SinkhornGradientFlow:
 
139
 
140
 
141
  class TrajectoryPool:
142
+ """Stores (x, v, t) tuples from Sinkhorn gradient flow trajectories.
143
+
144
+ After building, call finalize() to pre-concatenate tensors for O(1) sampling.
145
+ Without finalize(), sampling is O(pool_size) per call due to torch.cat.
146
+ """
147
 
148
  def __init__(self, max_size: int = 1_000_000):
149
  self.max_size = max_size
 
151
  self.v_pool: List[torch.Tensor] = []
152
  self.t_pool: List[int] = []
153
  self._size = 0
154
+ self._finalized = False
155
+ self._all_x = None
156
+ self._all_v = None
157
+ self._all_t = None
158
 
159
  def add_trajectory(self, trajectory: List[Tuple[torch.Tensor, torch.Tensor, int]]):
160
+ """Add (x, v, t) entries from a flow trajectory. Call before finalize()."""
161
+ if self._finalized:
162
+ raise RuntimeError("Cannot add to a finalized pool. Create a new pool.")
163
  for x, v, t in trajectory:
164
  n = x.shape[0]
165
  if self._size + n > self.max_size:
 
188
  self._size -= (batch_size - keep)
189
  removed = n
190
 
191
+ def finalize(self):
192
+ """Pre-concatenate all pool data for fast O(1) sampling.
193
+
194
+ Call this once after all trajectories have been added.
195
+ After finalization, sample() is fast (just random indexing).
196
+ """
197
+ if self._size == 0:
198
+ raise RuntimeError("Cannot finalize an empty pool.")
199
+ self._all_x = torch.cat(self.x_pool, dim=0)
200
+ self._all_v = torch.cat(self.v_pool, dim=0)
201
+ self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
202
+ # Free the lists to save memory
203
+ self.x_pool = None
204
+ self.v_pool = None
205
+ self.t_pool = None
206
+ self._finalized = True
207
+
208
  def sample(self, batch_size: int, device: str = "cpu"
209
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
210
+ """Sample a random batch from the pool.
211
+
212
+ If finalize() was called, this is O(1). Otherwise falls back to O(pool_size).
213
+ """
214
+ if self._finalized:
215
+ idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
216
+ return (
217
+ self._all_x[idx].to(device),
218
+ self._all_v[idx].to(device),
219
+ self._all_t[idx].to(device),
220
+ )
221
+ else:
222
+ # Fallback: concatenate on the fly (slow for large pools)
223
+ all_x = torch.cat(self.x_pool, dim=0)
224
+ all_v = torch.cat(self.v_pool, dim=0)
225
+ all_t = torch.tensor(self.t_pool, dtype=torch.float32)
226
+ idx = torch.randint(0, all_x.shape[0], (batch_size,))
227
+ return all_x[idx].to(device), all_v[idx].to(device), all_t[idx].to(device)
228
 
229
  @property
230
  def size(self) -> int: