| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .utils.utils import kabsch |
| from .utils.rbf import grad_log_wrt_positions |
|
|
| class BiasForceTransformer(nn.Module): |
| def __init__(self, |
| mds, |
| args, |
| d_model = 256, |
| nhead = 8, |
| num_layers = 4, |
| dim_feedforward = 512, |
| dropout = 0.1, |
| ): |
| super().__init__() |
| self.device = args.device |
| self.heavy_atoms = mds.heavy_atoms |
| self.N = mds.num_particles |
| |
| self.use_delta_to_target = args.use_delta_to_target |
| self.rbf = args.rbf |
| |
| self.sigma = args.sigma |
|
|
| feat_dim = 3 + 3 + (3 if self.use_delta_to_target else 0) + 1 |
|
|
| self.input_proj = nn.Linear(feat_dim, d_model) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, activation="gelu", |
| batch_first=True, norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
| self.scale_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 1), |
| ) |
| self.vec_head_aligned = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 3), |
| ) |
| |
| self.bias = args.bias |
|
|
| self.log_z = nn.Parameter(torch.tensor(0.0)) |
| self.to(self.device) |
|
|
| @staticmethod |
| def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
| return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
| def forward(self, pos, vel, target): |
| """ |
| pos, vel, target: (B,N,3) |
| Returns: force (B,N,3), scale (B,N), vector (B,N,3) |
| """ |
| B, N, _ = pos.shape |
| assert N == self.N, f"Expected N={self.N}, got {N}" |
| heavy = self.heavy_atoms.to(pos.device) |
|
|
| pos_h, tgt_h = pos[:, heavy], target[:, heavy] |
| R, t = kabsch(pos_h, tgt_h) |
|
|
| pos_al = pos @ R.transpose(-2, -1) + t |
| vel_al = vel @ R.transpose(-2, -1) |
|
|
| delta_al = target - pos_al |
| dist_al = torch.norm(delta_al, dim=-1, keepdim=True) |
| feats = torch.cat([pos_al, vel_al, delta_al, dist_al], dim=-1) \ |
| if self.use_delta_to_target else torch.cat([pos_al, vel_al, dist_al], dim=-1) |
|
|
| x = self.input_proj(feats) |
| x = self.encoder(x) |
|
|
| scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
| vec_aligned = self.vec_head_aligned(x) |
|
|
| vector = vec_aligned @ R |
|
|
| target_posframe = (target - t) @ R |
|
|
| if self.rbf: |
| d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach() |
| else: |
| d = (target_posframe - pos) |
|
|
| scale = scale.unsqueeze(-1).expand(-1, -1, 3) |
| scaled = scale * d |
|
|
| eps = torch.finfo(pos.dtype).eps |
| denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
| vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
| vec_perp = vector - vec_parallel |
|
|
| return vec_perp + scaled |
|
|
| class BiasForceTransformerNoVel(nn.Module): |
| def __init__(self, |
| mds, |
| args, |
| d_model = 256, |
| nhead = 8, |
| num_layers = 4, |
| dim_feedforward = 512, |
| dropout = 0.1, |
| ): |
| super().__init__() |
| self.device = args.device |
| self.heavy_atoms = mds.heavy_atoms |
| self.N = mds.num_particles |
| |
| self.use_delta_to_target = args.use_delta_to_target |
| self.rbf = args.rbf |
| |
| self.sigma = args.sigma |
|
|
| feat_dim = 3 + (3 if self.use_delta_to_target else 0) + 1 |
|
|
| self.input_proj = nn.Linear(feat_dim, d_model) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, activation="gelu", |
| batch_first=True, norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
| |
| self.scale_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 1), |
| ) |
| self.vec_head_aligned = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 3), |
| ) |
|
|
| self.log_z = nn.Parameter(torch.tensor(0.0)) |
| self.to(self.device) |
|
|
| @staticmethod |
| def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
| return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
| def forward(self, pos, target): |
| """ |
| pos, target: (B,N,D) |
| Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
| |
| N: number of atoms |
| D: dimension (3) |
| """ |
| B, N, _ = pos.shape |
| assert N == self.N, f"Expected N={self.N}, got {N}" |
| heavy = self.heavy_atoms.to(pos.device) |
| |
| pos_h, tgt_h = pos[:, heavy], target[:, heavy] |
| R, t = kabsch(pos_h, tgt_h) |
|
|
| pos_al = pos @ R.transpose(-2, -1) + t |
|
|
| delta_al = target - pos_al |
| dist_al = torch.norm(delta_al, dim=-1, keepdim=True) |
| feats = torch.cat([pos_al, delta_al, dist_al], dim=-1) \ |
| if self.use_delta_to_target else torch.cat([pos_al, dist_al], dim=-1) |
|
|
| x = self.input_proj(feats) |
| x = self.encoder(x) |
|
|
| |
| scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
| vec_aligned = self.vec_head_aligned(x) |
|
|
| vector = vec_aligned @ R |
|
|
| target_posframe = (target - t) @ R |
| |
| if self.rbf: |
| d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach() |
| else: |
| d = (target_posframe - pos) |
|
|
| scale = scale.unsqueeze(-1).expand(-1, -1, 3) |
| scaled = scale * d |
|
|
| eps = torch.finfo(pos.dtype).eps |
| denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
| vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
| vec_perp = vector - vec_parallel |
|
|
| return vec_perp + scaled |