Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman's picture
4096-release (#1)
0f31e57
import torch
from .wire_varifold_kernels import (
loss_simpson3_batch,
loss_simpson3_mix_batch,
)
def segments_to_vertices_edges(segments: torch.Tensor):
segs = torch.as_tensor(segments, dtype=torch.float32)
vertices = segs.reshape(-1, 3)
edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])]
return vertices, edges
def varifold_loss_batch(
pred_segments: torch.Tensor,
gt_segments: torch.Tensor,
*,
sigma: float = 0.1,
variant: str = "semi_lobatto3",
t_nodes01: torch.Tensor | None = None,
t_w: torch.Tensor | None = None,
sigmas: torch.Tensor | None = None,
alpha: torch.Tensor | None = None,
normalize_alpha: bool = True,
len_pow: float | None = None,
gt_mask: torch.Tensor | None = None,
pred_weights: torch.Tensor | None = None,
cross_only: bool = False,
) -> torch.Tensor:
if pred_segments.dim() != 4 or gt_segments.dim() != 4:
raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)")
p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1]
p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1]
w_gt = None
if gt_mask is not None:
w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype)
w_pred = None
if pred_weights is not None:
w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
if variant != "simpson3":
raise ValueError(
f"Unsupported varifold variant: {variant!r}. "
f"Only 'simpson3' is supported in batch mode.")
if sigmas is not None or alpha is not None:
if sigmas is None or alpha is None:
raise ValueError("sigmas and alpha are required for simpson3 mix")
return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)