File size: 6,105 Bytes
f4487da 0f31e57 f4487da 0f31e57 f4487da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import torch
# -----------------------------
# Helpers
# -----------------------------
def segment_geom(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9):
"""
p,q: (...,3)
returns d, a, ell, u:
d = q - p
a = ||d||^2
ell = sqrt(a + eps^2)
u = d / ell
"""
d = q - p
a = (d * d).sum(dim=-1)
eps_val = eps
if p.dtype in (torch.float16, torch.bfloat16):
eps_val = max(eps, float(torch.finfo(p.dtype).eps))
ell = torch.sqrt(a + eps_val * eps_val)
u = d / ell.unsqueeze(-1)
return d, a, ell, u
def sample_points(p: torch.Tensor, q: torch.Tensor, nodes01: torch.Tensor):
# (...,3) + (K,) -> (...,K,3)
d = q - p
nodes = nodes01.to(device=p.device, dtype=p.dtype)
shape = [1] * (p.dim() - 1) + [nodes.shape[0], 1]
nodes = nodes.view(*shape)
return p.unsqueeze(-2) + nodes * d.unsqueeze(-2)
# Fixed Lobatto-3 / Simpson nodes+weights on [0,1]
LOBATTO3_NODES = torch.tensor([0.0, 0.5, 1.0])
# LOBATTO3_W = torch.tensor([1.0/6.0, 4.0/6.0, 1.0/6.0])
LOBATTO3_W = torch.tensor([1/3, 1/3, 1/3])
LOBATTO3_W2 = LOBATTO3_W[:, None] * LOBATTO3_W[None, :] # (3,3)
def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
if normalize_alpha:
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
return sigmas_t, alpha_t
# -----------------------------
# Simpson-3 on both segments (3x3 product rule)
# -----------------------------
def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
if w is None:
return None
w = torch.as_tensor(w, device=device, dtype=dtype)
if w.dim() == 1:
if w.shape[0] != n:
raise ValueError(f"weight length {w.shape[0]} != {n}")
w = w.unsqueeze(0).expand(b, -1)
elif w.dim() == 2:
if w.shape[0] != b or w.shape[1] != n:
raise ValueError(f"weight shape {tuple(w.shape)} != ({b}, {n})")
else:
raise ValueError("weights must be 1D or 2D")
return w
def cross_simpson3(
pA,
qA,
pB,
qB,
sigma: float | torch.Tensor,
wA: torch.Tensor | None = None,
wB: torch.Tensor | None = None,
):
device, dtype = pA.device, pA.dtype
batched = pA.dim() == 3
if not batched:
pA = pA.unsqueeze(0)
qA = qA.unsqueeze(0)
pB = pB.unsqueeze(0)
qB = qB.unsqueeze(0)
nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
bsz, nA, _ = pA.shape
nB = pB.shape[1]
wA = _prep_weight(wA, nA, bsz, device, dtype)
wB = _prep_weight(wB, nB, bsz, device, dtype)
_, _, ellA, uA = segment_geom(pA, qA)
_, _, ellB, uB = segment_geom(pB, qB)
XA = sample_points(pA, qA, nodes) # (B,N,3,3)
YB = sample_points(pB, qB, nodes) # (B,M,3,3)
# angular + length factors: (N,M)
ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
lenfac = ellA[:, :, None] * ellB[:, None, :]
if wA is not None or wB is not None:
if wA is None:
wA = torch.ones((bsz, nA), device=device, dtype=dtype)
if wB is None:
wB = torch.ones((bsz, nB), device=device, dtype=dtype)
lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
# spatial: build (N,M,3,3) kernel via broadcasting
diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
if sigma_t.ndim == 0:
inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
else:
if sigma_t.shape[0] != bsz:
raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
return out[0] if not batched else out
# -----------------------------
# Batch losses
# -----------------------------
def loss_simpson3_batch(
p_pred: torch.Tensor,
q_pred: torch.Tensor,
p_gt: torch.Tensor,
q_gt: torch.Tensor,
sigma: float | torch.Tensor,
w_gt: torch.Tensor | None = None,
w_pred: torch.Tensor | None = None,
cross_only: bool = False,
) -> torch.Tensor:
cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wA=w_pred, wB=w_gt)
if cross_only:
# No self-energy: avoids O(S^2) blowup, sinkhorn handles repulsion
return -2.0 * cross
s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma, wA=w_pred, wB=w_pred)
return s_pred - 2.0 * cross
def loss_simpson3_mix_batch(
p_pred: torch.Tensor,
q_pred: torch.Tensor,
p_gt: torch.Tensor,
q_gt: torch.Tensor,
sigmas,
alpha,
w_gt: torch.Tensor | None = None,
w_pred: torch.Tensor | None = None,
normalize_alpha: bool = True,
cross_only: bool = False,
) -> torch.Tensor:
device, dtype = p_pred.device, p_pred.dtype
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
if normalize_alpha:
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
if sigmas_t.ndim == 1:
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, s, w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for s in sigmas_t]
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
if sigmas_t.ndim == 2:
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
raise ValueError("sigmas must be 1D or 2D for batch loss")
|