Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
learned-baseline-2026 / s23dr_2026_example /wire_varifold_kernels.py
jacklangerman's picture
4096-release (#1)
0f31e57
import torch
# -----------------------------
# Helpers
# -----------------------------
def segment_geom(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9):
"""
p,q: (...,3)
returns d, a, ell, u:
d = q - p
a = ||d||^2
ell = sqrt(a + eps^2)
u = d / ell
"""
d = q - p
a = (d * d).sum(dim=-1)
eps_val = eps
if p.dtype in (torch.float16, torch.bfloat16):
eps_val = max(eps, float(torch.finfo(p.dtype).eps))
ell = torch.sqrt(a + eps_val * eps_val)
u = d / ell.unsqueeze(-1)
return d, a, ell, u
def sample_points(p: torch.Tensor, q: torch.Tensor, nodes01: torch.Tensor):
# (...,3) + (K,) -> (...,K,3)
d = q - p
nodes = nodes01.to(device=p.device, dtype=p.dtype)
shape = [1] * (p.dim() - 1) + [nodes.shape[0], 1]
nodes = nodes.view(*shape)
return p.unsqueeze(-2) + nodes * d.unsqueeze(-2)
# Fixed Lobatto-3 / Simpson nodes+weights on [0,1]
LOBATTO3_NODES = torch.tensor([0.0, 0.5, 1.0])
# LOBATTO3_W = torch.tensor([1.0/6.0, 4.0/6.0, 1.0/6.0])
LOBATTO3_W = torch.tensor([1/3, 1/3, 1/3])
LOBATTO3_W2 = LOBATTO3_W[:, None] * LOBATTO3_W[None, :] # (3,3)
def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
if normalize_alpha:
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
return sigmas_t, alpha_t
# -----------------------------
# Simpson-3 on both segments (3x3 product rule)
# -----------------------------
def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
if w is None:
return None
w = torch.as_tensor(w, device=device, dtype=dtype)
if w.dim() == 1:
if w.shape[0] != n:
raise ValueError(f"weight length {w.shape[0]} != {n}")
w = w.unsqueeze(0).expand(b, -1)
elif w.dim() == 2:
if w.shape[0] != b or w.shape[1] != n:
raise ValueError(f"weight shape {tuple(w.shape)} != ({b}, {n})")
else:
raise ValueError("weights must be 1D or 2D")
return w
def cross_simpson3(
pA,
qA,
pB,
qB,
sigma: float | torch.Tensor,
wA: torch.Tensor | None = None,
wB: torch.Tensor | None = None,
):
device, dtype = pA.device, pA.dtype
batched = pA.dim() == 3
if not batched:
pA = pA.unsqueeze(0)
qA = qA.unsqueeze(0)
pB = pB.unsqueeze(0)
qB = qB.unsqueeze(0)
nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
bsz, nA, _ = pA.shape
nB = pB.shape[1]
wA = _prep_weight(wA, nA, bsz, device, dtype)
wB = _prep_weight(wB, nB, bsz, device, dtype)
_, _, ellA, uA = segment_geom(pA, qA)
_, _, ellB, uB = segment_geom(pB, qB)
XA = sample_points(pA, qA, nodes) # (B,N,3,3)
YB = sample_points(pB, qB, nodes) # (B,M,3,3)
# angular + length factors: (N,M)
ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
lenfac = ellA[:, :, None] * ellB[:, None, :]
if wA is not None or wB is not None:
if wA is None:
wA = torch.ones((bsz, nA), device=device, dtype=dtype)
if wB is None:
wB = torch.ones((bsz, nB), device=device, dtype=dtype)
lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
# spatial: build (N,M,3,3) kernel via broadcasting
diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
if sigma_t.ndim == 0:
inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
else:
if sigma_t.shape[0] != bsz:
raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
return out[0] if not batched else out
# -----------------------------
# Batch losses
# -----------------------------
def loss_simpson3_batch(
p_pred: torch.Tensor,
q_pred: torch.Tensor,
p_gt: torch.Tensor,
q_gt: torch.Tensor,
sigma: float | torch.Tensor,
w_gt: torch.Tensor | None = None,
w_pred: torch.Tensor | None = None,
cross_only: bool = False,
) -> torch.Tensor:
cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wA=w_pred, wB=w_gt)
if cross_only:
# No self-energy: avoids O(S^2) blowup, sinkhorn handles repulsion
return -2.0 * cross
s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma, wA=w_pred, wB=w_pred)
return s_pred - 2.0 * cross
def loss_simpson3_mix_batch(
p_pred: torch.Tensor,
q_pred: torch.Tensor,
p_gt: torch.Tensor,
q_gt: torch.Tensor,
sigmas,
alpha,
w_gt: torch.Tensor | None = None,
w_pred: torch.Tensor | None = None,
normalize_alpha: bool = True,
cross_only: bool = False,
) -> torch.Tensor:
device, dtype = p_pred.device, p_pred.dtype
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
if normalize_alpha:
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
if sigmas_t.ndim == 1:
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, s, w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for s in sigmas_t]
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
if sigmas_t.ndim == 2:
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
raise ValueError("sigmas must be 1D or 2D for batch loss")