Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
File size: 2,006 Bytes
f4487da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f31e57
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch

from .wire_varifold_kernels import (
    loss_simpson3_batch,
    loss_simpson3_mix_batch,
)


def segments_to_vertices_edges(segments: torch.Tensor):
    segs = torch.as_tensor(segments, dtype=torch.float32)
    vertices = segs.reshape(-1, 3)
    edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])]
    return vertices, edges


def varifold_loss_batch(
    pred_segments: torch.Tensor,
    gt_segments: torch.Tensor,
    *,
    sigma: float = 0.1,
    variant: str = "semi_lobatto3",
    t_nodes01: torch.Tensor | None = None,
    t_w: torch.Tensor | None = None,
    sigmas: torch.Tensor | None = None,
    alpha: torch.Tensor | None = None,
    normalize_alpha: bool = True,
    len_pow: float | None = None,
    gt_mask: torch.Tensor | None = None,
    pred_weights: torch.Tensor | None = None,
    cross_only: bool = False,
) -> torch.Tensor:
    if pred_segments.dim() != 4 or gt_segments.dim() != 4:
        raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)")
    p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1]
    p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1]

    w_gt = None
    if gt_mask is not None:
        w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype)

    w_pred = None
    if pred_weights is not None:
        w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)

    if variant != "simpson3":
        raise ValueError(
            f"Unsupported varifold variant: {variant!r}. "
            f"Only 'simpson3' is supported in batch mode.")
    if sigmas is not None or alpha is not None:
        if sigmas is None or alpha is None:
            raise ValueError("sigmas and alpha are required for simpson3 mix")
        return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
    return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)