| import typing |
|
|
| import torch |
| import torch.nn.functional as F |
| from jaxtyping import Float |
|
|
| from src.loss.abstract_loss import AbstractLoss |
| from src.utils.math import sobol_sphere |
|
|
|
|
| def process_vector( |
| x: Float[torch.Tensor, "B D N"], |
| dirs: Float[torch.Tensor, "K D"], |
| ) -> Float[torch.Tensor, "K B*N_valid"]: |
| """ |
| Project a 1-D sequence with a bank of linear directions. |
| |
| Args |
| ---- |
| x : (B, D, N) tensor – predictions or ground truth |
| dirs : (K, D) tensor – unit-length projection directions |
| |
| Returns |
| ------- |
| proj : (K, B*N_valid) tensor of flattened projections |
| """ |
| B, D, N = x.shape |
| K, _ = dirs.shape |
|
|
| |
| proj = F.linear(x.transpose(1, 2).to(torch.float32), dirs.to(torch.float32)) |
| proj = proj.permute(2, 0, 1).reshape(K, -1).to(x.dtype) |
|
|
| return proj |
|
|
|
|
| class VectorSWDLoss(AbstractLoss): |
| """ |
| 1-D Sliced-Wasserstein Distance on sequences. |
| |
| This loss computes the sliced Wasserstein distance between predicted and ground |
| truth sequences by projecting them onto random directions and computing the |
| Wasserstein distance in 1D. It supports reservoir sampling for adaptive direction |
| selection and various variance reduction techniques. |
| |
| Parameters |
| ---------- |
| num_proj : int, default=64 |
| Number of random projections to use per step (K). |
| |
| distance : {"l1", "l2"}, default="l1" |
| Distance metric to use for computing the Wasserstein distance. |
| |
| use_ucv : bool, default=False |
| Whether to use upper bounds control variates for variance reduction. |
| Mutually exclusive with use_lcv. |
| |
| use_lcv : bool, default=False |
| Whether to use lower bounds control variates for variance reduction. |
| Mutually exclusive with use_ucv. |
| |
| refresh_projections_every_n_steps : int, default=1 |
| How often to refresh the projection directions. A value of 1 means |
| refresh every step, higher values reuse directions for multiple steps. |
| |
| num_new_candidates : int, default=16 |
| Number of new candidate directions to generate per step (M). |
| If 0, reservoir sampling is disabled. Must not exceed num_proj. |
| |
| ess_alpha : float, default=0.5 |
| Effective sample size threshold for resetting the reservoir. |
| When ESS drops below ess_alpha * reservoir_size, the reservoir is reset. |
| |
| time_decay_tau : float or None, default=30.0 |
| Time decay parameter for reservoir weights. If None, no time decay is applied. |
| Weights decay exponentially with age: exp(-age / time_decay_tau). |
| |
| missing_value_method : {"random_replicate", "interpolate"}, |
| default="random_replicate" |
| Method for handling sequences of different lengths: |
| - "random_replicate": Randomly replicate shorter sequences |
| - "interpolate": Use linear interpolation to match lengths |
| |
| sampling_mode : {"gaussian", "qmc"}, default="qmc" |
| Method for generating random projection directions: |
| - "gaussian": Standard Gaussian sampling |
| - "qmc": Quasi-Monte Carlo sampling using Sobol sequences |
| |
| Notes |
| ----- |
| - Reservoir sampling is enabled when num_new_candidates > 0 |
| - Reservoir size = num_proj - num_new_candidates |
| - When use_ucv or use_lcv is True, variance reduction is applied using |
| control variates based on the difference between sample and population means |
| - The loss automatically handles sequences of different lengths using the |
| specified missing_value_method |
| """ |
|
|
| def __init__( |
| self, |
| num_proj: int = 64, |
| distance: typing.Literal["l1", "l2"] = "l1", |
| use_ucv: bool = False, |
| use_lcv: bool = False, |
| refresh_projections_every_n_steps: int = 1, |
| num_new_candidates: int = 16, |
| ess_alpha: float = 0.5, |
| time_decay_tau: float | None = 30.0, |
| missing_value_method: typing.Literal[ |
| "random_replicate", "interpolate" |
| ] = "random_replicate", |
| sampling_mode: typing.Literal[ |
| "gaussian", |
| "qmc", |
| ] = "qmc", |
| ): |
| super().__init__() |
|
|
| assert not (use_ucv and use_lcv), "use_ucv and use_lcv cannot both be True" |
|
|
| self.num_proj = num_proj |
| self.distance = distance |
| self.use_ucv = use_ucv |
| self.use_lcv = use_lcv |
|
|
| self.refresh_projections_every_n_steps = refresh_projections_every_n_steps |
| self.num_new_candidates = num_new_candidates |
| self.ess_alpha = ess_alpha |
| self.time_decay_tau = time_decay_tau |
| self.missing_value_method = missing_value_method |
|
|
| if num_new_candidates > 0 and self.refresh_projections_every_n_steps != 1: |
| |
| print( |
| "WARNING: num_new_candidates > 0 (enabling reservoir sampling) and " |
| "refresh_projections_every_n_steps != 1 is not recommended" |
| ) |
| assert ( |
| num_new_candidates <= num_proj |
| ), "`num_new_candidates` must not exceed `num_proj`" |
|
|
| |
| self.restir_enabled = self.num_new_candidates > 0 |
| self.reservoir_size = self.num_proj - self.num_new_candidates |
| self.register_buffer("_reservoir_filters", torch.empty(0)) |
| self.register_buffer("_reservoir_weights", torch.empty(0)) |
| self.register_buffer("_reservoir_steps", torch.empty(0, dtype=torch.long)) |
| self.register_buffer("_reservoir_keys", torch.empty(0)) |
| self.register_buffer("_cumulative_weights", torch.tensor(0.0)) |
| self.register_buffer("_has_reservoir", torch.tensor(False, dtype=torch.bool)) |
|
|
| self._cached_dirs: typing.Optional[torch.Tensor] = None |
| self.sampling_mode = sampling_mode |
| self.sobol_engine = None |
|
|
| def _gaussian_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor: |
| """Generate Gaussian random projection directions.""" |
| w = torch.randn(k, d, device=device) |
| return w / (w.norm(dim=1, keepdim=True) + 1e-8) |
|
|
| def _qmc_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor: |
| """Generate quasi-Monte Carlo projection directions using Sobol sequences.""" |
| vecs, self.sobol_engine = sobol_sphere(k, d, device, self.sobol_engine) |
| return vecs.view(k, d) |
|
|
| def _draw_dirs(self, k: int, d: int, device: torch.device) -> torch.Tensor: |
| """Draw projection directions using the specified sampling mode.""" |
| if self.sampling_mode == "gaussian": |
| return self._gaussian_proposals(k, d, device) |
| if self.sampling_mode == "qmc": |
| return self._qmc_proposals(k, d, device) |
| raise ValueError("bad sampling_mode") |
|
|
| @staticmethod |
| def _duplicate_to_match(a: torch.Tensor, b: torch.Tensor, method: str): |
| """ |
| Make two tensors have the same length by duplicating the shorter one. |
| |
| Args |
| ---- |
| a, b : (K, N₁) and (K, N₂) tensors |
| method : "random_replicate" or "interpolate" |
| |
| Returns |
| ------- |
| a, b : Tensors with matching second dimension |
| """ |
| if a.shape[1] == b.shape[1]: |
| return a, b |
| if a.shape[1] < b.shape[1]: |
| a, b = b, a |
|
|
| K, NA = a.shape |
| NB = b.shape[1] |
|
|
| |
| if method == "random_replicate": |
| repeats = NA // NB |
| b = torch.cat([b] * repeats, dim=1) |
| if b.shape[1] < NA: |
| idx = torch.randint(0, NB, (NA - b.shape[1],), device=b.device) |
| b = torch.cat([b, b[:, idx]], dim=1) |
| else: |
| b = F.interpolate( |
| b.unsqueeze(0), size=(NA,), mode="linear", align_corners=False |
| ).squeeze(0) |
| return a, b |
|
|
| def reset(self): |
| """Reset the reservoir sampling state.""" |
| if self.restir_enabled: |
| self._reservoir_filters = torch.empty(0) |
| self._reservoir_weights = torch.empty(0) |
| self._cumulative_weights.data.fill_(0) |
| self._has_reservoir.fill_(False) |
| self._reservoir_steps = torch.empty(0, dtype=torch.long) |
| self._reservoir_keys = torch.empty(0) |
|
|
| def _wrs_multi( |
| self, filters: torch.Tensor, weights: torch.Tensor, step: int |
| ) -> torch.Tensor: |
| """ |
| Weighted reservoir sampling that keeps exactly self.reservoir_size samples and |
| returns their indices inside the concatenated candidate set. |
| |
| Args |
| ---- |
| filters : (K+M, D) tensor of candidate directions |
| weights : (K+M,) tensor of importance weights |
| step : Current training step |
| |
| Returns |
| ------- |
| keep_idx : Indices of kept samples |
| keep_w : Normalized weights of kept samples |
| """ |
| R = self.reservoir_size |
| device = weights.device |
|
|
| u = torch.rand_like(weights) |
| keys = u.pow(1.0 / weights.clamp_min(1e-9)) |
|
|
| if not self._has_reservoir.item(): |
| self._reservoir_filters = filters[:R] |
| self._reservoir_weights = weights[:R] |
| self._reservoir_keys = keys[:R] |
| self._reservoir_steps = torch.full( |
| (R,), step, dtype=torch.long, device=device |
| ) |
| self._has_reservoir.fill_(True) |
|
|
| new_filters = filters[R:] |
| new_keys = keys[R:] |
| new_weights = weights[R:] |
| new_steps = torch.full( |
| (new_filters.size(0),), step, dtype=torch.long, device=device |
| ) |
|
|
| all_filters = torch.cat([self._reservoir_filters, new_filters], 0) |
| all_keys = torch.cat([self._reservoir_keys, new_keys], 0) |
| all_weights = torch.cat([self._reservoir_weights, new_weights], 0) |
| all_steps = torch.cat([self._reservoir_steps, new_steps], 0) |
|
|
| topk_keys, topk_idx = torch.topk(all_keys, R, largest=True) |
|
|
| self._reservoir_filters = all_filters[topk_idx] |
| self._reservoir_weights = all_weights[topk_idx] |
| self._reservoir_keys = topk_keys |
| self._reservoir_steps = all_steps[topk_idx] |
|
|
| |
| keep_idx = torch.cat( |
| [ |
| torch.arange(R, device=device), |
| torch.arange(R, R + new_filters.size(0), device=device), |
| ] |
| )[topk_idx] |
| keep_w = self._reservoir_weights / self._reservoir_weights.sum().clamp_min( |
| 1e-12 |
| ) |
| return keep_idx, keep_w |
|
|
| def _apply_time_decay(self, step: int): |
| """ |
| Apply exponential time decay to stored reservoir weights. |
| |
| Args |
| ---- |
| step : Current training step |
| """ |
| if self.time_decay_tau is None or not self._has_reservoir.item(): |
| return |
| age = (step - self._reservoir_steps).to(torch.float32) |
| decay = torch.exp(-age / self.time_decay_tau).to(self._reservoir_weights.dtype) |
| self._reservoir_weights.mul_(decay) |
| self._reservoir_keys.mul_(decay) |
|
|
| def forward( |
| self, |
| pred: Float[torch.Tensor, "B D N"], |
| gt: Float[torch.Tensor, "B D N"], |
| step: int, |
| ): |
| """ |
| Compute the sliced Wasserstein distance between predicted and ground truth |
| sequences. |
| |
| Args |
| ---- |
| pred : (B, D, N) tensor of predicted sequences |
| gt : (B, D, N) tensor of ground truth sequences |
| step : Current training step for reservoir sampling |
| |
| Returns |
| ------- |
| loss : Scalar tensor containing the computed loss |
| """ |
| B, D, N = pred.shape |
| K = self.num_proj |
| M = self.num_new_candidates |
| R = self.reservoir_size |
| device = pred.device |
| gt = gt.detach() |
|
|
| self._apply_time_decay(step) |
|
|
| |
| if step % self.refresh_projections_every_n_steps == 0: |
| new_dirs = self._draw_dirs( |
| M if self.restir_enabled and self._has_reservoir.item() else K, |
| D, |
| device, |
| ) |
| self._cached_dirs = new_dirs |
| else: |
| new_dirs = self._cached_dirs |
|
|
| if self.restir_enabled and self._has_reservoir.item(): |
| cand_dirs = torch.cat( |
| [self._reservoir_filters, new_dirs], dim=0 |
| ) |
| else: |
| cand_dirs = new_dirs |
|
|
| |
| cand_pred = process_vector(pred, cand_dirs) |
| cand_gt = process_vector(gt, cand_dirs) |
|
|
| cand_pred, cand_gt = self._duplicate_to_match( |
| cand_pred, cand_gt, self.missing_value_method |
| ) |
|
|
| cand_pred = cand_pred.sort(dim=1).values |
| cand_gt = cand_gt.sort(dim=1).values |
|
|
| |
| if self.restir_enabled: |
| with torch.no_grad(): |
| base = cand_pred - cand_gt |
| base = base.abs() if self.distance == "l1" else base.square() |
| ris_weights = base.mean(1) |
| keep_idx, keep_w = self._wrs_multi(cand_dirs, ris_weights, step) |
|
|
| w = keep_w |
| w_hat = keep_w |
|
|
| dirs = cand_dirs[keep_idx] |
| proj_pred = cand_pred[keep_idx] |
| proj_gt = cand_gt[keep_idx] |
| else: |
| dirs = cand_dirs |
| proj_pred = cand_pred |
| proj_gt = cand_gt |
| w = torch.full((dirs.shape[0],), 1.0 / K, device=device) |
|
|
| |
| diff = proj_pred - proj_gt |
| diff = diff.abs() if self.distance == "l1" else diff.square() |
| per_slice = diff.mean(1) |
|
|
| if self.use_ucv or self.use_lcv: |
| X_vecs = pred.permute(0, 2, 1).reshape(-1, D) |
| Y_vecs = gt.permute(0, 2, 1).reshape(-1, D) |
|
|
| m1 = X_vecs.mean(0) |
| m2 = Y_vecs.mean(0) |
| diff_m = m1 - m2 |
|
|
| theta = dirs |
|
|
| if self.use_ucv: |
| diff_X = X_vecs - m1 |
| diff_Y = Y_vecs - m2 |
|
|
| d = D |
| trSigX = diff_X.pow(2).mean() |
| trSigY = diff_Y.pow(2).mean() |
| G_bar = (diff_m @ diff_m) / d + (trSigX + trSigY) |
|
|
| delta2 = (theta @ diff_m) ** 2 |
|
|
| proj_X = diff_X @ theta.t() |
| proj_Y = diff_Y @ theta.t() |
| varX = proj_X.pow(2).mean(0) |
| varY = proj_Y.pow(2).mean(0) |
| G_hat = delta2 + varX + varY |
| else: |
| d = D |
| G_bar = (diff_m @ diff_m) / d |
| G_hat = (theta @ diff_m) ** 2 |
|
|
| diff_hat_G_mean_G = G_hat - G_bar |
|
|
| hat_A = (w * per_slice).sum() |
| var_G = (w * diff_hat_G_mean_G.pow(2)).sum() |
| cov_AG = (w * (per_slice - hat_A) * diff_hat_G_mean_G).sum() |
| hat_alpha = cov_AG / (var_G + 1e-12) |
| loss = hat_A - hat_alpha * (w * diff_hat_G_mean_G).sum() |
| else: |
| loss = (w * per_slice).sum() |
|
|
| |
| if self.restir_enabled and self.ess_alpha > 0: |
| with torch.no_grad(): |
| ess = (w_hat.sum().square()) / (w_hat.square().sum() + 1e-12) |
| ess = torch.nan_to_num(ess, nan=0.0, posinf=R, neginf=0.0).item() |
| if ess < self.ess_alpha * R: |
| print(f"ESS: {ess} is less than {self.ess_alpha * R}, resetting") |
| self.reset() |
|
|
| return loss |
|
|