| import torch |
|
|
| from .utils import kabsch |
| |
| def rbf(positions, target_position, sigma): |
| R, t = kabsch(positions.detach(), target_position.detach()) |
| positions = torch.matmul(positions, R.transpose(-2, -1)) + t |
| log_ri = ( |
| -0.5 / sigma**2 * (positions - target_position).square().mean((-2, -1)) |
| ) |
| return log_ri |
|
|
| def grad_log_wrt_positions(positions, target_position, sigma): |
| """ |
| Gradient of log kernel w.r.t. the ORIGINAL positions: same shape as positions (..., N, 3). |
| """ |
| pos = positions.clone().detach().requires_grad_(True) |
| log_ri = rbf(pos, target_position, sigma) |
| |
| (grad_pos,) = torch.autograd.grad(log_ri.sum(), pos, create_graph=False, retain_graph=False) |
| return grad_pos |