# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn.functional as F # ============================================================================ # Modular Rotation Estimation Functions # ============================================================================ def compute_covariance(A, B, virtual_normal=True, eps=1e-8): """Compute covariance matrix H = A^T @ B for rotation estimation. Args: A: Target vectors (..., N, 3) B: Source vectors (..., N, 3) virtual_normal: If True, add synthetic normal correspondence for conditioning eps: Small constant for numerical stability Returns: H: Covariance matrix (..., 3, 3) """ # Basic covariance: H = A^T @ B H = torch.einsum("...ni,...nj->...ij", A, B) # Virtual normal fix: add synthetic correspondence from cross product if virtual_normal and A.shape[-2] >= 2: p0, p1 = A[..., 0, :], A[..., 1, :] q0, q1 = B[..., 0, :], B[..., 1, :] # Compute normal direction (cross product) n_src = torch.cross(p0, p1, dim=-1) n_dst = torch.cross(q0, q1, dim=-1) # Normalize and scale by point cloud radius len_n_src = torch.linalg.norm(n_src, dim=-1, keepdim=True) len_n_dst = torch.linalg.norm(n_dst, dim=-1, keepdim=True) scale_src = torch.linalg.norm(p0, dim=-1, keepdim=True) / (len_n_src + eps) scale_dst = torch.linalg.norm(q0, dim=-1, keepdim=True) / (len_n_dst + eps) # Check for collinearity valid_normal = (len_n_src[..., 0] > 1e-9) & (len_n_dst[..., 0] > 1e-9) # Virtual normal vectors v_src = n_src * scale_src v_dst = n_dst * scale_dst # Add virtual correspondence (only for valid normals) if torch.any(valid_normal): virtual_contrib = torch.einsum("...i,...j->...ij", v_src, v_dst) mask = valid_normal[..., None, None].expand(H.shape) virtual_contrib = torch.where(mask, virtual_contrib, 0.0) H = H + virtual_contrib return H def kabsch(H): """Compute rotation matrix from covariance using Kabsch algorithm (SVD). Args: H: Covariance matrix (..., 3, 3) Returns: R: Rotation matrix (..., 3, 3) with det(R) = 1 """ U, S, Vh = torch.linalg.svd(H) I3 = torch.eye(3, dtype=H.dtype, device=H.device) # Compute correction for determinant UVt = U @ Vh.swapaxes(-2, -1) det_sign = torch.where(torch.linalg.det(UVt) < 0, -1.0, 1.0) # Apply correction Dcorr = I3.expand(H.shape).clone() Dcorr[..., -1, -1] = det_sign R = U @ Dcorr @ Vh return R def newton_schulz(H, num_iters=20, eps=1e-8): """Compute rotation matrix from covariance using Newton-Schulz iteration. This is primarily a reference implementation for testing and comparing against the Warp-accelerated Newton-Schulz kernel. For production use, prefer `kabsch()` which is more stable and has better-defined gradients through SVD. Args: H: Covariance matrix (..., 3, 3) num_iters: Number of iterations (default 20) eps: Small constant for numerical stability Returns: R: Rotation matrix (..., 3, 3) with det(R) = 1 Note: Convergence depends on conditioning of H. Ill-conditioned matrices may require more iterations or may not converge to high precision. """ # Scale by infinity norm (max absolute row sum) for guaranteed convergence row_sums = torch.abs(H).sum(dim=-1) max_row_sum = row_sums.max(dim=-1, keepdim=True)[0].unsqueeze(-1) R = H / (max_row_sum + eps) I3 = torch.eye(3, dtype=H.dtype, device=H.device) I3_batch = I3.expand(H.shape) # Newton-Schulz iteration: R_{k+1} = R_k * (3*I - R_k^T * R_k) / 2 for _ in range(num_iters): RT_R = R.swapaxes(-2, -1) @ R term = 3.0 * I3_batch - RT_R R = R @ term * 0.5 # Differentiable determinant correction det_R = torch.linalg.det(R) sign_factor = torch.where(det_R < 0, -1.0, 1.0) # Apply sign correction to last column R_corrected = R.clone() R_corrected[..., :, 2] = R[..., :, 2] * sign_factor[..., None] return R_corrected def rodrigues_rotation(a, b, eps=1e-8): """Compute rotation matrix that aligns vector b to vector a. Uses the shortest arc rotation approach similar to SciPy's align_vectors. Args: a: Target vector (..., 3) b: Source vector (..., 3) eps: Small constant for numerical stability Returns: R: Rotation matrix (..., 3, 3) such that R @ b ≈ a """ dtype, device = a.dtype, a.device a_norm = torch.linalg.norm(a, dim=-1, keepdim=True) b_norm = torch.linalg.norm(b, dim=-1, keepdim=True) a_u = a / torch.clamp(a_norm, min=eps) b_u = b / torch.clamp(b_norm, min=eps) dot = torch.clamp((a_u * b_u).sum(dim=-1, keepdim=True), -1.0, 1.0) v = torch.cross(b_u, a_u, dim=-1) zeros = torch.zeros_like(v[..., 0]) vx = v[..., 0] vy = v[..., 1] vz = v[..., 2] skew_v = torch.stack( [ torch.stack([zeros, -vz, vy], dim=-1), torch.stack([vz, zeros, -vx], dim=-1), torch.stack([-vy, vx, zeros], dim=-1), ], dim=-2, ) eye = torch.eye(3, dtype=dtype, device=device).expand(a.shape[:-1] + (3, 3)) factor = 1.0 / (1.0 + dot[..., None]) R = eye + skew_v + factor * (skew_v @ skew_v) # Handle Antiparallel Case (180 degree rotation) antiparallel_mask = dot[..., 0] < -1.0 + 1e-6 if torch.any(antiparallel_mask): b_anti = b_u[antiparallel_mask] basis_shape = b_anti.shape[:-1] + (3,) y_vec = torch.zeros(basis_shape, dtype=dtype, device=device) y_vec[..., 1] = 1.0 x_vec = torch.zeros(basis_shape, dtype=dtype, device=device) x_vec[..., 0] = 1.0 w = torch.where((torch.abs(b_anti[..., 0]) > 0.6)[..., None], y_vec, x_vec) axis_180 = torch.cross(b_anti, w, dim=-1) axis_180 = axis_180 / torch.linalg.norm(axis_180, dim=-1, keepdim=True) u_mat = axis_180[..., :, None] * axis_180[..., None, :] eye_3 = torch.eye(3, dtype=dtype, device=device) R_180 = 2.0 * u_mat - eye_3 R[antiparallel_mask] = R_180 return R # ============================================================================ # High-Level Alignment Function # ============================================================================ def align_vectors(A, B, eps=1e-8, method="kabsch"): """ SciPy-compatible: return rotation C such that C @ b ≈ a. Supports broadcasting across leading batch dims. Inputs: (..., N, 3). Args: A: Target vectors (..., N, 3) B: Source vectors (..., N, 3) eps: Small constant for numerical stability method: 'kabsch' (SVD-based) or 'newton-schulz' (iterative) """ if A.shape[-1] != 3 or B.shape[-1] != 3: raise NotImplementedError("Only 3D vectors are supported (last dim must be 3).") if A.shape[-2] != B.shape[-2]: raise ValueError(f"N must match, got {A.shape[-2]} vs {B.shape[-2]}.") N = A.shape[-2] if N == 1: return rodrigues_rotation(A[..., 0, :], B[..., 0, :], eps=eps) H = compute_covariance(A, B, virtual_normal=True, eps=eps) if method == "newton-schulz": return newton_schulz(H, num_iters=20, eps=eps) elif method == "kabsch": return kabsch(H) else: raise ValueError(f"Unknown method: {method}. Use 'kabsch' or 'newton-schulz'.") def SE3_from_Rt(R, t): """ autograd-safe SE(3) transform construction from rotation R and translation t. R: (..., 3, 3) t: (..., 3) Returns: T (..., 4, 4) """ dtype, device = R.dtype, R.device upper = torch.cat([R, t[..., None]], dim=-1) # (..., 3, 4) last_row = torch.cat( [ torch.zeros((*upper.shape[:-2], 1, 3), dtype=dtype, device=device), torch.ones((*upper.shape[:-2], 1, 1), dtype=dtype, device=device), ], dim=-1, ) # (..., 1, 4) return torch.cat([upper, last_row], dim=-2) # (..., 4, 4) def SE3_inverse(T): """ Invert SE(3) transform(s) in homogeneous coordinates. Args: T: (..., 4, 4) torch.Tensor Returns: Tinv: (..., 4, 4) """ R = T[..., :3, :3] # (..., 3, 3) t = T[..., :3, 3:4] # (..., 3, 1) R_T = R.swapaxes(-2, -1) # (..., 3, 3) t_new = -(R_T @ t) # (..., 3, 1) Tinv = SE3_from_Rt(R_T, t_new[..., 0]) # (..., 4, 4) return Tinv # --- SO(3) conversions -------------------------------------------------------- def matrix_to_rotvec(R, eps=1e-6): """ (...,3,3) rotation matrices -> (...,3) rotation vectors (axis * angle). Robust for small angles and near-pi. """ if R.shape[-2:] != (3, 3): raise ValueError(f"Expected (...,3,3), got {R.shape}") tr = torch.diagonal(R, dim1=-2, dim2=-1).sum(-1) cos_theta = torch.clamp((tr - 1.0) * 0.5, -1.0, 1.0) theta = torch.acos(cos_theta) S = R - R.swapaxes(-2, -1) v = torch.stack( [ S[..., 2, 1] - S[..., 1, 2], S[..., 0, 2] - S[..., 2, 0], S[..., 1, 0] - S[..., 0, 1], ], dim=-1, ) sin_theta = 0.5 * torch.linalg.norm(v, dim=-1) # Regions small = theta <= 1e-3 near_pi = theta >= (torch.pi - 1e-3) # Small-angle series theta2_approx = torch.clamp(3.0 - tr, min=0.0) factor_small = 0.5 + theta2_approx / 12.0 w_small = v * factor_small[..., None] # Generic denom = torch.where(sin_theta < eps, eps, 2.0 * sin_theta) factor_gen = theta / denom w_gen = v * factor_gen[..., None] # Near-pi: axis from diagonals + sign from v R00, R11, R22 = R[..., 0, 0], R[..., 1, 1], R[..., 2, 2] u0 = torch.sqrt(torch.clamp((R00 - R11 - R22 + 1.0) * 0.5, min=0.0)) u1 = torch.sqrt(torch.clamp((-R00 + R11 - R22 + 1.0) * 0.5, min=0.0)) u2 = torch.sqrt(torch.clamp((-R00 - R11 + R22 + 1.0) * 0.5, min=0.0)) u = torch.stack([u0, u1, u2], dim=-1) sx, sy, sz = torch.sign(v[..., 0]), torch.sign(v[..., 1]), torch.sign(v[..., 2]) sx = torch.where(sx == 0, 1, sx) sy = torch.where(sy == 0, 1, sy) sz = torch.where(sz == 0, 1, sz) u = torch.stack([u[..., 0] * sx, u[..., 1] * sy, u[..., 2] * sz], dim=-1) u_norm = torch.linalg.norm(u, dim=-1, keepdim=True) u_norm = torch.where(u_norm < eps, eps, u_norm) axis_pi = u / u_norm w_pi = axis_pi * theta[..., None] return torch.where(near_pi[..., None], w_pi, torch.where(small[..., None], w_small, w_gen)) def rotvec_to_matrix(rotvec, eps=1e-8): """ (...,3) rotation vectors -> (...,3,3) rotation matrices. Robust near zero. """ if rotvec.shape[-1] != 3: raise ValueError(f"Expected (...,3), got {rotvec.shape}") theta = torch.linalg.norm(rotvec, dim=-1) denom = torch.where(theta < eps, eps, theta)[..., None] axis = rotvec / denom K = torch.zeros(rotvec.shape[:-1] + (3, 3), dtype=rotvec.dtype, device=rotvec.device) K[..., 0, 1] = -axis[..., 2] K[..., 0, 2] = axis[..., 1] K[..., 1, 0] = axis[..., 2] K[..., 1, 2] = -axis[..., 0] K[..., 2, 0] = -axis[..., 1] K[..., 2, 1] = axis[..., 0] eye = torch.eye(3, dtype=rotvec.dtype, device=rotvec.device) sin_t = torch.sin(theta) cos_t = torch.cos(theta) A = sin_t / torch.where(theta < eps, 1.0, theta) B = (1.0 - cos_t) / torch.where(theta < eps, 1.0, theta * theta) R = eye + A[..., None, None] * K + B[..., None, None] * (K @ K) small = theta < 1e-6 return torch.where(small[..., None, None], eye + K, R) def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix using Gram--Schmidt orthogonalization per Section B of [1]. Args: d6: 6D rotation representation, of size (*, 6) Returns: batch of rotation matrices of size (*, 3, 3) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2)