"""Loss computation for wireframe prediction.""" from __future__ import annotations import torch from .varifold import varifold_loss_batch from .sinkhorn import batched_sinkhorn_loss # Varifold config VARIANT = "simpson3" SIGMAS = [0.5, 1.0, 2.0] # meters (divided by per-scene scale at runtime) ALPHAS = [0.2, 0.6, 0.2] LEN_POW = 1.0 VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) blowup) # Sinkhorn config (note: near-zero gradients at eps=0.05, effectively disabled) SINKHORN_EPS = 0.05 SINKHORN_ITERS = 10 # Sinkhorn dustbin cost: controls the OT "not matching" penalty. # Like tau, this is an OT behavior parameter, NOT a physical distance. # Must be comparable to typical matching costs in normalized space (~0.1). # Do NOT divide by scale. SINKHORN_DUSTBIN = 0.1 MAX_GT = 64 # fixed pad size for compile-friendly shapes # Precomputed constants (created once on first call) _loss_constants = {} def _get_loss_constants(device, dtype): key = (device, dtype) if key not in _loss_constants: _loss_constants[key] = { "sigmas": torch.tensor(SIGMAS, device=device, dtype=dtype), "alphas": torch.tensor(ALPHAS, device=device, dtype=dtype), } return _loss_constants[key] def pad_gt_fixed(gt_list, device, dtype): """Pad GT segments to fixed MAX_GT for compile-friendly shapes.""" B = len(gt_list) gt_pad = torch.zeros((B, MAX_GT, 2, 3), device=device, dtype=dtype) gt_mask = torch.zeros((B, MAX_GT), device=device, dtype=torch.bool) gt_lengths = torch.zeros(B, device=device, dtype=dtype) for i, g in enumerate(gt_list): n = g.shape[0] if n > 0: gt_pad[i, :n] = g gt_mask[i, :n] = True gt_lengths[i] = torch.linalg.norm(g[:, 1] - g[:, 0], dim=-1).sum() return gt_pad, gt_mask, gt_lengths def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales, sigmas, alphas, varifold_w): """Pure tensor loss -- no Python control flow, no boolean indexing.""" has_gt = (gt_lengths > 0).float() sigmas_eff = sigmas / scales[:, None] loss_batch = varifold_loss_batch( pred_segments, gt_pad, gt_mask=gt_mask, variant=VARIANT, sigmas=sigmas_eff, alpha=alphas, len_pow=LEN_POW, cross_only=VARIFOLD_CROSS_ONLY, ) v = loss_batch / gt_lengths.clamp(min=1.0) v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0) total = varifold_w * v return total, v # Will be replaced with compiled version on CUDA _loss_fn = _loss_inner def compute_loss(pred_segments, gt_list, scales, device, varifold_w, sinkhorn_w, endpoint_w=0.0, conf_logits=None, conf_weight=0.0, conf_mode="sinkhorn", sinkhorn_eps=None, sinkhorn_iters=None, sinkhorn_dustbin=None, conf_clamp_min=None): """Combined loss with fixed-size GT padding. conf_mode: "sinkhorn" = conf-weighted sinkhorn, "sinkhorn_detach" = detached conf. """ if conf_logits is not None and conf_clamp_min is not None: conf_logits = conf_logits.clamp(min=conf_clamp_min) gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype) c = _get_loss_constants(device, pred_segments.dtype) total, v = _loss_fn( pred_segments, gt_pad, gt_mask, gt_lengths, scales, c["sigmas"], c["alphas"], varifold_w) terms = {} if varifold_w > 0: terms["varifold"] = v.detach() if sinkhorn_w > 0: has_gt = (gt_lengths > 0).float() if conf_logits is not None and conf_mode == "sinkhorn": pred_mass = torch.sigmoid(conf_logits) elif conf_logits is not None and conf_mode == "sinkhorn_detach": pred_mass = torch.sigmoid(conf_logits.detach()) else: pred_mass = None eps = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS iters = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS dustbin = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN S = pred_segments.shape[1] sink_per = batched_sinkhorn_loss( pred_segments, gt_pad, gt_mask, eps, iters, dustbin, pred_mass=pred_mass, ) / (gt_lengths.clamp(min=1.0) * S) s = (sink_per * has_gt).sum() / has_gt.sum().clamp(min=1.0) total = total + sinkhorn_w * s terms["sinkhorn"] = s.detach() if conf_logits is not None and conf_weight > 0: if conf_mode in ("sinkhorn", "sinkhorn_detach"): conf_w = torch.sigmoid(conf_logits) S = conf_logits.shape[1] gt_counts = gt_mask.sum(dim=1).float() conf_sum = conf_w.sum(dim=1) reg = (((conf_sum - gt_counts) / S) ** 2).mean() total = total + conf_weight * reg terms["conf_reg"] = reg.detach() else: raise ValueError(f"Unknown conf_mode: {conf_mode}") if endpoint_w > 0: has_gt = (gt_lengths > 0).float() eps_ep = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS iters_ep = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS dustbin_ep = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN B, S = pred_segments.shape[:2] M = gt_pad.shape[1] # Compute hard assignment via sinkhorn (detached -- matching is not trained) with torch.no_grad(): pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None sink_loss_for_assign = batched_sinkhorn_loss( pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep, pred_mass=pred_mass_ep) p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1] g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1] mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0) mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0) d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1) len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6) len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6) dir_p, dir_g = half_p / len_p, half_g / len_g cos_a = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1) d_dir = 1.0 - cos_a.abs() d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs() cost = d_mid + d_dir + d_len dc = torch.as_tensor(dustbin_ep, device=cost.device, dtype=cost.dtype) cost = torch.where(gt_mask.unsqueeze(1), cost, dc * 10.0) cost_pad = dc.expand(B, S + 1, M + 1).clone() cost_pad[:, :S, :M] = cost cost_pad[:, -1, -1] = 0.0 gt_counts = gt_mask.sum(dim=1).float() if pred_mass_ep is not None: pm = pred_mass_ep.clamp(min=0.0) a = torch.cat([pm, (gt_counts - pm.sum(1)).clamp(min=0).unsqueeze(1)], dim=1) b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype) b_val[:, :M] = gt_mask.float() b_val[:, -1] = (pm.sum(1) - gt_counts).clamp(min=0) else: n = float(S) denom = n + gt_counts a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone() a[:, -1] = gt_counts / denom b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone() b_val[:, -1] = n / denom b_val[:, :M] = b_val[:, :M] * gt_mask.float() log_a = torch.log(a + 1e-9) log_b = torch.log(b_val + 1e-9) log_k = -cost_pad / eps_ep log_u = torch.zeros_like(a) log_v = torch.zeros_like(b_val) for _ in range(iters_ep): log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2) log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1) transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k) assignment = transport[:, :S, :M+1].argmax(dim=2) assignment[assignment >= M] = -1 # Everything below is WITH gradients (assignment is detached but pred_segments is live) matched = (assignment >= 0) # [B, S] n_matched = matched.float().sum().clamp(min=1.0) assign_safe = assignment.clamp(min=0) gt_matched = gt_pad[ torch.arange(B, device=device)[:, None].expand(B, S), assign_safe] # [B, S, 2, 3] # Symmetric endpoint distance ref_ep1 = pred_segments[:, :, 0] ref_ep2 = pred_segments[:, :, 1] gt_ep1 = gt_matched[:, :, 0] gt_ep2 = gt_matched[:, :, 1] dist_fwd = (ref_ep1 - gt_ep1).norm(dim=-1) + (ref_ep2 - gt_ep2).norm(dim=-1) dist_rev = (ref_ep1 - gt_ep2).norm(dim=-1) + (ref_ep2 - gt_ep1).norm(dim=-1) ep_dist = torch.min(dist_fwd, dist_rev) # Normalize by GT total length * S (same scale as sinkhorn) ep_loss = (ep_dist * matched.float()).sum() / n_matched total = total + endpoint_w * ep_loss terms["endpoint"] = ep_loss.detach() return total, terms