| """ |
| Utility functions for geometric operations (torch only). |
| """ |
| import torch |
|
|
|
|
| def rots_mul_vecs(m, v): |
| """(Batch) Apply rotations 'm' to vectors 'v'.""" |
| return torch.stack([ |
| m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2], |
| m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2], |
| m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2], |
| ], dim=-1) |
| |
| def distance(p, eps=1e-10): |
| """Calculate distance between a pair of points (dim=-2).""" |
| |
| return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5 |
|
|
| def dihedral(p, eps=1e-10): |
| """Calculate dihedral angle between a quadruple of points (dim=-2).""" |
| |
|
|
| |
| u1 = p[..., 1, :] - p[..., 0, :] |
| u2 = p[..., 2, :] - p[..., 1, :] |
| u3 = p[..., 3, :] - p[..., 2, :] |
|
|
| |
| u1xu2 = torch.cross(u1, u2, dim=-1) |
| u2xu3 = torch.cross(u2, u3, dim=-1) |
|
|
| |
| u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5 |
| u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5 |
| u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5 |
|
|
| |
| cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm) |
| sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm) |
|
|
| return torch.stack([cos_enc, sin_enc], dim=-1) |
|
|
| def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int): |
| |
| dists_2d = torch.linalg.norm( |
| pos[..., :, None, :] - pos[..., None, :, :], axis=-1 |
| )[..., None] |
| lower = torch.linspace( |
| min_bin, |
| max_bin, |
| num_bins, |
| device=pos.device) |
| upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) |
| distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) |
| return distogram |
|
|
| def rmsd(xyz1, xyz2): |
| """ Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """ |
| return squared_deviation(xyz1, xyz2, 'rmsd') |
|
|
| def squared_deviation(xyz1, xyz2, reduction='none'): |
| """Squared point-wise deviation between two point clouds after alignment. |
| |
| Args: |
| xyz1: (*, L, 3), to be transformed |
| xyz2: (*, L, 3), the reference |
| |
| Returns: |
| rmsd: (*, ) or none: (*, L) |
| """ |
| map_to_np = False |
| if not torch.is_tensor(xyz1): |
| map_to_np = True |
| xyz1 = torch.as_tensor(xyz1) |
| xyz2 = torch.as_tensor(xyz2) |
| |
| R, t = _find_rigid_alignment(xyz1, xyz2) |
|
|
| |
| |
| |
| xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) |
|
|
| sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) |
| |
| assert sd.shape == xyz1.shape[:-1] |
| if reduction == 'none': |
| pass |
| elif reduction == 'rmsd': |
| sd = torch.sqrt(sd.mean(dim=-1)) |
| else: |
| raise NotImplementedError() |
| |
| sd = sd.numpy() if map_to_np else sd |
| return sd |
|
|
| def _find_rigid_alignment(src, tgt): |
| """Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/; |
| https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8 |
| |
| See: https://en.wikipedia.org/wiki/Kabsch_algorithm |
| |
| 2-D or 3-D registration with known correspondences. |
| Registration occurs in the zero centered coordinate system, and then |
| must be transported back. |
| |
| Args: |
| src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source) |
| tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target) |
| Returns: |
| R: optimal rotation (*, 3, 3) |
| t: optimal translation (*, 3) |
| |
| Test on rotation + translation and on rotation + translation + reflection |
| >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float) |
| >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float) |
| >>> B = (R0.mm(A.T)).T |
| >>> t0 = torch.tensor([3., 3.]) |
| >>> B += t0 |
| >>> R, t = find_rigid_alignment(A, B) |
| >>> A_aligned = (R.mm(A.T)).T + t |
| >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) |
| >>> rmsd |
| tensor(3.7064e-07) |
| >>> B *= torch.tensor([-1., 1.]) |
| >>> R, t = find_rigid_alignment(A, B) |
| >>> A_aligned = (R.mm(A.T)).T + t |
| >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) |
| >>> rmsd |
| tensor(3.7064e-07) |
| """ |
| assert src.shape[-2] > 1 |
| src_com = src.mean(dim=-2, keepdim=True) |
| tgt_com = tgt.mean(dim=-2, keepdim=True) |
| src_centered = src - src_com |
| tgt_centered = tgt - tgt_com |
|
|
| |
|
|
| |
| H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered) |
|
|
| U, S, V = torch.svd(H) |
| |
|
|
| |
| R = torch.matmul(V, U.transpose(-2, -1)) |
|
|
| |
|
|
| |
| t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1) |
|
|
| return R, t.squeeze(-2) |
|
|