File size: 2,006 Bytes
f4487da 0f31e57 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
from .wire_varifold_kernels import (
loss_simpson3_batch,
loss_simpson3_mix_batch,
)
def segments_to_vertices_edges(segments: torch.Tensor):
segs = torch.as_tensor(segments, dtype=torch.float32)
vertices = segs.reshape(-1, 3)
edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])]
return vertices, edges
def varifold_loss_batch(
pred_segments: torch.Tensor,
gt_segments: torch.Tensor,
*,
sigma: float = 0.1,
variant: str = "semi_lobatto3",
t_nodes01: torch.Tensor | None = None,
t_w: torch.Tensor | None = None,
sigmas: torch.Tensor | None = None,
alpha: torch.Tensor | None = None,
normalize_alpha: bool = True,
len_pow: float | None = None,
gt_mask: torch.Tensor | None = None,
pred_weights: torch.Tensor | None = None,
cross_only: bool = False,
) -> torch.Tensor:
if pred_segments.dim() != 4 or gt_segments.dim() != 4:
raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)")
p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1]
p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1]
w_gt = None
if gt_mask is not None:
w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype)
w_pred = None
if pred_weights is not None:
w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
if variant != "simpson3":
raise ValueError(
f"Unsupported varifold variant: {variant!r}. "
f"Only 'simpson3' is supported in batch mode.")
if sigmas is not None or alpha is not None:
if sigmas is None or alpha is None:
raise ValueError("sigmas and alpha are required for simpson3 mix")
return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
|