| """Loss computation for wireframe prediction.""" |
| from __future__ import annotations |
|
|
| import torch |
|
|
| from .varifold import varifold_loss_batch |
| from .sinkhorn import batched_sinkhorn_loss |
|
|
| |
| VARIANT = "simpson3" |
| SIGMAS = [0.5, 1.0, 2.0] |
| ALPHAS = [0.2, 0.6, 0.2] |
| LEN_POW = 1.0 |
| VARIFOLD_CROSS_ONLY = False |
|
|
| |
| SINKHORN_EPS = 0.05 |
| SINKHORN_ITERS = 10 |
|
|
| |
| |
| |
| |
| SINKHORN_DUSTBIN = 0.1 |
|
|
| MAX_GT = 64 |
|
|
| |
| _loss_constants = {} |
|
|
|
|
| def _get_loss_constants(device, dtype): |
| key = (device, dtype) |
| if key not in _loss_constants: |
| _loss_constants[key] = { |
| "sigmas": torch.tensor(SIGMAS, device=device, dtype=dtype), |
| "alphas": torch.tensor(ALPHAS, device=device, dtype=dtype), |
| } |
| return _loss_constants[key] |
|
|
|
|
| def pad_gt_fixed(gt_list, device, dtype): |
| """Pad GT segments to fixed MAX_GT for compile-friendly shapes.""" |
| B = len(gt_list) |
| gt_pad = torch.zeros((B, MAX_GT, 2, 3), device=device, dtype=dtype) |
| gt_mask = torch.zeros((B, MAX_GT), device=device, dtype=torch.bool) |
| gt_lengths = torch.zeros(B, device=device, dtype=dtype) |
| for i, g in enumerate(gt_list): |
| n = g.shape[0] |
| if n > 0: |
| gt_pad[i, :n] = g |
| gt_mask[i, :n] = True |
| gt_lengths[i] = torch.linalg.norm(g[:, 1] - g[:, 0], dim=-1).sum() |
| return gt_pad, gt_mask, gt_lengths |
|
|
|
|
| def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales, |
| sigmas, alphas, varifold_w): |
| """Pure tensor loss -- no Python control flow, no boolean indexing.""" |
| has_gt = (gt_lengths > 0).float() |
|
|
| sigmas_eff = sigmas / scales[:, None] |
| loss_batch = varifold_loss_batch( |
| pred_segments, gt_pad, gt_mask=gt_mask, |
| variant=VARIANT, sigmas=sigmas_eff, alpha=alphas, len_pow=LEN_POW, |
| cross_only=VARIFOLD_CROSS_ONLY, |
| ) |
| v = loss_batch / gt_lengths.clamp(min=1.0) |
| v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0) |
|
|
| total = varifold_w * v |
| return total, v |
|
|
|
|
| |
| _loss_fn = _loss_inner |
|
|
|
|
| def compute_loss(pred_segments, gt_list, scales, device, |
| varifold_w, sinkhorn_w, |
| endpoint_w=0.0, |
| conf_logits=None, conf_weight=0.0, conf_mode="sinkhorn", |
| sinkhorn_eps=None, sinkhorn_iters=None, |
| sinkhorn_dustbin=None, conf_clamp_min=None): |
| """Combined loss with fixed-size GT padding. |
| |
| conf_mode: "sinkhorn" = conf-weighted sinkhorn, "sinkhorn_detach" = detached conf. |
| """ |
| if conf_logits is not None and conf_clamp_min is not None: |
| conf_logits = conf_logits.clamp(min=conf_clamp_min) |
| gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype) |
| c = _get_loss_constants(device, pred_segments.dtype) |
|
|
| total, v = _loss_fn( |
| pred_segments, gt_pad, gt_mask, gt_lengths, scales, |
| c["sigmas"], c["alphas"], varifold_w) |
|
|
| terms = {} |
| if varifold_w > 0: |
| terms["varifold"] = v.detach() |
|
|
| if sinkhorn_w > 0: |
| has_gt = (gt_lengths > 0).float() |
| if conf_logits is not None and conf_mode == "sinkhorn": |
| pred_mass = torch.sigmoid(conf_logits) |
| elif conf_logits is not None and conf_mode == "sinkhorn_detach": |
| pred_mass = torch.sigmoid(conf_logits.detach()) |
| else: |
| pred_mass = None |
| eps = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS |
| iters = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS |
| dustbin = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN |
| S = pred_segments.shape[1] |
| sink_per = batched_sinkhorn_loss( |
| pred_segments, gt_pad, gt_mask, |
| eps, iters, dustbin, |
| pred_mass=pred_mass, |
| ) / (gt_lengths.clamp(min=1.0) * S) |
| s = (sink_per * has_gt).sum() / has_gt.sum().clamp(min=1.0) |
| total = total + sinkhorn_w * s |
| terms["sinkhorn"] = s.detach() |
|
|
| if conf_logits is not None and conf_weight > 0: |
| if conf_mode in ("sinkhorn", "sinkhorn_detach"): |
| conf_w = torch.sigmoid(conf_logits) |
| S = conf_logits.shape[1] |
| gt_counts = gt_mask.sum(dim=1).float() |
| conf_sum = conf_w.sum(dim=1) |
| reg = (((conf_sum - gt_counts) / S) ** 2).mean() |
| total = total + conf_weight * reg |
| terms["conf_reg"] = reg.detach() |
| else: |
| raise ValueError(f"Unknown conf_mode: {conf_mode}") |
|
|
| if endpoint_w > 0: |
| has_gt = (gt_lengths > 0).float() |
| eps_ep = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS |
| iters_ep = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS |
| dustbin_ep = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN |
| B, S = pred_segments.shape[:2] |
| M = gt_pad.shape[1] |
|
|
| |
| with torch.no_grad(): |
| pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None |
| sink_loss_for_assign = batched_sinkhorn_loss( |
| pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep, |
| pred_mass=pred_mass_ep) |
| p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1] |
| g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1] |
| mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0) |
| mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0) |
| d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1) |
| len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6) |
| len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6) |
| dir_p, dir_g = half_p / len_p, half_g / len_g |
| cos_a = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1) |
| d_dir = 1.0 - cos_a.abs() |
| d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs() |
| cost = d_mid + d_dir + d_len |
| dc = torch.as_tensor(dustbin_ep, device=cost.device, dtype=cost.dtype) |
| cost = torch.where(gt_mask.unsqueeze(1), cost, dc * 10.0) |
| cost_pad = dc.expand(B, S + 1, M + 1).clone() |
| cost_pad[:, :S, :M] = cost |
| cost_pad[:, -1, -1] = 0.0 |
| gt_counts = gt_mask.sum(dim=1).float() |
| if pred_mass_ep is not None: |
| pm = pred_mass_ep.clamp(min=0.0) |
| a = torch.cat([pm, (gt_counts - pm.sum(1)).clamp(min=0).unsqueeze(1)], dim=1) |
| b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype) |
| b_val[:, :M] = gt_mask.float() |
| b_val[:, -1] = (pm.sum(1) - gt_counts).clamp(min=0) |
| else: |
| n = float(S) |
| denom = n + gt_counts |
| a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone() |
| a[:, -1] = gt_counts / denom |
| b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone() |
| b_val[:, -1] = n / denom |
| b_val[:, :M] = b_val[:, :M] * gt_mask.float() |
| log_a = torch.log(a + 1e-9) |
| log_b = torch.log(b_val + 1e-9) |
| log_k = -cost_pad / eps_ep |
| log_u = torch.zeros_like(a) |
| log_v = torch.zeros_like(b_val) |
| for _ in range(iters_ep): |
| log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2) |
| log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1) |
| transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k) |
| assignment = transport[:, :S, :M+1].argmax(dim=2) |
| assignment[assignment >= M] = -1 |
|
|
| |
| matched = (assignment >= 0) |
| n_matched = matched.float().sum().clamp(min=1.0) |
| assign_safe = assignment.clamp(min=0) |
| gt_matched = gt_pad[ |
| torch.arange(B, device=device)[:, None].expand(B, S), |
| assign_safe] |
|
|
| |
| ref_ep1 = pred_segments[:, :, 0] |
| ref_ep2 = pred_segments[:, :, 1] |
| gt_ep1 = gt_matched[:, :, 0] |
| gt_ep2 = gt_matched[:, :, 1] |
| dist_fwd = (ref_ep1 - gt_ep1).norm(dim=-1) + (ref_ep2 - gt_ep2).norm(dim=-1) |
| dist_rev = (ref_ep1 - gt_ep2).norm(dim=-1) + (ref_ep2 - gt_ep1).norm(dim=-1) |
| ep_dist = torch.min(dist_fwd, dist_rev) |
|
|
| |
| ep_loss = (ep_dist * matched.float()).sum() / n_matched |
| total = total + endpoint_w * ep_loss |
| terms["endpoint"] = ep_loss.detach() |
|
|
| return total, terms |
|
|