Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import torch
from torch import nn
from models_con import ipa_pytorch as ipa_pytorch
from data import utils as du
from models_con.utils import get_index_embedding, get_time_embedding
from pepflow.modules.protein.constants import ANG_TO_NM_SCALE, NM_TO_ANG_SCALE
from pepflow.modules.common.layers import AngularEncoding
import math
class GAEncoder(nn.Module):
def __init__(self, ipa_conf):
super().__init__()
self._ipa_conf = ipa_conf
# angles
self.angles_embedder = AngularEncoding(num_funcs=12) # 25*5=120, for competitive embedding size
self.angle_net = nn.Sequential(
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, 5)
# nn.Linear(self._ipa_conf.c_s, 22)
)
# for condition on current seq
self.current_seq_embedder = nn.Embedding(22, self._ipa_conf.c_s)
self.seq_net = nn.Sequential(
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, 20)
# nn.Linear(self._ipa_conf.c_s, 22)
)
# mixer
self.res_feat_mixer = nn.Sequential(
nn.Linear(3 * self._ipa_conf.c_s + self.angles_embedder.get_out_dim(in_dim=5), self._ipa_conf.c_s),
nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
)
self.feat_dim = self._ipa_conf.c_s
# Attention trunk
self.trunk = nn.ModuleDict()
for b in range(self._ipa_conf.num_blocks):
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
tfmr_in = self._ipa_conf.c_s
tfmr_layer = torch.nn.TransformerEncoderLayer(
d_model=tfmr_in,
nhead=self._ipa_conf.seq_tfmr_num_heads,
dim_feedforward=tfmr_in,
batch_first=True,
dropout=0.0,
norm_first=False
)
self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
tfmr_in, self._ipa_conf.c_s, init="final")
self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
c=self._ipa_conf.c_s)
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
self._ipa_conf.c_s, use_rot_updates=True)
if b < self._ipa_conf.num_blocks-1:
# No edge update on the last block.
edge_in = self._ipa_conf.c_z
self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
node_embed_size=self._ipa_conf.c_s,
edge_embed_in=edge_in,
edge_embed_out=self._ipa_conf.c_z,
)
def embed_t(self, timesteps, mask):
timestep_emb = get_time_embedding(
timesteps[:, 0],
self.feat_dim,
max_positions=2056
)[:, None, :].repeat(1, mask.shape[1], 1)
return timestep_emb
def forward(self, t, rotmats_t, trans_t, angles_t, seqs_t, node_embed, edge_embed, generate_mask, res_mask):
num_batch, num_res = seqs_t.shape
# incorperate current seq and timesteps
node_mask = res_mask
edge_mask = node_mask[:, None] * node_mask[:, :, None]
node_embed = self.res_feat_mixer(torch.cat([node_embed, self.current_seq_embedder(seqs_t), self.embed_t(t,node_mask), self.angles_embedder(angles_t).reshape(num_batch,num_res,-1)],dim=-1))
node_embed = node_embed * node_mask[..., None]
curr_rigids = du.create_rigid(rotmats_t, trans_t)
for b in range(self._ipa_conf.num_blocks):
ipa_embed = self.trunk[f'ipa_{b}'](
node_embed,
edge_embed,
curr_rigids,
node_mask)
ipa_embed *= node_mask[..., None]
node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
node_embed, src_key_padding_mask=(1 - node_mask).bool())
node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
node_embed = self.trunk[f'node_transition_{b}'](node_embed)
node_embed = node_embed * node_mask[..., None]
rigid_update = self.trunk[f'bb_update_{b}'](
node_embed * node_mask[..., None])
curr_rigids = curr_rigids.compose_q_update_vec(
rigid_update, node_mask[..., None])
if b < self._ipa_conf.num_blocks-1:
edge_embed = self.trunk[f'edge_transition_{b}'](
node_embed, edge_embed)
edge_embed *= edge_mask[..., None]
# curr_rigids = self.rigids_nm_to_ang(curr_rigids)
pred_trans1 = curr_rigids.get_trans()
pred_rotmats1 = curr_rigids.get_rots().get_rot_mats()
pred_seqs1_prob = self.seq_net(node_embed)
pred_angles1 = self.angle_net(node_embed)
pred_angles1 = pred_angles1 % (2*math.pi) # inductive bias to bound between (0,2pi)
return pred_rotmats1, pred_trans1, pred_angles1, pred_seqs1_prob