| import torch |
|
|
| |
| |
| |
| def segment_geom(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9): |
| """ |
| p,q: (...,3) |
| returns d, a, ell, u: |
| d = q - p |
| a = ||d||^2 |
| ell = sqrt(a + eps^2) |
| u = d / ell |
| """ |
| d = q - p |
| a = (d * d).sum(dim=-1) |
| eps_val = eps |
| if p.dtype in (torch.float16, torch.bfloat16): |
| eps_val = max(eps, float(torch.finfo(p.dtype).eps)) |
| ell = torch.sqrt(a + eps_val * eps_val) |
| u = d / ell.unsqueeze(-1) |
| return d, a, ell, u |
|
|
| def sample_points(p: torch.Tensor, q: torch.Tensor, nodes01: torch.Tensor): |
| |
| d = q - p |
| nodes = nodes01.to(device=p.device, dtype=p.dtype) |
| shape = [1] * (p.dim() - 1) + [nodes.shape[0], 1] |
| nodes = nodes.view(*shape) |
| return p.unsqueeze(-2) + nodes * d.unsqueeze(-2) |
|
|
|
|
| |
| LOBATTO3_NODES = torch.tensor([0.0, 0.5, 1.0]) |
| |
| LOBATTO3_W = torch.tensor([1/3, 1/3, 1/3]) |
| LOBATTO3_W2 = LOBATTO3_W[:, None] * LOBATTO3_W[None, :] |
|
|
|
|
| def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool): |
| sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6) |
| alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype) |
| if normalize_alpha: |
| alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12) |
| return sigmas_t, alpha_t |
|
|
| |
| |
| |
| def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None: |
| if w is None: |
| return None |
| w = torch.as_tensor(w, device=device, dtype=dtype) |
| if w.dim() == 1: |
| if w.shape[0] != n: |
| raise ValueError(f"weight length {w.shape[0]} != {n}") |
| w = w.unsqueeze(0).expand(b, -1) |
| elif w.dim() == 2: |
| if w.shape[0] != b or w.shape[1] != n: |
| raise ValueError(f"weight shape {tuple(w.shape)} != ({b}, {n})") |
| else: |
| raise ValueError("weights must be 1D or 2D") |
| return w |
|
|
|
|
| def cross_simpson3( |
| pA, |
| qA, |
| pB, |
| qB, |
| sigma: float | torch.Tensor, |
| wA: torch.Tensor | None = None, |
| wB: torch.Tensor | None = None, |
| ): |
| device, dtype = pA.device, pA.dtype |
| batched = pA.dim() == 3 |
| if not batched: |
| pA = pA.unsqueeze(0) |
| qA = qA.unsqueeze(0) |
| pB = pB.unsqueeze(0) |
| qB = qB.unsqueeze(0) |
| nodes = LOBATTO3_NODES.to(device=device, dtype=dtype) |
| w2 = LOBATTO3_W2.to(device=device, dtype=dtype) |
|
|
| bsz, nA, _ = pA.shape |
| nB = pB.shape[1] |
| wA = _prep_weight(wA, nA, bsz, device, dtype) |
| wB = _prep_weight(wB, nB, bsz, device, dtype) |
|
|
| _, _, ellA, uA = segment_geom(pA, qA) |
| _, _, ellB, uB = segment_geom(pB, qB) |
|
|
| XA = sample_points(pA, qA, nodes) |
| YB = sample_points(pB, qB, nodes) |
|
|
| |
| ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2) |
| lenfac = ellA[:, :, None] * ellB[:, None, :] |
| if wA is not None or wB is not None: |
| if wA is None: |
| wA = torch.ones((bsz, nA), device=device, dtype=dtype) |
| if wB is None: |
| wB = torch.ones((bsz, nB), device=device, dtype=dtype) |
| lenfac = lenfac * (wA[:, :, None] * wB[:, None, :]) |
|
|
| |
| diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] |
| r2 = (diff * diff).sum(dim=-1) |
| sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype) |
| if sigma_t.ndim == 0: |
| inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t) |
| else: |
| if sigma_t.shape[0] != bsz: |
| raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}") |
| inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1) |
| K = torch.exp(-r2 * inv2s2) |
|
|
| spatial = (K * w2).sum(dim=-1).sum(dim=-1) |
| out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) |
| return out[0] if not batched else out |
|
|
|
|
| |
| |
| |
|
|
| def loss_simpson3_batch( |
| p_pred: torch.Tensor, |
| q_pred: torch.Tensor, |
| p_gt: torch.Tensor, |
| q_gt: torch.Tensor, |
| sigma: float | torch.Tensor, |
| w_gt: torch.Tensor | None = None, |
| w_pred: torch.Tensor | None = None, |
| cross_only: bool = False, |
| ) -> torch.Tensor: |
| cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wA=w_pred, wB=w_gt) |
| if cross_only: |
| |
| return -2.0 * cross |
| s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma, wA=w_pred, wB=w_pred) |
| return s_pred - 2.0 * cross |
|
|
|
|
| def loss_simpson3_mix_batch( |
| p_pred: torch.Tensor, |
| q_pred: torch.Tensor, |
| p_gt: torch.Tensor, |
| q_gt: torch.Tensor, |
| sigmas, |
| alpha, |
| w_gt: torch.Tensor | None = None, |
| w_pred: torch.Tensor | None = None, |
| normalize_alpha: bool = True, |
| cross_only: bool = False, |
| ) -> torch.Tensor: |
| device, dtype = p_pred.device, p_pred.dtype |
| sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6) |
| alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype) |
| if normalize_alpha: |
| alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12) |
| if sigmas_t.ndim == 1: |
| losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, s, w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for s in sigmas_t] |
| return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0) |
| if sigmas_t.ndim == 2: |
| losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])] |
| return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0) |
| raise ValueError("sigmas must be 1D or 2D for batch loss") |
|
|