| import collections |
| import copy |
| import enum |
| import math |
| import os |
| import time |
| from argparse import ArgumentParser |
| from enum import Enum, auto |
| from os.path import basename, dirname, isfile, join |
| from typing import List, Optional, Tuple, Union |
|
|
| import networkx as nx |
| import numpy as np |
| import torch |
| import torch.hub |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from biopandas.pdb import PandasPdb |
| from scipy.spatial.distance import cdist |
| from torch import Tensor |
| from torch.nn import Dropout, LayerNorm, Linear |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| def reset_parameters_helper(m: nn.Module): |
| """helper function for resetting model parameters, meant to be used with model.apply()""" |
|
|
| |
| |
| reset_parameters = getattr(m, "reset_parameters", None) |
| reset_parameters_private = getattr(m, "_reset_parameters", None) |
|
|
| if callable(reset_parameters) and callable(reset_parameters_private): |
| raise RuntimeError( |
| "Module has both public and private methods for resetting parameters. " |
| "This is unexpected... probably should just call the public one." |
| ) |
|
|
| if callable(reset_parameters): |
| m.reset_parameters() |
|
|
| if callable(reset_parameters_private): |
| m._reset_parameters() |
|
|
|
|
| class SequentialWithArgs(nn.Sequential): |
| def forward(self, x, **kwargs): |
| for module in self: |
| if isinstance(module, RelativeTransformerEncoder) or isinstance( |
| module, SequentialWithArgs |
| ): |
| |
| x = module(x, **kwargs) |
| else: |
| |
| x = module(x) |
| return x |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| |
| |
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super(PositionalEncoding, self).__init__() |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| |
| |
| |
| |
| |
| pe = pe.unsqueeze(0) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x, **kwargs): |
| |
| |
| |
| |
| x = x + self.pe[:, : x.size(1), :] |
| return self.dropout(x) |
|
|
|
|
| class ScaledEmbedding(nn.Module): |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool): |
| super(ScaledEmbedding, self).__init__() |
| self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
| self.emb_size = embedding_dim |
| self.embed_scale = math.sqrt(self.emb_size) |
|
|
| self.scale = scale |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| |
| |
| |
| init_range = 0.1 |
| self.embedding.weight.data.uniform_(-init_range, init_range) |
|
|
| def forward(self, tokens: Tensor, **kwargs): |
| if self.scale: |
| return self.embedding(tokens.long()) * self.embed_scale |
| else: |
| return self.embedding(tokens.long()) |
|
|
|
|
| class FCBlock(nn.Module): |
| """a fully connected block with options for batchnorm and dropout |
| can extend in the future with option for different activation, etc""" |
|
|
| def __init__( |
| self, |
| in_features: int, |
| num_hidden_nodes: int = 64, |
| use_batchnorm: bool = False, |
| use_layernorm: bool = False, |
| norm_before_activation: bool = False, |
| use_dropout: bool = False, |
| dropout_rate: float = 0.2, |
| activation: str = "relu", |
| ): |
|
|
| super().__init__() |
|
|
| if use_batchnorm and use_layernorm: |
| raise ValueError( |
| "Only one of use_batchnorm or use_layernorm can be set to True" |
| ) |
|
|
| self.use_batchnorm = use_batchnorm |
| self.use_dropout = use_dropout |
| self.use_layernorm = use_layernorm |
| self.norm_before_activation = norm_before_activation |
|
|
| self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes) |
|
|
| self.activation = get_activation_fn(activation, functional=False) |
|
|
| if use_batchnorm: |
| self.norm = nn.BatchNorm1d(num_hidden_nodes) |
|
|
| if use_layernorm: |
| self.norm = nn.LayerNorm(num_hidden_nodes) |
|
|
| if use_dropout: |
| self.dropout = nn.Dropout(p=dropout_rate) |
|
|
| def forward(self, x, **kwargs): |
| x = self.fc(x) |
|
|
| |
| if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation: |
| x = self.norm(x) |
|
|
| x = self.activation(x) |
|
|
| |
| if ( |
| self.use_batchnorm or self.use_layernorm |
| ) and not self.norm_before_activation: |
| x = self.norm(x) |
|
|
| |
| if self.use_dropout: |
| x = self.dropout(x) |
|
|
| return x |
|
|
|
|
| class TaskSpecificPredictionLayers(nn.Module): |
| """Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input |
| into a single output node. All num_tasks outputs are then concatenated into a single tensor. |
| """ |
|
|
| |
| |
| |
| |
| |
|
|
| def __init__( |
| self, |
| num_tasks: int, |
| in_features: int, |
| num_hidden_nodes: int = 64, |
| use_batchnorm: bool = False, |
| use_dropout: bool = False, |
| dropout_rate: float = 0.2, |
| activation: str = "relu", |
| ): |
|
|
| super().__init__() |
|
|
| |
| |
| self.task_specific_pred_layers = nn.ModuleList() |
| for i in range(num_tasks): |
| layers = [ |
| FCBlock( |
| in_features=in_features, |
| num_hidden_nodes=num_hidden_nodes, |
| use_batchnorm=use_batchnorm, |
| use_dropout=use_dropout, |
| dropout_rate=dropout_rate, |
| activation=activation, |
| ), |
| nn.Linear(in_features=num_hidden_nodes, out_features=1), |
| ] |
| self.task_specific_pred_layers.append(nn.Sequential(*layers)) |
|
|
| def forward(self, x, **kwargs): |
| |
| task_specific_outputs = [] |
| for layer in self.task_specific_pred_layers: |
| task_specific_outputs.append(layer(x)) |
|
|
| output = torch.cat(task_specific_outputs, dim=1) |
| return output |
|
|
|
|
| class GlobalAveragePooling(nn.Module): |
| """helper class for global average pooling""" |
|
|
| def __init__(self, dim=1): |
| super().__init__() |
| |
| |
| self.dim = dim |
|
|
| def forward(self, x, **kwargs): |
| return torch.mean(x, dim=self.dim) |
|
|
|
|
| class CLSPooling(nn.Module): |
| """helper class for CLS token extraction""" |
|
|
| def __init__(self, cls_position=0): |
| super().__init__() |
|
|
| |
| |
| self.cls_position = cls_position |
|
|
| def forward(self, x, **kwargs): |
| |
| |
| return x[:, self.cls_position, :] |
|
|
|
|
| class TransformerEncoderWrapper(nn.TransformerEncoder): |
| """wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters, |
| so each transformer encoder layer has a different initialization""" |
|
|
| |
| def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): |
| super().__init__(encoder_layer, num_layers, norm) |
| if reset_params: |
| self.apply(reset_parameters_helper) |
|
|
|
|
| class AttnModel(nn.Module): |
| |
|
|
| @staticmethod |
| def add_model_specific_args(parent_parser): |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
| parser.add_argument( |
| "--pos_encoding", |
| type=str, |
| default="absolute", |
| choices=["none", "absolute", "relative", "relative_3D"], |
| help="what type of positional encoding to use", |
| ) |
| parser.add_argument( |
| "--pos_encoding_dropout", |
| type=float, |
| default=0.1, |
| help="out much dropout to use in positional encoding, for pos_encoding==absolute", |
| ) |
| parser.add_argument( |
| "--clipping_threshold", |
| type=int, |
| default=3, |
| help="clipping threshold for relative position embedding, for relative and relative_3D", |
| ) |
| parser.add_argument( |
| "--contact_threshold", |
| type=int, |
| default=7, |
| help="threshold, in angstroms, for contact map, for relative_3D", |
| ) |
| parser.add_argument("--embedding_len", type=int, default=128) |
| parser.add_argument("--num_heads", type=int, default=2) |
| parser.add_argument("--num_hidden", type=int, default=64) |
| parser.add_argument("--num_enc_layers", type=int, default=2) |
| parser.add_argument("--enc_layer_dropout", type=float, default=0.1) |
| parser.add_argument( |
| "--use_final_encoder_norm", action="store_true", default=False |
| ) |
|
|
| parser.add_argument( |
| "--global_average_pooling", action="store_true", default=False |
| ) |
| parser.add_argument("--cls_pooling", action="store_true", default=False) |
|
|
| parser.add_argument( |
| "--use_task_specific_layers", |
| action="store_true", |
| default=False, |
| help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer" |
| " if both flags are set", |
| ) |
| parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) |
| parser.add_argument( |
| "--use_final_hidden_layer", action="store_true", default=False |
| ) |
| parser.add_argument("--final_hidden_size", type=int, default=64) |
| parser.add_argument( |
| "--use_final_hidden_layer_norm", action="store_true", default=False |
| ) |
| parser.add_argument( |
| "--final_hidden_layer_norm_before_activation", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--use_final_hidden_layer_dropout", action="store_true", default=False |
| ) |
| parser.add_argument( |
| "--final_hidden_layer_dropout_rate", type=float, default=0.2 |
| ) |
|
|
| parser.add_argument( |
| "--activation", |
| type=str, |
| default="relu", |
| help="activation function used for all activations in the network", |
| ) |
| return parser |
|
|
| def __init__( |
| self, |
| |
| num_tasks: int, |
| aa_seq_len: int, |
| num_tokens: int, |
| |
| pos_encoding: str = "absolute", |
| pos_encoding_dropout: float = 0.1, |
| clipping_threshold: int = 3, |
| contact_threshold: int = 7, |
| pdb_fns: List[str] = None, |
| embedding_len: int = 64, |
| num_heads: int = 2, |
| num_hidden: int = 64, |
| num_enc_layers: int = 2, |
| enc_layer_dropout: float = 0.1, |
| use_final_encoder_norm: bool = False, |
| |
| global_average_pooling: bool = True, |
| cls_pooling: bool = False, |
| |
| use_task_specific_layers: bool = False, |
| task_specific_hidden_nodes: int = 64, |
| use_final_hidden_layer: bool = False, |
| final_hidden_size: int = 64, |
| use_final_hidden_layer_norm: bool = False, |
| final_hidden_layer_norm_before_activation: bool = False, |
| use_final_hidden_layer_dropout: bool = False, |
| final_hidden_layer_dropout_rate: float = 0.2, |
| |
| activation: str = "relu", |
| *args, |
| **kwargs, |
| ): |
|
|
| super().__init__() |
|
|
| |
| self.embedding_len = embedding_len |
| self.aa_seq_len = aa_seq_len |
|
|
| |
| layers = collections.OrderedDict() |
|
|
| |
| layers["embedder"] = ScaledEmbedding( |
| num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True |
| ) |
|
|
| |
| if pos_encoding == "absolute": |
| layers["pos_encoder"] = PositionalEncoding( |
| embedding_len, dropout=pos_encoding_dropout, max_len=512 |
| ) |
|
|
| |
| if pos_encoding in ["none", "absolute"]: |
| encoder_layer = torch.nn.TransformerEncoderLayer( |
| d_model=embedding_len, |
| nhead=num_heads, |
| dim_feedforward=num_hidden, |
| dropout=enc_layer_dropout, |
| activation=get_activation_fn(activation), |
| norm_first=True, |
| batch_first=True, |
| ) |
|
|
| |
| |
| |
| |
| encoder_norm = None |
| if use_final_encoder_norm: |
| encoder_norm = nn.LayerNorm(embedding_len) |
|
|
| layers["tr_encoder"] = TransformerEncoderWrapper( |
| encoder_layer=encoder_layer, |
| num_layers=num_enc_layers, |
| norm=encoder_norm, |
| ) |
|
|
| |
| elif pos_encoding in ["relative", "relative_3D"]: |
| relative_encoder_layer = RelativeTransformerEncoderLayer( |
| d_model=embedding_len, |
| nhead=num_heads, |
| pos_encoding=pos_encoding, |
| clipping_threshold=clipping_threshold, |
| contact_threshold=contact_threshold, |
| pdb_fns=pdb_fns, |
| dim_feedforward=num_hidden, |
| dropout=enc_layer_dropout, |
| activation=get_activation_fn(activation), |
| norm_first=True, |
| ) |
|
|
| encoder_norm = None |
| if use_final_encoder_norm: |
| encoder_norm = nn.LayerNorm(embedding_len) |
|
|
| layers["tr_encoder"] = RelativeTransformerEncoder( |
| encoder_layer=relative_encoder_layer, |
| num_layers=num_enc_layers, |
| norm=encoder_norm, |
| ) |
|
|
| |
| |
| if global_average_pooling: |
| |
| layers["avg_pooling"] = GlobalAveragePooling(dim=1) |
| pred_layer_input_features = embedding_len |
| elif cls_pooling: |
| layers["cls_pooling"] = CLSPooling(cls_position=0) |
| pred_layer_input_features = embedding_len |
| else: |
| |
| |
| layers["flatten"] = nn.Flatten() |
| pred_layer_input_features = embedding_len * aa_seq_len |
|
|
| |
| if use_task_specific_layers: |
| |
| layers["prediction"] = TaskSpecificPredictionLayers( |
| num_tasks=num_tasks, |
| in_features=pred_layer_input_features, |
| num_hidden_nodes=task_specific_hidden_nodes, |
| activation=activation, |
| ) |
| elif use_final_hidden_layer: |
| |
| layers["fc1"] = FCBlock( |
| in_features=pred_layer_input_features, |
| num_hidden_nodes=final_hidden_size, |
| use_batchnorm=False, |
| use_layernorm=use_final_hidden_layer_norm, |
| norm_before_activation=final_hidden_layer_norm_before_activation, |
| use_dropout=use_final_hidden_layer_dropout, |
| dropout_rate=final_hidden_layer_dropout_rate, |
| activation=activation, |
| ) |
|
|
| layers["prediction"] = nn.Linear( |
| in_features=final_hidden_size, out_features=num_tasks |
| ) |
| else: |
| layers["prediction"] = nn.Linear( |
| in_features=pred_layer_input_features, out_features=num_tasks |
| ) |
|
|
| |
| self.model = SequentialWithArgs(layers) |
|
|
| def forward(self, x, **kwargs): |
| return self.model(x, **kwargs) |
|
|
|
|
| class Transpose(nn.Module): |
| """helper layer to swap data from (batch, seq, channels) to (batch, channels, seq) |
| used as a helper in the convolutional network which pytorch defaults to channels-first |
| """ |
|
|
| def __init__(self, dims: Tuple[int, ...] = (1, 2)): |
| super().__init__() |
| self.dims = dims |
|
|
| def forward(self, x, **kwargs): |
| x = x.transpose(*self.dims).contiguous() |
| return x |
|
|
|
|
| def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1): |
| return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1 |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| dilation: int = 1, |
| padding: str = "same", |
| use_batchnorm: bool = False, |
| use_layernorm: bool = False, |
| norm_before_activation: bool = False, |
| use_dropout: bool = False, |
| dropout_rate: float = 0.2, |
| activation: str = "relu", |
| ): |
|
|
| super().__init__() |
|
|
| if use_batchnorm and use_layernorm: |
| raise ValueError( |
| "Only one of use_batchnorm or use_layernorm can be set to True" |
| ) |
|
|
| self.use_batchnorm = use_batchnorm |
| self.use_layernorm = use_layernorm |
| self.norm_before_activation = norm_before_activation |
| self.use_dropout = use_dropout |
|
|
| self.conv = nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| padding=padding, |
| dilation=dilation, |
| ) |
|
|
| self.activation = get_activation_fn(activation, functional=False) |
|
|
| if use_batchnorm: |
| self.norm = nn.BatchNorm1d(out_channels) |
|
|
| if use_layernorm: |
| self.norm = nn.LayerNorm(out_channels) |
|
|
| if use_dropout: |
| self.dropout = nn.Dropout(p=dropout_rate) |
|
|
| def forward(self, x, **kwargs): |
| x = self.conv(x) |
|
|
| |
| if self.use_batchnorm and self.norm_before_activation: |
| x = self.norm(x) |
| elif self.use_layernorm and self.norm_before_activation: |
| x = self.norm(x.transpose(1, 2)).transpose(1, 2) |
|
|
| x = self.activation(x) |
|
|
| |
| if self.use_batchnorm and not self.norm_before_activation: |
| x = self.norm(x) |
| elif self.use_layernorm and not self.norm_before_activation: |
| x = self.norm(x.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| if self.use_dropout: |
| x = self.dropout(x) |
|
|
| return x |
|
|
|
|
| class ConvModel2(nn.Module): |
| """convolutional source model that supports padded inputs, pooling, etc""" |
|
|
| @staticmethod |
| def add_model_specific_args(parent_parser): |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) |
| parser.add_argument("--use_embedding", action="store_true", default=False) |
| parser.add_argument("--embedding_len", type=int, default=128) |
|
|
| parser.add_argument("--num_conv_layers", type=int, default=1) |
| parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) |
| parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) |
| parser.add_argument("--dilations", type=int, nargs="+", default=[1]) |
| parser.add_argument( |
| "--padding", type=str, default="valid", choices=["valid", "same"] |
| ) |
| parser.add_argument("--use_conv_layer_norm", action="store_true", default=False) |
| parser.add_argument( |
| "--conv_layer_norm_before_activation", action="store_true", default=False |
| ) |
| parser.add_argument( |
| "--use_conv_layer_dropout", action="store_true", default=False |
| ) |
| parser.add_argument("--conv_layer_dropout_rate", type=float, default=0.2) |
|
|
| parser.add_argument( |
| "--global_average_pooling", action="store_true", default=False |
| ) |
|
|
| parser.add_argument( |
| "--use_task_specific_layers", action="store_true", default=False |
| ) |
| parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) |
| parser.add_argument( |
| "--use_final_hidden_layer", action="store_true", default=False |
| ) |
| parser.add_argument("--final_hidden_size", type=int, default=64) |
| parser.add_argument( |
| "--use_final_hidden_layer_norm", action="store_true", default=False |
| ) |
| parser.add_argument( |
| "--final_hidden_layer_norm_before_activation", |
| action="store_true", |
| default=False, |
| ) |
| parser.add_argument( |
| "--use_final_hidden_layer_dropout", action="store_true", default=False |
| ) |
| parser.add_argument( |
| "--final_hidden_layer_dropout_rate", type=float, default=0.2 |
| ) |
|
|
| parser.add_argument( |
| "--activation", |
| type=str, |
| default="relu", |
| help="activation function used for all activations in the network", |
| ) |
|
|
| return parser |
|
|
| def __init__( |
| self, |
| |
| num_tasks: int, |
| aa_seq_len: int, |
| aa_encoding_len: int, |
| num_tokens: int, |
| |
| use_embedding: bool = False, |
| embedding_len: int = 64, |
| num_conv_layers: int = 1, |
| kernel_sizes: List[int] = (7,), |
| out_channels: List[int] = (128,), |
| dilations: List[int] = (1,), |
| padding: str = "valid", |
| use_conv_layer_norm: bool = False, |
| conv_layer_norm_before_activation: bool = False, |
| use_conv_layer_dropout: bool = False, |
| conv_layer_dropout_rate: float = 0.2, |
| |
| global_average_pooling: bool = True, |
| |
| use_task_specific_layers: bool = False, |
| task_specific_hidden_nodes: int = 64, |
| use_final_hidden_layer: bool = False, |
| final_hidden_size: int = 64, |
| use_final_hidden_layer_norm: bool = False, |
| final_hidden_layer_norm_before_activation: bool = False, |
| use_final_hidden_layer_dropout: bool = False, |
| final_hidden_layer_dropout_rate: float = 0.2, |
| |
| activation: str = "relu", |
| *args, |
| **kwargs, |
| ): |
|
|
| super(ConvModel2, self).__init__() |
|
|
| |
| layers = collections.OrderedDict() |
|
|
| |
| if use_embedding: |
| layers["embedder"] = ScaledEmbedding( |
| num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False |
| ) |
|
|
| |
| layers["transpose"] = Transpose(dims=(1, 2)) |
|
|
| |
| for layer_num in range(num_conv_layers): |
| |
| if layer_num == 0 and use_embedding: |
| |
| in_channels = embedding_len |
| elif layer_num == 0 and not use_embedding: |
| |
| in_channels = aa_encoding_len |
| else: |
| in_channels = out_channels[layer_num - 1] |
|
|
| layers[f"conv{layer_num}"] = ConvBlock( |
| in_channels=in_channels, |
| out_channels=out_channels[layer_num], |
| kernel_size=kernel_sizes[layer_num], |
| dilation=dilations[layer_num], |
| padding=padding, |
| use_batchnorm=False, |
| use_layernorm=use_conv_layer_norm, |
| norm_before_activation=conv_layer_norm_before_activation, |
| use_dropout=use_conv_layer_dropout, |
| dropout_rate=conv_layer_dropout_rate, |
| activation=activation, |
| ) |
|
|
| |
| |
| |
| if global_average_pooling: |
| |
| |
| layers["avg_pooling"] = GlobalAveragePooling(dim=-1) |
| |
| pred_layer_input_features = out_channels[-1] |
|
|
| else: |
| |
| layers["flatten"] = nn.Flatten() |
| |
| |
| if padding == "valid": |
| |
| conv_out_len = conv1d_out_shape( |
| aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0] |
| ) |
| for layer_num in range(1, num_conv_layers): |
| conv_out_len = conv1d_out_shape( |
| conv_out_len, |
| kernel_size=kernel_sizes[layer_num], |
| dilation=dilations[layer_num], |
| ) |
| pred_layer_input_features = conv_out_len * out_channels[-1] |
| else: |
| |
| pred_layer_input_features = aa_seq_len * out_channels[-1] |
|
|
| |
| if use_task_specific_layers: |
| layers["prediction"] = TaskSpecificPredictionLayers( |
| num_tasks=num_tasks, |
| in_features=pred_layer_input_features, |
| num_hidden_nodes=task_specific_hidden_nodes, |
| activation=activation, |
| ) |
|
|
| |
| elif use_final_hidden_layer: |
| layers["fc1"] = FCBlock( |
| in_features=pred_layer_input_features, |
| num_hidden_nodes=final_hidden_size, |
| use_batchnorm=False, |
| use_layernorm=use_final_hidden_layer_norm, |
| norm_before_activation=final_hidden_layer_norm_before_activation, |
| use_dropout=use_final_hidden_layer_dropout, |
| dropout_rate=final_hidden_layer_dropout_rate, |
| activation=activation, |
| ) |
| layers["prediction"] = nn.Linear( |
| in_features=final_hidden_size, out_features=num_tasks |
| ) |
|
|
| else: |
| layers["prediction"] = nn.Linear( |
| in_features=pred_layer_input_features, out_features=num_tasks |
| ) |
|
|
| self.model = nn.Sequential(layers) |
|
|
| def forward(self, x, **kwargs): |
| output = self.model(x) |
| return output |
|
|
|
|
| class ConvModel(nn.Module): |
| """a convolutional network with convolutional layers followed by a fully connected layer""" |
|
|
| @staticmethod |
| def add_model_specific_args(parent_parser): |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) |
| parser.add_argument("--num_conv_layers", type=int, default=1) |
| parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) |
| parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) |
| parser.add_argument( |
| "--padding", type=str, default="valid", choices=["valid", "same"] |
| ) |
| parser.add_argument( |
| "--use_final_hidden_layer", |
| action="store_true", |
| help="whether to use a final hidden layer", |
| ) |
| parser.add_argument( |
| "--final_hidden_size", |
| type=int, |
| default=128, |
| help="number of nodes in the final hidden layer", |
| ) |
| parser.add_argument( |
| "--use_dropout", |
| action="store_true", |
| help="whether to use dropout in the final hidden layer", |
| ) |
| parser.add_argument( |
| "--dropout_rate", |
| type=float, |
| default=0.2, |
| help="dropout rate in the final hidden layer", |
| ) |
| parser.add_argument( |
| "--use_task_specific_layers", action="store_true", default=False |
| ) |
| parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) |
| return parser |
|
|
| def __init__( |
| self, |
| num_tasks: int, |
| aa_seq_len: int, |
| aa_encoding_len: int, |
| num_conv_layers: int = 1, |
| kernel_sizes: List[int] = (7,), |
| out_channels: List[int] = (128,), |
| padding: str = "valid", |
| use_final_hidden_layer: bool = True, |
| final_hidden_size: int = 128, |
| use_dropout: bool = False, |
| dropout_rate: float = 0.2, |
| use_task_specific_layers: bool = False, |
| task_specific_hidden_nodes: int = 64, |
| *args, |
| **kwargs, |
| ): |
|
|
| super(ConvModel, self).__init__() |
|
|
| |
| layers = collections.OrderedDict() |
|
|
| layers["transpose"] = Transpose(dims=(1, 2)) |
|
|
| for layer_num in range(num_conv_layers): |
| |
| in_channels = ( |
| aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1] |
| ) |
|
|
| layers["conv{}".format(layer_num)] = nn.Sequential( |
| nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels[layer_num], |
| kernel_size=kernel_sizes[layer_num], |
| padding=padding, |
| ), |
| nn.ReLU(), |
| ) |
|
|
| layers["flatten"] = nn.Flatten() |
|
|
| |
| |
| if padding == "valid": |
| |
| conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0]) |
| for layer_num in range(1, num_conv_layers): |
| conv_out_len = conv1d_out_shape( |
| conv_out_len, kernel_size=kernel_sizes[layer_num] |
| ) |
| next_dim = conv_out_len * out_channels[-1] |
| elif padding == "same": |
| next_dim = aa_seq_len * out_channels[-1] |
| else: |
| raise ValueError("unexpected value for padding: {}".format(padding)) |
|
|
| |
| if use_final_hidden_layer: |
| layers["fc1"] = FCBlock( |
| in_features=next_dim, |
| num_hidden_nodes=final_hidden_size, |
| use_batchnorm=False, |
| use_dropout=use_dropout, |
| dropout_rate=dropout_rate, |
| ) |
| next_dim = final_hidden_size |
|
|
| |
| |
| if use_task_specific_layers: |
| layers["prediction"] = TaskSpecificPredictionLayers( |
| num_tasks=num_tasks, |
| in_features=next_dim, |
| num_hidden_nodes=task_specific_hidden_nodes, |
| ) |
| else: |
| layers["prediction"] = nn.Linear( |
| in_features=next_dim, out_features=num_tasks |
| ) |
|
|
| self.model = nn.Sequential(layers) |
|
|
| def forward(self, x, **kwargs): |
| output = self.model(x) |
| return output |
|
|
|
|
| class FCModel(nn.Module): |
|
|
| @staticmethod |
| def add_model_specific_args(parent_parser): |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) |
| parser.add_argument("--num_layers", type=int, default=1) |
| parser.add_argument("--num_hidden", nargs="+", type=int, default=[128]) |
| parser.add_argument("--use_batchnorm", action="store_true", default=False) |
| parser.add_argument("--use_layernorm", action="store_true", default=False) |
| parser.add_argument( |
| "--norm_before_activation", action="store_true", default=False |
| ) |
| parser.add_argument("--use_dropout", action="store_true", default=False) |
| parser.add_argument("--dropout_rate", type=float, default=0.2) |
| return parser |
|
|
| def __init__( |
| self, |
| num_tasks: int, |
| seq_encoding_len: int, |
| num_layers: int = 1, |
| num_hidden: List[int] = (128,), |
| use_batchnorm: bool = False, |
| use_layernorm: bool = False, |
| norm_before_activation: bool = False, |
| use_dropout: bool = False, |
| dropout_rate: float = 0.2, |
| activation: str = "relu", |
| *args, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| |
| layers = collections.OrderedDict() |
|
|
| |
| layers["flatten"] = nn.Flatten() |
|
|
| |
| for layer_num in range(num_layers): |
| |
| |
| in_features = ( |
| seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1] |
| ) |
|
|
| layers["fc{}".format(layer_num)] = FCBlock( |
| in_features=in_features, |
| num_hidden_nodes=num_hidden[layer_num], |
| use_batchnorm=use_batchnorm, |
| use_layernorm=use_layernorm, |
| norm_before_activation=norm_before_activation, |
| use_dropout=use_dropout, |
| dropout_rate=dropout_rate, |
| activation=activation, |
| ) |
|
|
| |
| in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len |
| layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks) |
|
|
| self.model = nn.Sequential(layers) |
|
|
| def forward(self, x, **kwargs): |
| output = self.model(x) |
| return output |
|
|
|
|
| class LRModel(nn.Module): |
| """a simple linear model""" |
|
|
| def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs): |
| super().__init__() |
|
|
| self.model = nn.Sequential( |
| nn.Flatten(), nn.Linear(seq_encoding_len, out_features=num_tasks) |
| ) |
|
|
| def forward(self, x, **kwargs): |
| output = self.model(x) |
| return output |
|
|
|
|
| class TransferModel(nn.Module): |
| """transfer learning model""" |
|
|
| @staticmethod |
| def add_model_specific_args(parent_parser): |
|
|
| def none_or_int(value: str): |
| return None if value.lower() == "none" else int(value) |
|
|
| p = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
| |
| p.add_argument("--pretrained_ckpt_path", type=str, default=None) |
|
|
| |
| p.add_argument( |
| "--backbone_cutoff", |
| type=none_or_int, |
| default=-1, |
| help="where to cut off the backbone. can be a negative int, indexing back from " |
| "pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. " |
| "a value of -2 chops the prediction head and FC layer. a value of -3 chops" |
| "the above, as well as the global average pooling layer. all depends on architecture.", |
| ) |
|
|
| p.add_argument( |
| "--pred_layer_input_features", |
| type=int, |
| default=None, |
| help="if None, number of features will be determined based on backbone_cutoff and standard " |
| "architecture. otherwise, specify the number of input features for the prediction layer", |
| ) |
|
|
| |
| p.add_argument( |
| "--top_net_type", |
| type=str, |
| default="linear", |
| choices=["linear", "nonlinear", "sklearn"], |
| ) |
| p.add_argument("--top_net_hidden_nodes", type=int, default=256) |
| p.add_argument("--top_net_use_batchnorm", action="store_true") |
| p.add_argument("--top_net_use_dropout", action="store_true") |
| p.add_argument("--top_net_dropout_rate", type=float, default=0.1) |
|
|
| return p |
|
|
| def __init__( |
| self, |
| |
| pretrained_ckpt_path: Optional[str] = None, |
| pretrained_hparams: Optional[dict] = None, |
| backbone_cutoff: Optional[int] = -1, |
| |
| pred_layer_input_features: Optional[int] = None, |
| top_net_type: str = "linear", |
| top_net_hidden_nodes: int = 256, |
| top_net_use_batchnorm: bool = False, |
| top_net_use_dropout: bool = False, |
| top_net_dropout_rate: float = 0.1, |
| *args, |
| **kwargs, |
| ): |
|
|
| super().__init__() |
|
|
| |
| if pretrained_ckpt_path is None and pretrained_hparams is None: |
| raise ValueError( |
| "Either pretrained_ckpt_path or pretrained_hparams must be specified" |
| ) |
|
|
| |
| |
| |
| pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if pretrained_hparams is not None: |
| |
| pretrained_hparams["pdb_fns"] = pdb_fns |
| pretrained_model = Model[pretrained_hparams["model_name"]].cls( |
| **pretrained_hparams |
| ) |
| self.pretrained_hparams = pretrained_hparams |
| else: |
| |
| raise NotImplementedError( |
| "Loading pretrained weights from RosettaTask checkpoint not supported" |
| ) |
|
|
| layers = collections.OrderedDict() |
|
|
| |
| if backbone_cutoff is None: |
| layers["backbone"] = SequentialWithArgs( |
| *list(pretrained_model.model.children()) |
| ) |
| else: |
| layers["backbone"] = SequentialWithArgs( |
| *list(pretrained_model.model.children())[0:backbone_cutoff] |
| ) |
|
|
| if top_net_type == "sklearn": |
| |
| self.model = SequentialWithArgs(layers) |
| return |
|
|
| |
| if pred_layer_input_features is None: |
| |
| |
| |
| if backbone_cutoff is None: |
| |
| pred_layer_input_features = self.pretrained_hparams["num_tasks"] |
| elif backbone_cutoff == -1: |
| pred_layer_input_features = self.pretrained_hparams["final_hidden_size"] |
| elif backbone_cutoff == -2: |
| pred_layer_input_features = self.pretrained_hparams["embedding_len"] |
| elif backbone_cutoff == -3: |
| pred_layer_input_features = ( |
| self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"] |
| ) |
| else: |
| raise ValueError( |
| "can't automatically determine pred_layer_input_features for given backbone_cutoff" |
| ) |
|
|
| layers["flatten"] = nn.Flatten(start_dim=1) |
|
|
| |
| if top_net_type == "linear": |
| |
| layers["prediction"] = nn.Linear( |
| in_features=pred_layer_input_features, out_features=1 |
| ) |
| elif top_net_type == "nonlinear": |
| |
| fc_block = FCBlock( |
| in_features=pred_layer_input_features, |
| num_hidden_nodes=top_net_hidden_nodes, |
| use_batchnorm=top_net_use_batchnorm, |
| use_dropout=top_net_use_dropout, |
| dropout_rate=top_net_dropout_rate, |
| ) |
|
|
| pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1) |
|
|
| layers["prediction"] = SequentialWithArgs(fc_block, pred_layer) |
| else: |
| raise ValueError( |
| "Unexpected type of top net layer: {}".format(top_net_type) |
| ) |
|
|
| self.model = SequentialWithArgs(layers) |
|
|
| def forward(self, x, **kwargs): |
| return self.model(x, **kwargs) |
|
|
|
|
| def get_activation_fn(activation, functional=True): |
| if activation == "relu": |
| return F.relu if functional else nn.ReLU() |
| elif activation == "gelu": |
| return F.gelu if functional else nn.GELU() |
| elif activation == "silo" or activation == "swish": |
| return F.silu if functional else nn.SiLU() |
| elif activation == "leaky_relu" or activation == "lrelu": |
| return F.leaky_relu if functional else nn.LeakyReLU() |
| else: |
| raise RuntimeError("unknown activation: {}".format(activation)) |
|
|
|
|
| class Model(enum.Enum): |
| def __new__(cls, *args, **kwds): |
| value = len(cls.__members__) + 1 |
| obj = object.__new__(cls) |
| obj._value_ = value |
| return obj |
|
|
| def __init__(self, cls, transfer_model): |
| self.cls = cls |
| self.transfer_model = transfer_model |
|
|
| linear = LRModel, False |
| fully_connected = FCModel, False |
| cnn = ConvModel, False |
| cnn2 = ConvModel2, False |
| transformer_encoder = AttnModel, False |
| transfer_model = TransferModel, True |
|
|
|
|
| def main(): |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| main() |
| """ Encodes data in different formats """ |
|
|
|
|
| class Encoding(Enum): |
| INT_SEQS = auto() |
| ONE_HOT = auto() |
|
|
|
|
| class DataEncoder: |
| chars = [ |
| "*", |
| "A", |
| "C", |
| "D", |
| "E", |
| "F", |
| "G", |
| "H", |
| "I", |
| "K", |
| "L", |
| "M", |
| "N", |
| "P", |
| "Q", |
| "R", |
| "S", |
| "T", |
| "V", |
| "W", |
| "Y", |
| ] |
| num_chars = len(chars) |
| mapping = {c: i for i, c in enumerate(chars)} |
|
|
| def __init__(self, encoding: Encoding = Encoding.INT_SEQS): |
| self.encoding = encoding |
|
|
| def _encode_from_int_seqs(self, seq_ints): |
| if self.encoding == Encoding.INT_SEQS: |
| return seq_ints |
| elif self.encoding == Encoding.ONE_HOT: |
| one_hot = np.eye(self.num_chars)[seq_ints] |
| return one_hot.astype(np.float32) |
|
|
| def encode_sequences(self, char_seqs): |
| seq_ints = [] |
| for char_seq in char_seqs: |
| int_seq = [self.mapping[c] for c in char_seq] |
| seq_ints.append(int_seq) |
| seq_ints = np.array(seq_ints).astype(int) |
| return self._encode_from_int_seqs(seq_ints) |
|
|
| def encode_variants(self, wt, variants): |
| |
| wt_int = np.zeros(len(wt), dtype=np.uint8) |
| for i, c in enumerate(wt): |
| wt_int[i] = self.mapping[c] |
|
|
| |
| seq_ints = np.tile(wt_int, (len(variants), 1)) |
|
|
| for i, variant in enumerate(variants): |
| |
| if variant == "_wt": |
| continue |
|
|
| |
| variant = variant.split(",") |
| for mutation in variant: |
| |
| position = int(mutation[1:-1]) |
| replacement = self.mapping[mutation[-1]] |
| seq_ints[i, position] = replacement |
|
|
| seq_ints = seq_ints.astype(int) |
| return self._encode_from_int_seqs(seq_ints) |
|
|
|
|
| UUID_URL_MAP = { |
| |
| "D72M9aEp": "https://zenodo.org/records/14908509/files/METL-G-20M-1D-D72M9aEp.pt?download=1", |
| "Nr9zCKpR": "https://zenodo.org/records/14908509/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1", |
| "auKdzzwX": "https://zenodo.org/records/14908509/files/METL-G-50M-1D-auKdzzwX.pt?download=1", |
| "6PSAzdfv": "https://zenodo.org/records/14908509/files/METL-G-50M-3D-6PSAzdfv.pt?download=1", |
| |
| "8gMPQJy4": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1", |
| "Hr4GNHws": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1", |
| "8iFoiYw2": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1", |
| "kt5DdWTa": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1", |
| "DMfkjVzT": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1", |
| "epegcFiH": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1", |
| "kS3rUS7h": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1", |
| "X7w83g6S": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1", |
| "UKebCQGz": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1", |
| "2rr8V4th": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1", |
| "PREhfC22": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1", |
| "9ASvszux": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1", |
| "HscFFkAb": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1", |
| "H48oiNZN": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1", |
| "CEMSx7ZC": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-PTEN-CEMSx7ZC.pt?download=1", |
| "PjxR5LW7": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-PTEN-PjxR5LW7.pt?download=1", |
| |
| "K6mw24Rg": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1", |
| "Bo5wn2SG": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1", |
| |
| "YoQkzoLD": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1", |
| "PEkeRuxb": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1", |
| |
| "4Rh3WCbG": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-DLG4_2022-ABUNDANCE-4Rh3WCbG.pt?download=1", |
| "4xbuC5y7": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-DLG4_2022-BINDING-4xbuC5y7.pt?download=1", |
| "dAndZfJ4": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-GB1-dAndZfJ4.pt?download=1", |
| "PeT2D92j": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-GFP-PeT2D92j.pt?download=1", |
| "HenDpDWe": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-GRB2-ABUNDANCE-HenDpDWe.pt?download=1", |
| "cvnycE5Q": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-GRB2-BINDING-cvnycE5Q.pt?download=1", |
| "ho54gxzv": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-Pab1-ho54gxzv.pt?download=1", |
| "UEuMtmfx": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-PTEN-ABUNDANCE-UEuMtmfx.pt?download=1", |
| "U3X8mSeT": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-PTEN-ACTIVITY-U3X8mSeT.pt?download=1", |
| "ELL4GGQq": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-TEM-1-ELL4GGQq.pt?download=1", |
| "BAWw23vW": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-1D-Ube4b-BAWw23vW.pt?download=1", |
| "RBtqxzvu": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-DLG4_2022-ABUNDANCE-RBtqxzvu.pt?download=1", |
| "BuvxgE2x": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-DLG4_2022-BINDING-BuvxgE2x.pt?download=1", |
| "9vSB3DRM": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-GB1-9vSB3DRM.pt?download=1", |
| "6JBzHpkQ": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-GFP-6JBzHpkQ.pt?download=1", |
| "dDoCCvfr": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-GRB2-ABUNDANCE-dDoCCvfr.pt?download=1", |
| "jYesS9Ki": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-GRB2-BINDING-jYesS9Ki.pt?download=1", |
| "jhbL2FeB": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-Pab1-jhbL2FeB.pt?download=1", |
| "eJPPQYEW": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-PTEN-ABUNDANCE-eJPPQYEW.pt?download=1", |
| "4gqYnW6V": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-PTEN-ACTIVITY-4gqYnW6V.pt?download=1", |
| "K6BjsWXm": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-TEM-1-K6BjsWXm.pt?download=1", |
| "G9piq6WH": "https://zenodo.org/records/14908509/files/FT-METL-G-20M-3D-Ube4b-G9piq6WH.pt?download=1", |
| |
| "RMFA6dnX": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-DLG4_2022-ABUNDANCE-RMFA6dnX.pt?download=1", |
| "YdzBYWHs": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-DLG4_2022-BINDING-YdzBYWHs.pt?download=1", |
| "Pgcseywk": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-GB1-Pgcseywk.pt?download=1", |
| "HaUuRwfE": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-GFP-HaUuRwfE.pt?download=1", |
| "VNpi9Zjt": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-GRB2-ABUNDANCE-VNpi9Zjt.pt?download=1", |
| "Z59BhUaE": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-GRB2-BINDING-Z59BhUaE.pt?download=1", |
| "TdjCzoQQ": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-Pab1-TdjCzoQQ.pt?download=1", |
| "64ncFxBR": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-TEM-1-64ncFxBR.pt?download=1", |
| "e9uhhnAv": "https://zenodo.org/records/14908509/files/FT-METL-L-1D-Ube4b-e9uhhnAv.pt?download=1", |
| "oUScGeHo": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-1D-PTEN-ABUNDANCE-oUScGeHo.pt?download=1", |
| "m9UsG7dq": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-1D-PTEN-ACTIVITY-m9UsG7dq.pt?download=1", |
| "DhuasDEr": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-3D-PTEN-ABUNDANCE-DhuasDEr.pt?download=1", |
| "8Vi7ENcC": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-3D-PTEN-ACTIVITY-8Vi7ENcC.pt?download=1", |
| "V3uTtXVe": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-DLG4_2022-ABUNDANCE-V3uTtXVe.pt?download=1", |
| "iu6ZahPw": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-DLG4_2022-BINDING-iu6ZahPw.pt?download=1", |
| "UvMMdsq4": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-GB1-UvMMdsq4.pt?download=1", |
| "LWEY95Yb": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-GFP-LWEY95Yb.pt?download=1", |
| "PqBMjXkA": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-GRB2-ABUNDANCE-PqBMjXkA.pt?download=1", |
| "VwcRN6UB": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-GRB2-BINDING-VwcRN6UB.pt?download=1", |
| "5SjoLx3y": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-Pab1-5SjoLx3y.pt?download=1", |
| "PncvgiJU": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-TEM-1-PncvgiJU.pt?download=1", |
| "NfbZL7jK": "https://zenodo.org/records/14908509/files/FT-METL-L-3D-Ube4b-NfbZL7jK.pt?download=1", |
| } |
|
|
| IDENT_UUID_MAP = { |
| |
| "metl-g-20m-1d": "D72M9aEp", |
| "metl-g-20m-3d": "Nr9zCKpR", |
| "metl-g-50m-1d": "auKdzzwX", |
| "metl-g-50m-3d": "6PSAzdfv", |
| |
| "metl-l-2m-1d-gfp": "8gMPQJy4", |
| "metl-l-2m-3d-gfp": "Hr4GNHws", |
| |
| "metl-l-2m-1d-dlg4_2022": "8iFoiYw2", |
| "metl-l-2m-3d-dlg4_2022": "kt5DdWTa", |
| |
| "metl-l-2m-1d-gb1": "DMfkjVzT", |
| "metl-l-2m-3d-gb1": "epegcFiH", |
| |
| "metl-l-2m-1d-grb2": "kS3rUS7h", |
| "metl-l-2m-3d-grb2": "X7w83g6S", |
| |
| "metl-l-2m-1d-pab1": "UKebCQGz", |
| "metl-l-2m-3d-pab1": "2rr8V4th", |
| |
| "metl-l-2m-1d-pten": "CEMSx7ZC", |
| "metl-l-2m-3d-pten": "PjxR5LW7", |
| |
| "metl-l-2m-1d-tem-1": "PREhfC22", |
| "metl-l-2m-3d-tem-1": "9ASvszux", |
| |
| "metl-l-2m-1d-ube4b": "HscFFkAb", |
| "metl-l-2m-3d-ube4b": "H48oiNZN", |
| |
| "metl-bind-2m-3d-gb1-standard": "K6mw24Rg", |
| "metl-bind-2m-3d-gb1-binding": "Bo5wn2SG", |
| |
| "metl-l-2m-1d-gfp-ft-design": "YoQkzoLD", |
| "metl-l-2m-3d-gfp-ft-design": "PEkeRuxb", |
| } |
|
|
|
|
| def download_checkpoint(uuid): |
| ckpt = torch.hub.load_state_dict_from_url( |
| UUID_URL_MAP[uuid], map_location="cpu", file_name=f"{uuid}.pt" |
| ) |
| state_dict = ckpt["state_dict"] |
| hyper_parameters = ckpt["hyper_parameters"] |
|
|
| return state_dict, hyper_parameters |
|
|
|
|
| def _get_data_encoding(hparams): |
| if "encoding" in hparams and hparams["encoding"] == "int_seqs": |
| encoding = Encoding.INT_SEQS |
| elif "encoding" in hparams and hparams["encoding"] == "one_hot": |
| encoding = Encoding.ONE_HOT |
| elif ( |
| ("encoding" in hparams and hparams["encoding"] == "auto") |
| or "encoding" not in hparams |
| ) and hparams["model_name"] in ["transformer_encoder"]: |
| encoding = Encoding.INT_SEQS |
| else: |
| raise ValueError("Detected unsupported encoding in hyperparameters") |
|
|
| return encoding |
|
|
|
|
| def load_model_and_data_encoder(state_dict, hparams): |
| model = Model[hparams["model_name"]].cls(**hparams) |
| model.load_state_dict(state_dict) |
|
|
| data_encoder = DataEncoder(_get_data_encoding(hparams)) |
|
|
| return model, data_encoder |
|
|
|
|
| def get_from_uuid(uuid): |
| if uuid in UUID_URL_MAP: |
| state_dict, hparams = download_checkpoint(uuid) |
| return load_model_and_data_encoder(state_dict, hparams) |
| else: |
| raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP") |
|
|
|
|
| def get_from_ident(ident): |
| ident = ident.lower() |
| if ident in IDENT_UUID_MAP: |
| state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident]) |
| return load_model_and_data_encoder(state_dict, hparams) |
| else: |
| raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP") |
|
|
|
|
| def get_from_checkpoint(ckpt_fn): |
| ckpt = torch.load(ckpt_fn, map_location="cpu") |
| state_dict = ckpt["state_dict"] |
| hyper_parameters = ckpt["hyper_parameters"] |
| return load_model_and_data_encoder(state_dict, hyper_parameters) |
|
|
|
|
| """ implementation of transformer encoder with relative attention |
| references: |
| - https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a |
| - https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer |
| - https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py |
| - https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py |
| """ |
|
|
|
|
| class RelativePosition3D(nn.Module): |
| """Contact map-based relative position embeddings""" |
|
|
| |
| |
| |
| |
| |
| def __init__( |
| self, |
| embedding_len: int, |
| contact_threshold: int, |
| clipping_threshold: int, |
| pdb_fns: Optional[Union[str, list, tuple]] = None, |
| default_pdb_dir: str = "data/pdb_files", |
| ): |
|
|
| |
| |
| super().__init__() |
| self.embedding_len = embedding_len |
| self.clipping_threshold = clipping_threshold |
| self.contact_threshold = contact_threshold |
| self.default_pdb_dir = default_pdb_dir |
|
|
| |
| self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) |
|
|
| |
| |
| num_embeddings = clipping_threshold + 1 |
|
|
| |
| self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) |
|
|
| |
| |
| |
| |
| self.bucket_mtxs = {} |
| self.bucket_mtxs_device = self.dummy_buffer.device |
| self._init_pdbs(pdb_fns) |
|
|
| def forward(self, pdb_fn): |
| |
| embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn)) |
| return embeddings |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _move_bucket_mtxs(self, device): |
| for k, v in self.bucket_mtxs.items(): |
| self.bucket_mtxs[k] = v.to(device) |
| self.bucket_mtxs_device = device |
|
|
| def _get_bucket_mtx(self, pdb_fn): |
| """retrieve a bucket matrix given the pdb_fn. |
| if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be |
| retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly |
| """ |
|
|
| |
| if self.bucket_mtxs_device != self.dummy_buffer.device: |
| self._move_bucket_mtxs(self.dummy_buffer.device) |
|
|
| pdb_attr = self._pdb_key(pdb_fn) |
| if pdb_attr in self.bucket_mtxs: |
| return self.bucket_mtxs[pdb_attr] |
| else: |
| |
| |
| |
| |
| self._init_pdb(pdb_fn) |
| return self.bucket_mtxs[pdb_attr] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _set_bucket_mtx(self, pdb_fn, bucket_mtx): |
| """store a bucket matrix in the bucket dict""" |
|
|
| |
| bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device) |
|
|
| self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx |
|
|
| @staticmethod |
| def _pdb_key(pdb_fn): |
| """return a unique key for the given pdb_fn, used to map unique PDBs""" |
| |
| |
| |
| return f"pdb_{basename(pdb_fn).split('.')[0]}" |
|
|
| def _init_pdbs(self, pdb_fns): |
| start = time.time() |
|
|
| if pdb_fns is None: |
| |
| return |
|
|
| |
| if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple): |
| pdb_fns = [pdb_fns] |
|
|
| |
| for pdb_fn in pdb_fns: |
| self._init_pdb(pdb_fn) |
|
|
| print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start)) |
|
|
| def _init_pdb(self, pdb_fn): |
| """process a pdb file for use with structure-based relative attention""" |
| |
| if dirname(pdb_fn) == "": |
| |
| |
| if not isfile(pdb_fn): |
| pdb_fn = join(self.default_pdb_dir, pdb_fn) |
|
|
| |
| cbeta_mtx = cbeta_distance_matrix(pdb_fn) |
| structure_graph = dist_thresh_graph(cbeta_mtx, self.contact_threshold) |
|
|
| |
| bucket_mtx = self._compute_bucket_mtx(structure_graph) |
|
|
| self._set_bucket_mtx(pdb_fn, bucket_mtx) |
|
|
| def _compute_bucketed_neighbors(self, structure_graph, source_node): |
| """gets the bucketed neighbors from the given source node and structure graph""" |
| if self.clipping_threshold < 0: |
| raise ValueError("Clipping threshold must be >= 0") |
|
|
| sspl = _inv_dict( |
| nx.single_source_shortest_path_length(structure_graph, source_node) |
| ) |
|
|
| if self.clipping_threshold is not None: |
| num_buckets = 1 + self.clipping_threshold |
| sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1) |
|
|
| return sspl |
|
|
| def _compute_bucket_mtx(self, structure_graph): |
| """get the bucket_mtx for the given structure_graph |
| calls _get_bucketed_neighbors for every node in the structure_graph""" |
| num_residues = len(list(structure_graph)) |
|
|
| |
| bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long) |
|
|
| for node_num in sorted(list(structure_graph)): |
| bucketed_neighbors = self._compute_bucketed_neighbors( |
| structure_graph, node_num |
| ) |
|
|
| for bucket_num, neighbors in bucketed_neighbors.items(): |
| bucket_mtx[node_num, neighbors] = bucket_num |
|
|
| return bucket_mtx |
|
|
|
|
| class RelativePosition(nn.Module): |
| """creates the embedding lookup table E_r and computes R |
| note this inherits from pl.LightningModule instead of nn.Module |
| makes it easier to access the device with `self.device` |
| might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property |
| """ |
|
|
| def __init__(self, embedding_len: int, clipping_threshold: int): |
| """ |
| embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead |
| clipping_threshold: the maximum relative position, referred to as k by Shaw et al. |
| """ |
| super().__init__() |
| self.embedding_len = embedding_len |
| self.clipping_threshold = clipping_threshold |
| |
| num_embeddings = 2 * clipping_threshold + 1 |
|
|
| |
| self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) |
|
|
| |
| self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) |
|
|
| def forward(self, length_q, length_k): |
| |
| range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device) |
| range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device) |
|
|
| |
| |
| distance_mat = range_vec_k[None, :] - range_vec_q[:, None] |
| distance_mat_clipped = torch.clamp( |
| distance_mat, -self.clipping_threshold, self.clipping_threshold |
| ) |
|
|
| |
| final_mat = (distance_mat_clipped + self.clipping_threshold).long() |
|
|
| |
| embeddings = self.embeddings_table(final_mat) |
|
|
| return embeddings |
|
|
|
|
| class RelativeMultiHeadAttention(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| dropout, |
| pos_encoding, |
| clipping_threshold, |
| contact_threshold, |
| pdb_fns, |
| ): |
| """ |
| Multi-head attention with relative position embeddings. Input data should be in batch_first format. |
| :param embed_dim: aka d_model, aka hid_dim |
| :param num_heads: number of heads |
| :param dropout: how much dropout for scaled dot product attention |
| |
| :param pos_encoding: what type of positional encoding to use, relative or relative3D |
| :param clipping_threshold: clipping threshold for relative position embedding |
| :param contact_threshold: for relative_3D, the threshold in angstroms for the contact map |
| :param pdb_fns: pdb file(s) to set up the relative position object |
| |
| """ |
| super().__init__() |
|
|
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
|
|
| |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
|
|
| |
| self.pos_encoding = pos_encoding |
| self.clipping_threshold = clipping_threshold |
| self.contact_threshold = contact_threshold |
| if pdb_fns is not None and not isinstance(pdb_fns, list): |
| pdb_fns = [pdb_fns] |
| self.pdb_fns = pdb_fns |
|
|
| |
| |
| |
| if pos_encoding == "relative": |
| self.relative_position_k = RelativePosition( |
| self.head_dim, self.clipping_threshold |
| ) |
| self.relative_position_v = RelativePosition( |
| self.head_dim, self.clipping_threshold |
| ) |
| elif pos_encoding == "relative_3D": |
| self.relative_position_k = RelativePosition3D( |
| self.head_dim, |
| self.contact_threshold, |
| self.clipping_threshold, |
| self.pdb_fns, |
| ) |
| self.relative_position_v = RelativePosition3D( |
| self.head_dim, |
| self.contact_threshold, |
| self.clipping_threshold, |
| self.pdb_fns, |
| ) |
| else: |
| raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding)) |
|
|
| |
| |
| self.q_proj = nn.Linear(embed_dim, embed_dim) |
| self.k_proj = nn.Linear(embed_dim, embed_dim) |
| self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| |
| |
| |
| |
| |
| self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| scale = torch.sqrt(torch.FloatTensor([self.head_dim])) |
| |
| self.register_buffer("scale", scale) |
|
|
| |
| self.need_weights = False |
| self.average_attn_weights = True |
|
|
| def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn): |
| """computes the attention weights (a "compatability function" of queries with corresponding keys)""" |
|
|
| |
| |
| |
| |
| r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute( |
| 0, 2, 1, 3 |
| ) |
| |
| |
| |
| r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute( |
| 0, 2, 1, 3 |
| ) |
| |
| attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) |
|
|
| |
| |
| r_q2 = ( |
| query.permute(1, 0, 2) |
| .contiguous() |
| .view(len_q, batch_size * self.num_heads, self.head_dim) |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if self.pos_encoding == "relative": |
| |
| rel_pos_k = self.relative_position_k(len_q, len_k) |
| elif self.pos_encoding == "relative_3D": |
| |
| rel_pos_k = self.relative_position_k(pdb_fn) |
| else: |
| raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) |
|
|
| |
| |
| |
| attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1) |
| |
| attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k) |
|
|
| |
| attn_weights = (attn1 + attn2) / self.scale |
|
|
| |
| if mask is not None: |
| |
| attn_weights = attn_weights.masked_fill(mask == 0, -1e10) |
|
|
| |
| attn_weights = torch.softmax(attn_weights, dim=-1) |
| |
| attn_weights = self.dropout(attn_weights) |
|
|
| return attn_weights |
|
|
| def _compute_avg_val( |
| self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn |
| ): |
| |
| |
| |
| r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute( |
| 0, 2, 1, 3 |
| ) |
| |
| avg1 = torch.matmul(attn_weights, r_v1) |
|
|
| |
| |
| if self.pos_encoding == "relative": |
| |
| rel_pos_v = self.relative_position_v(len_q, len_v) |
| elif self.pos_encoding == "relative_3D": |
| |
| rel_pos_v = self.relative_position_v(pdb_fn) |
| else: |
| raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) |
|
|
| |
| r_attn_weights = ( |
| attn_weights.permute(2, 0, 1, 3) |
| .contiguous() |
| .view(len_q, batch_size * self.num_heads, len_k) |
| ) |
| avg2 = torch.matmul(r_attn_weights, rel_pos_v) |
| |
| avg2 = ( |
| avg2.transpose(0, 1) |
| .contiguous() |
| .view(batch_size, self.num_heads, len_q, self.head_dim) |
| ) |
|
|
| |
| x = avg1 + avg2 |
| x = x.permute( |
| 0, 2, 1, 3 |
| ).contiguous() |
| |
| x = x.view(batch_size, len_q, self.embed_dim) |
|
|
| return x |
|
|
| def forward(self, query, key, value, pdb_fn=None, mask=None): |
| |
| |
| |
| batch_size = query.shape[0] |
| len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1]) |
|
|
| |
| query = self.q_proj(query) |
| key = self.k_proj(key) |
| value = self.v_proj(value) |
|
|
| |
| |
| attn_weights = self._compute_attn_weights( |
| query, key, len_q, len_k, batch_size, mask, pdb_fn |
| ) |
|
|
| |
| attn_output = self._compute_avg_val( |
| value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn |
| ) |
|
|
| |
| |
| attn_output = self.out_proj(attn_output) |
|
|
| if self.need_weights: |
| |
| |
| |
| if self.average_attn_weights: |
| attn_weights = attn_weights.sum(dim=1) / self.num_heads |
| return {"attn_output": attn_output, "attn_weights": attn_weights} |
| else: |
| return attn_output |
|
|
|
|
| class RelativeTransformerEncoderLayer(nn.Module): |
| """ |
| d_model: the number of expected features in the input (required). |
| nhead: the number of heads in the MultiHeadAttention models (required). |
| clipping_threshold: the clipping threshold for relative position embeddings |
| dim_feedforward: the dimension of the feedforward network model (default=2048). |
| dropout: the dropout value (default=0.1). |
| activation: the activation function of the intermediate layer, can be a string |
| ("relu" or "gelu") or a unary callable. Default: relu |
| layer_norm_eps: the eps value in layer normalization components (default=1e-5). |
| norm_first: if ``True``, layer norm is done prior to attention and feedforward |
| operations, respectively. Otherwise, it's done after. Default: ``False`` (after). |
| """ |
|
|
| |
| __constants__ = ["batch_first", "norm_first"] |
|
|
| def __init__( |
| self, |
| d_model, |
| nhead, |
| pos_encoding="relative", |
| clipping_threshold=3, |
| contact_threshold=7, |
| pdb_fns=None, |
| dim_feedforward=2048, |
| dropout=0.1, |
| activation=F.relu, |
| layer_norm_eps=1e-5, |
| norm_first=False, |
| ) -> None: |
|
|
| self.batch_first = True |
|
|
| super(RelativeTransformerEncoderLayer, self).__init__() |
|
|
| self.self_attn = RelativeMultiHeadAttention( |
| d_model, |
| nhead, |
| dropout, |
| pos_encoding, |
| clipping_threshold, |
| contact_threshold, |
| pdb_fns, |
| ) |
|
|
| |
| self.linear1 = Linear(d_model, dim_feedforward) |
| self.dropout = Dropout(dropout) |
| self.linear2 = Linear(dim_feedforward, d_model) |
|
|
| self.norm_first = norm_first |
| self.norm1 = LayerNorm(d_model, eps=layer_norm_eps) |
| self.norm2 = LayerNorm(d_model, eps=layer_norm_eps) |
| self.dropout1 = Dropout(dropout) |
| self.dropout2 = Dropout(dropout) |
|
|
| |
| if isinstance(activation, str): |
| self.activation = get_activation_fn(activation) |
| else: |
| self.activation = activation |
|
|
| def forward(self, src: Tensor, pdb_fn=None) -> Tensor: |
| x = src |
| if self.norm_first: |
| x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn) |
| x = x + self._ff_block(self.norm2(x)) |
| else: |
| x = self.norm1(x + self._sa_block(x)) |
| x = self.norm2(x + self._ff_block(x)) |
|
|
| return x |
|
|
| |
| def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor: |
| x = self.self_attn(x, x, x, pdb_fn=pdb_fn) |
| if isinstance(x, dict): |
| |
| x = x["attn_output"] |
| return self.dropout1(x) |
|
|
| |
| def _ff_block(self, x: Tensor) -> Tensor: |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| return self.dropout2(x) |
|
|
|
|
| class RelativeTransformerEncoder(nn.Module): |
| def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): |
| super(RelativeTransformerEncoder, self).__init__() |
| |
| |
| |
| self.layers = _get_clones(encoder_layer, num_layers) |
| self.num_layers = num_layers |
| self.norm = norm |
|
|
| |
| |
| if reset_params: |
| self.apply(reset_parameters_helper) |
|
|
| def forward(self, src: Tensor, pdb_fn=None) -> Tensor: |
| output = src |
|
|
| for mod in self.layers: |
| output = mod(output, pdb_fn=pdb_fn) |
|
|
| if self.norm is not None: |
| output = self.norm(output) |
|
|
| return output |
|
|
|
|
| def _get_clones(module, num_clones): |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)]) |
|
|
|
|
| def _inv_dict(d): |
| """helper function for contact map-based position embeddings""" |
| inv = dict() |
| for k, v in d.items(): |
| |
| inv.setdefault(v, list()).append(k) |
| for k, v in inv.items(): |
| |
| inv[k] = sorted(v) |
| return inv |
|
|
|
|
| def _combine_d(d, threshold, combined_key): |
| """helper function for contact map-based position embeddings |
| d is a dictionary with ints as keys and lists as values. |
| for all keys >= threshold, this function combines the values of those keys into a single list |
| """ |
| out_d = {} |
| for k, v in d.items(): |
| if k < threshold: |
| out_d[k] = v |
| elif k >= threshold: |
| if combined_key not in out_d: |
| out_d[combined_key] = v |
| else: |
| out_d[combined_key] += v |
| if combined_key in out_d: |
| out_d[combined_key] = sorted(out_d[combined_key]) |
| return out_d |
|
|
|
|
| class GraphType(Enum): |
| LINEAR = auto() |
| COMPLETE = auto() |
| DISCONNECTED = auto() |
| DIST_THRESH = auto() |
| DIST_THRESH_SHUFFLED = auto() |
|
|
|
|
| def save_graph(g, fn): |
| """Saves graph to file""" |
| nx.write_gexf(g, fn) |
|
|
|
|
| def load_graph(fn): |
| """Loads graph from file""" |
| g = nx.read_gexf(fn, node_type=int) |
| return g |
|
|
|
|
| def shuffle_nodes(g, seed=7): |
| """Shuffles the nodes of the given graph and returns a copy of the shuffled graph""" |
| |
| nodes = g.nodes() |
|
|
| |
| np.random.seed(seed) |
| nodes_shuffled = np.random.permutation(nodes) |
|
|
| |
| mapping = {n: ns for n, ns in zip(nodes, nodes_shuffled)} |
|
|
| g_shuffled = nx.relabel_nodes(g, mapping, copy=True) |
|
|
| return g_shuffled |
|
|
|
|
| def linear_graph(num_residues): |
| """Creates a linear graph where each node is connected to its sequence neighbor in order""" |
| g = nx.Graph() |
| g.add_nodes_from(np.arange(0, num_residues)) |
| for i in range(num_residues - 1): |
| g.add_edge(i, i + 1) |
| return g |
|
|
|
|
| def complete_graph(num_residues): |
| """Creates a graph where each node is connected to all other nodes""" |
| g = nx.complete_graph(num_residues) |
| return g |
|
|
|
|
| def disconnected_graph(num_residues): |
| g = nx.Graph() |
| g.add_nodes_from(np.arange(0, num_residues)) |
| return g |
|
|
|
|
| def dist_thresh_graph(dist_mtx, threshold): |
| """Creates undirected graph based on a distance threshold""" |
| g = nx.Graph() |
| g.add_nodes_from(np.arange(0, dist_mtx.shape[0])) |
|
|
| |
| for rn1 in range(len(dist_mtx)): |
| |
| rns_within_threshold = np.where(dist_mtx[rn1] < threshold)[0] |
|
|
| |
| for rn2 in rns_within_threshold: |
| |
| if rn1 != rn2: |
| g.add_edge(rn1, rn2) |
| return g |
|
|
|
|
| def ordered_adjacency_matrix(g): |
| """returns the adjacency matrix ordered by node label in increasing order as a numpy array""" |
| node_order = sorted(g.nodes()) |
| adj_mtx = nx.to_numpy_matrix(g, nodelist=node_order) |
| return np.asarray(adj_mtx).astype(np.float32) |
|
|
|
|
| def cbeta_distance_matrix(pdb_fn, start=0, end=None): |
| |
| |
|
|
| |
| ppdb = PandasPdb().read_pdb(pdb_fn) |
|
|
| |
| |
| |
| |
| grouped = ppdb.df["ATOM"].groupby("residue_number", sort=True) |
|
|
| |
| coords = [] |
|
|
| |
| for i, (residue_number, values) in enumerate(grouped): |
|
|
| |
| end_index = len(grouped) if end is None else end |
| if i not in range(start, end_index): |
| continue |
|
|
| residue_group = grouped.get_group(residue_number) |
|
|
| atom_names = residue_group["atom_name"] |
| if "CB" in atom_names.values: |
| |
| atom_name = "CB" |
| elif "CA" in atom_names.values: |
| |
| atom_name = "CA" |
| else: |
| raise ValueError( |
| "Couldn't find CB or CA for residue {}".format(residue_number) |
| ) |
|
|
| |
| coords.append( |
| residue_group[residue_group["atom_name"] == atom_name][ |
| ["x_coord", "y_coord", "z_coord"] |
| ].values[0] |
| ) |
|
|
| |
| coords = np.stack(coords) |
|
|
| |
| dist_mtx = cdist(coords, coords, metric="euclidean") |
|
|
| return dist_mtx |
|
|
|
|
| def get_neighbors(g, nodes): |
| """returns a list (set) of neighbors of all given nodes""" |
| neighbors = set() |
| for n in nodes: |
| neighbors.update(g.neighbors(n)) |
| return sorted(list(neighbors)) |
|
|
|
|
| def gen_graph( |
| graph_type, |
| res_dist_mtx, |
| dist_thresh=7, |
| shuffle_seed=7, |
| graph_save_dir=None, |
| save=False, |
| ): |
| """generate the specified structure graph using the specified residue distance matrix""" |
| if graph_type is GraphType.LINEAR: |
| g = linear_graph(len(res_dist_mtx)) |
| save_fn = None if not save else os.path.join(graph_save_dir, "linear.graph") |
|
|
| elif graph_type is GraphType.COMPLETE: |
| g = complete_graph(len(res_dist_mtx)) |
| save_fn = None if not save else os.path.join(graph_save_dir, "complete.graph") |
|
|
| elif graph_type is GraphType.DISCONNECTED: |
| g = disconnected_graph(len(res_dist_mtx)) |
| save_fn = ( |
| None if not save else os.path.join(graph_save_dir, "disconnected.graph") |
| ) |
|
|
| elif graph_type is GraphType.DIST_THRESH: |
| g = dist_thresh_graph(res_dist_mtx, dist_thresh) |
| save_fn = ( |
| None |
| if not save |
| else os.path.join( |
| graph_save_dir, "dist_thresh_{}.graph".format(dist_thresh) |
| ) |
| ) |
|
|
| elif graph_type is GraphType.DIST_THRESH_SHUFFLED: |
| g = dist_thresh_graph(res_dist_mtx, dist_thresh) |
| g = shuffle_nodes(g, seed=shuffle_seed) |
| save_fn = ( |
| None |
| if not save |
| else os.path.join( |
| graph_save_dir, |
| "dist_thresh_{}_shuffled_r{}.graph".format(dist_thresh, shuffle_seed), |
| ) |
| ) |
|
|
| else: |
| raise ValueError("Graph type {} is not implemented".format(graph_type)) |
|
|
| if save: |
| if isfile(save_fn): |
| print( |
| "err: graph already exists: {}. to overwrite, delete the existing file first".format( |
| save_fn |
| ) |
| ) |
| else: |
| os.makedirs(graph_save_dir, exist_ok=True) |
| save_graph(g, save_fn) |
|
|
| return g |
|
|
|
|
| |
|
|
|
|
| class METLConfig(PretrainedConfig): |
| IDENT_UUID_MAP = IDENT_UUID_MAP |
| UUID_URL_MAP = UUID_URL_MAP |
| model_type = "METL" |
|
|
| def __init__( |
| self, |
| id: str = None, |
| **kwargs, |
| ): |
| self.id = id |
| super().__init__(**kwargs) |
|
|
|
|
| class METLModel(PreTrainedModel): |
| config_class = METLConfig |
|
|
| def __init__(self, config: METLConfig): |
| super().__init__(config) |
| self.model = None |
| self.encoder = None |
| self.config = config |
|
|
| def forward(self, X, pdb_fn=None): |
| if pdb_fn: |
| return self.model(X, pdb_fn=pdb_fn) |
| return self.model(X) |
|
|
| def load_from_uuid(self, id): |
| if id: |
| assert ( |
| id in self.config.UUID_URL_MAP |
| ), "ID given does not reference a valid METL model in the IDENT_UUID_MAP" |
| self.config.id = id |
|
|
| self.model, self.encoder = get_from_uuid(self.config.id) |
|
|
| def load_from_ident(self, id): |
| if id: |
| id = id.lower() |
| assert ( |
| id in self.config.IDENT_UUID_MAP |
| ), "ID given does not reference a valid METL model in the IDENT_UUID_MAP" |
| self.config.id = id |
|
|
| self.model, self.encoder = get_from_ident(self.config.id) |
|
|
| def get_from_checkpoint(self, checkpoint_path): |
| self.model, self.encoder = get_from_checkpoint(checkpoint_path) |
|
|