| """Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/r3_diffuser.py""" |
|
|
| from math import sqrt |
| import torch |
|
|
| from src.utils.tensor_utils import inflate_array_like |
|
|
| class R3Diffuser: |
| """VPSDE diffusion module.""" |
| def __init__( |
| self, |
| min_b: float = 0.1, |
| max_b: float = 20.0, |
| coordinate_scaling: float = 1.0, |
| ): |
| self.min_b = min_b |
| self.max_b = max_b |
| self.coordinate_scaling = coordinate_scaling |
|
|
| def scale(self, x): |
| return x * self.coordinate_scaling |
|
|
| def unscale(self, x): |
| return x / self.coordinate_scaling |
|
|
| def b_t(self, t: torch.Tensor): |
| if torch.any(t < 0) or torch.any(t > 1): |
| raise ValueError(f'Invalid t={t}') |
| return self.min_b + t * (self.max_b - self.min_b) |
|
|
| def diffusion_coef(self, t): |
| return torch.sqrt(self.b_t(t)) |
|
|
| def drift_coef(self, x, t): |
| return -0.5 * self.b_t(t) * x |
|
|
| def sample_prior(self, shape, device=None): |
| return torch.randn(size=shape, device=device) |
|
|
| def marginal_b_t(self, t): |
| return t*self.min_b + 0.5*(t**2)*(self.max_b-self.min_b) |
|
|
| def calc_trans_0(self, score_t, x_t, t): |
| beta_t = self.marginal_b_t(t) |
| beta_t = beta_t[..., None, None] |
| cond_var = 1 - torch.exp(-beta_t) |
| return (score_t * cond_var + x_t) / torch.exp(-0.5*beta_t) |
|
|
| def forward_marginal( |
| self, |
| x_0: torch.Tensor, |
| t: torch.Tensor |
| ): |
| """Samples marginal p(x(t) | x(0)). |
| |
| Args: |
| x_0: [..., n, 3] initial positions in Angstroms. |
| t: continuous time in [0, 1]. |
| |
| Returns: |
| x_t: [..., n, 3] positions at time t in Angstroms. |
| score_t: [..., n, 3] score at time t in scaled Angstroms. |
| """ |
| t = inflate_array_like(t, x_0) |
| x_0 = self.scale(x_0) |
| |
| loc = torch.exp(-0.5 * self.marginal_b_t(t)) * x_0 |
| scale = torch.sqrt(1 - torch.exp(-self.marginal_b_t(t))) |
| z = torch.randn_like(x_0) |
| x_t = z * scale + loc |
| score_t = self.score(x_t, x_0, t) |
| |
| x_t = self.unscale(x_t) |
| return x_t, score_t |
|
|
| def score_scaling(self, t: torch.Tensor): |
| return 1.0 / torch.sqrt(self.conditional_var(t)) |
|
|
| def reverse( |
| self, |
| x_t: torch.Tensor, |
| score_t: torch.Tensor, |
| t: torch.Tensor, |
| dt: float, |
| mask: torch.Tensor = None, |
| center: bool = True, |
| noise_scale: float = 1.0, |
| probability_flow: bool = True, |
| ): |
| """Simulates the reverse SDE for 1 step |
| |
| Args: |
| x_t: [..., 3] current positions at time t in angstroms. |
| score_t: [..., 3] rotation score at time t. |
| t: continuous time in [0, 1]. |
| dt: continuous step size in [0, 1]. |
| mask: True indicates which residues to diffuse. |
| probability_flow: whether to use probability flow ODE. |
| |
| Returns: |
| [..., 3] positions at next step t-1. |
| """ |
| t = inflate_array_like(t, x_t) |
| x_t = self.scale(x_t) |
| |
| f_t = self.drift_coef(x_t, t) |
| g_t = self.diffusion_coef(t) |
| |
| z = noise_scale * torch.randn_like(score_t) |
| |
| rev_drift = (f_t - g_t ** 2 * score_t) * dt * (0.5 if probability_flow else 1.) |
| rev_diffusion = 0. if probability_flow else (g_t * sqrt(dt) * z) |
| perturb = rev_drift + rev_diffusion |
|
|
| if mask is not None: |
| perturb *= mask[..., None] |
| else: |
| mask = torch.ones_like(x_t[..., 0]) |
| x_t_1 = x_t - perturb |
| if center: |
| com = torch.sum(x_t_1, dim=-2) / torch.sum(mask, dim=-1)[..., None] |
| x_t_1 -= com[..., None, :] |
| |
| x_t_1 = self.unscale(x_t_1) |
| return x_t_1 |
|
|
| def conditional_var(self, t, use_torch=False): |
| """Conditional variance of p(xt|x0). |
| Var[x_t|x_0] = conditional_var(t) * I |
| """ |
| return 1.0 - torch.exp(-self.marginal_b_t(t)) |
|
|
| def score(self, x_t, x_0, t, scale=False): |
| t = inflate_array_like(t, x_t) |
| if scale: |
| x_t, x_0 = self.scale(x_t), self.scale(x_0) |
| return -(x_t - torch.exp(-0.5 * self.marginal_b_t(t)) * x_0) / self.conditional_var(t) |
|
|
| def distribution(self, x_t, score_t, t, mask, dt): |
| x_t = self.scale(x_t) |
| f_t = self.drift_coef(x_t, t) |
| g_t = self.diffusion_coef(t) |
| std = g_t * sqrt(dt) |
| mu = x_t - (f_t - g_t**2 * score_t) * dt |
| if mask is not None: |
| mu *= mask[..., None] |
| return mu, std |