Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman's picture
4096-release (#1)
0f31e57
"""Loss computation for wireframe prediction."""
from __future__ import annotations
import torch
from .varifold import varifold_loss_batch
from .sinkhorn import batched_sinkhorn_loss
# Varifold config
VARIANT = "simpson3"
SIGMAS = [0.5, 1.0, 2.0] # meters (divided by per-scene scale at runtime)
ALPHAS = [0.2, 0.6, 0.2]
LEN_POW = 1.0
VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) blowup)
# Sinkhorn config (note: near-zero gradients at eps=0.05, effectively disabled)
SINKHORN_EPS = 0.05
SINKHORN_ITERS = 10
# Sinkhorn dustbin cost: controls the OT "not matching" penalty.
# Like tau, this is an OT behavior parameter, NOT a physical distance.
# Must be comparable to typical matching costs in normalized space (~0.1).
# Do NOT divide by scale.
SINKHORN_DUSTBIN = 0.1
MAX_GT = 64 # fixed pad size for compile-friendly shapes
# Precomputed constants (created once on first call)
_loss_constants = {}
def _get_loss_constants(device, dtype):
key = (device, dtype)
if key not in _loss_constants:
_loss_constants[key] = {
"sigmas": torch.tensor(SIGMAS, device=device, dtype=dtype),
"alphas": torch.tensor(ALPHAS, device=device, dtype=dtype),
}
return _loss_constants[key]
def pad_gt_fixed(gt_list, device, dtype):
"""Pad GT segments to fixed MAX_GT for compile-friendly shapes."""
B = len(gt_list)
gt_pad = torch.zeros((B, MAX_GT, 2, 3), device=device, dtype=dtype)
gt_mask = torch.zeros((B, MAX_GT), device=device, dtype=torch.bool)
gt_lengths = torch.zeros(B, device=device, dtype=dtype)
for i, g in enumerate(gt_list):
n = g.shape[0]
if n > 0:
gt_pad[i, :n] = g
gt_mask[i, :n] = True
gt_lengths[i] = torch.linalg.norm(g[:, 1] - g[:, 0], dim=-1).sum()
return gt_pad, gt_mask, gt_lengths
def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
sigmas, alphas, varifold_w):
"""Pure tensor loss -- no Python control flow, no boolean indexing."""
has_gt = (gt_lengths > 0).float()
sigmas_eff = sigmas / scales[:, None]
loss_batch = varifold_loss_batch(
pred_segments, gt_pad, gt_mask=gt_mask,
variant=VARIANT, sigmas=sigmas_eff, alpha=alphas, len_pow=LEN_POW,
cross_only=VARIFOLD_CROSS_ONLY,
)
v = loss_batch / gt_lengths.clamp(min=1.0)
v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
total = varifold_w * v
return total, v
# Will be replaced with compiled version on CUDA
_loss_fn = _loss_inner
def compute_loss(pred_segments, gt_list, scales, device,
varifold_w, sinkhorn_w,
endpoint_w=0.0,
conf_logits=None, conf_weight=0.0, conf_mode="sinkhorn",
sinkhorn_eps=None, sinkhorn_iters=None,
sinkhorn_dustbin=None, conf_clamp_min=None):
"""Combined loss with fixed-size GT padding.
conf_mode: "sinkhorn" = conf-weighted sinkhorn, "sinkhorn_detach" = detached conf.
"""
if conf_logits is not None and conf_clamp_min is not None:
conf_logits = conf_logits.clamp(min=conf_clamp_min)
gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
c = _get_loss_constants(device, pred_segments.dtype)
total, v = _loss_fn(
pred_segments, gt_pad, gt_mask, gt_lengths, scales,
c["sigmas"], c["alphas"], varifold_w)
terms = {}
if varifold_w > 0:
terms["varifold"] = v.detach()
if sinkhorn_w > 0:
has_gt = (gt_lengths > 0).float()
if conf_logits is not None and conf_mode == "sinkhorn":
pred_mass = torch.sigmoid(conf_logits)
elif conf_logits is not None and conf_mode == "sinkhorn_detach":
pred_mass = torch.sigmoid(conf_logits.detach())
else:
pred_mass = None
eps = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
iters = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
dustbin = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
S = pred_segments.shape[1]
sink_per = batched_sinkhorn_loss(
pred_segments, gt_pad, gt_mask,
eps, iters, dustbin,
pred_mass=pred_mass,
) / (gt_lengths.clamp(min=1.0) * S)
s = (sink_per * has_gt).sum() / has_gt.sum().clamp(min=1.0)
total = total + sinkhorn_w * s
terms["sinkhorn"] = s.detach()
if conf_logits is not None and conf_weight > 0:
if conf_mode in ("sinkhorn", "sinkhorn_detach"):
conf_w = torch.sigmoid(conf_logits)
S = conf_logits.shape[1]
gt_counts = gt_mask.sum(dim=1).float()
conf_sum = conf_w.sum(dim=1)
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
total = total + conf_weight * reg
terms["conf_reg"] = reg.detach()
else:
raise ValueError(f"Unknown conf_mode: {conf_mode}")
if endpoint_w > 0:
has_gt = (gt_lengths > 0).float()
eps_ep = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
iters_ep = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
dustbin_ep = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
B, S = pred_segments.shape[:2]
M = gt_pad.shape[1]
# Compute hard assignment via sinkhorn (detached -- matching is not trained)
with torch.no_grad():
pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
sink_loss_for_assign = batched_sinkhorn_loss(
pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
pred_mass=pred_mass_ep)
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
dir_p, dir_g = half_p / len_p, half_g / len_g
cos_a = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
d_dir = 1.0 - cos_a.abs()
d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
cost = d_mid + d_dir + d_len
dc = torch.as_tensor(dustbin_ep, device=cost.device, dtype=cost.dtype)
cost = torch.where(gt_mask.unsqueeze(1), cost, dc * 10.0)
cost_pad = dc.expand(B, S + 1, M + 1).clone()
cost_pad[:, :S, :M] = cost
cost_pad[:, -1, -1] = 0.0
gt_counts = gt_mask.sum(dim=1).float()
if pred_mass_ep is not None:
pm = pred_mass_ep.clamp(min=0.0)
a = torch.cat([pm, (gt_counts - pm.sum(1)).clamp(min=0).unsqueeze(1)], dim=1)
b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype)
b_val[:, :M] = gt_mask.float()
b_val[:, -1] = (pm.sum(1) - gt_counts).clamp(min=0)
else:
n = float(S)
denom = n + gt_counts
a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone()
a[:, -1] = gt_counts / denom
b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone()
b_val[:, -1] = n / denom
b_val[:, :M] = b_val[:, :M] * gt_mask.float()
log_a = torch.log(a + 1e-9)
log_b = torch.log(b_val + 1e-9)
log_k = -cost_pad / eps_ep
log_u = torch.zeros_like(a)
log_v = torch.zeros_like(b_val)
for _ in range(iters_ep):
log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
assignment = transport[:, :S, :M+1].argmax(dim=2)
assignment[assignment >= M] = -1
# Everything below is WITH gradients (assignment is detached but pred_segments is live)
matched = (assignment >= 0) # [B, S]
n_matched = matched.float().sum().clamp(min=1.0)
assign_safe = assignment.clamp(min=0)
gt_matched = gt_pad[
torch.arange(B, device=device)[:, None].expand(B, S),
assign_safe] # [B, S, 2, 3]
# Symmetric endpoint distance
ref_ep1 = pred_segments[:, :, 0]
ref_ep2 = pred_segments[:, :, 1]
gt_ep1 = gt_matched[:, :, 0]
gt_ep2 = gt_matched[:, :, 1]
dist_fwd = (ref_ep1 - gt_ep1).norm(dim=-1) + (ref_ep2 - gt_ep2).norm(dim=-1)
dist_rev = (ref_ep1 - gt_ep2).norm(dim=-1) + (ref_ep2 - gt_ep1).norm(dim=-1)
ep_dist = torch.min(dist_fwd, dist_rev)
# Normalize by GT total length * S (same scale as sinkhorn)
ep_loss = (ep_dist * matched.float()).sum() / n_matched
total = total + endpoint_w * ep_loss
terms["endpoint"] = ep_loss.detach()
return total, terms