| import torch |
|
|
| from .wire_varifold_kernels import ( |
| loss_simpson3_batch, |
| loss_simpson3_mix_batch, |
| ) |
|
|
|
|
| def segments_to_vertices_edges(segments: torch.Tensor): |
| segs = torch.as_tensor(segments, dtype=torch.float32) |
| vertices = segs.reshape(-1, 3) |
| edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])] |
| return vertices, edges |
|
|
|
|
| def varifold_loss_batch( |
| pred_segments: torch.Tensor, |
| gt_segments: torch.Tensor, |
| *, |
| sigma: float = 0.1, |
| variant: str = "semi_lobatto3", |
| t_nodes01: torch.Tensor | None = None, |
| t_w: torch.Tensor | None = None, |
| sigmas: torch.Tensor | None = None, |
| alpha: torch.Tensor | None = None, |
| normalize_alpha: bool = True, |
| len_pow: float | None = None, |
| gt_mask: torch.Tensor | None = None, |
| pred_weights: torch.Tensor | None = None, |
| cross_only: bool = False, |
| ) -> torch.Tensor: |
| if pred_segments.dim() != 4 or gt_segments.dim() != 4: |
| raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)") |
| p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1] |
| p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1] |
|
|
| w_gt = None |
| if gt_mask is not None: |
| w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype) |
|
|
| w_pred = None |
| if pred_weights is not None: |
| w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype) |
|
|
| if variant != "simpson3": |
| raise ValueError( |
| f"Unsupported varifold variant: {variant!r}. " |
| f"Only 'simpson3' is supported in batch mode.") |
| if sigmas is not None or alpha is not None: |
| if sigmas is None or alpha is None: |
| raise ValueError("sigmas and alpha are required for simpson3 mix") |
| return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only) |
| return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred) |
|
|