| import sys |
| import os |
| os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class BiasForceTransformer(nn.Module): |
| def __init__(self, |
| args, |
| d_model = 256, |
| nhead = 8, |
| num_layers = 4, |
| dim_feedforward = 512, |
| dropout = 0.1, |
| ): |
| super().__init__() |
| self.device = args.device |
| self.N = args.num_particles |
| |
| self.use_delta_to_target = args.use_delta_to_target |
| self.rbf = args.rbf |
| |
| self.sigma = args.sigma |
|
|
| G = args.dim |
| |
| |
| feat_dim = (2 * G) + (G if self.use_delta_to_target else 0) + 1 |
|
|
| self.input_proj = nn.Linear(feat_dim, d_model) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, activation="gelu", |
| batch_first=True, norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
| |
| self.scale_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 1), |
| ) |
| self.vec_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, args.dim), |
| ) |
|
|
| self.log_z = nn.Parameter(torch.tensor(0.0)) |
| |
|
|
| @staticmethod |
| def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
| return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
| def forward(self, pos, vel, target): |
| """ |
| pos, vel, target: (B,N,D) |
| Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
| |
| N: number of cells in batch |
| D: dimension of gene vector |
| """ |
| B, N, G = pos.shape |
| assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
| |
| delta = target - pos |
| dist = torch.norm(delta, dim=-1, keepdim=True) |
| feats = torch.cat([pos, vel, delta, dist], dim=-1) \ |
| if self.use_delta_to_target else torch.cat([pos, vel, dist], dim=-1) |
|
|
| x = self.input_proj(feats) |
| x = self.encoder(x) |
|
|
| |
| scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
| vector = self.vec_head(x) |
|
|
| |
| d = (target - pos) |
|
|
| |
| scale = scale.unsqueeze(-1).expand(-1, -1, G) |
| scaled = scale * d |
|
|
| |
| eps = torch.finfo(pos.dtype).eps |
| denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
| vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
| vec_perp = vector - vec_parallel |
|
|
| return vec_perp + scaled |
| |
| class BiasForceTransformerNoVel(nn.Module): |
| def __init__(self, |
| args, |
| d_model = 256, |
| nhead = 8, |
| num_layers = 4, |
| dim_feedforward = 512, |
| dropout = 0.1, |
| ): |
| super().__init__() |
| self.device = args.device |
| self.N = args.num_particles |
| |
| self.use_delta_to_target = args.use_delta_to_target |
| self.rbf = args.rbf |
| |
| self.sigma = args.sigma |
|
|
| G = args.dim |
| |
| |
| feat_dim = G + (G if self.use_delta_to_target else 0) + 1 |
|
|
| self.input_proj = nn.Linear(feat_dim, d_model) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, activation="gelu", |
| batch_first=True, norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
| |
| self.scale_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, 1), |
| ) |
| self.vec_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, args.dim), |
| ) |
|
|
| self.log_z = nn.Parameter(torch.tensor(0.0)) |
| |
|
|
| @staticmethod |
| def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
| return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
| def forward(self, pos, target): |
| """ |
| pos, target: (B,N,D) |
| Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
| |
| N: number of cells in batch |
| D: dimension of gene vector |
| """ |
| B, N, G = pos.shape |
| assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
| |
| delta = target - pos |
| dist = torch.norm(delta, dim=-1, keepdim=True) |
| feats = torch.cat([pos, delta, dist], dim=-1) \ |
| if self.use_delta_to_target else torch.cat([pos, dist], dim=-1) |
|
|
| x = self.input_proj(feats) |
| x = self.encoder(x) |
|
|
| |
| scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
| vector = self.vec_head(x) |
|
|
| |
| d = (target - pos) |
|
|
| |
| scale = scale.unsqueeze(-1).expand(-1, -1, G) |
| scaled = scale * d |
|
|
| |
| eps = torch.finfo(pos.dtype).eps |
| denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
| vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
| vec_perp = vector - vec_parallel |
|
|
| return vec_perp + scaled |