| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from opt_einsum import contract as einsum
|
|
|
|
|
| class FeedForwardLayer(nn.Module):
|
| def __init__(self, d_model, r_ff, p_drop=0.1):
|
| super(FeedForwardLayer, self).__init__()
|
| self.norm = nn.LayerNorm(d_model)
|
| self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
| self.dropout = nn.Dropout(p_drop)
|
| self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def forward(self, src):
|
| src = self.norm(src)
|
| src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
| return src
|
|
|
| class Attention(nn.Module):
|
|
|
| def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
|
| super(Attention, self).__init__()
|
| self.h = n_head
|
| self.dim = d_hidden
|
|
|
| self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
| self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
| self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
|
|
| self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
| self.scaling = 1/math.sqrt(d_hidden)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def forward(self, query, key, value):
|
| B, Q = query.shape[:2]
|
| B, K = key.shape[:2]
|
|
|
| query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
| key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
| value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
|
|
| query = query * self.scaling
|
| attn = einsum('bqhd,bkhd->bhqk', query, key)
|
| attn = F.softmax(attn, dim=-1)
|
|
|
| out = einsum('bhqk,bkhd->bqhd', attn, value)
|
| out = out.reshape(B, Q, self.h*self.dim)
|
|
|
| out = self.to_out(out)
|
|
|
| return out
|
|
|
| class AttentionWithBias(nn.Module):
|
| def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
|
| super(AttentionWithBias, self).__init__()
|
| self.norm_in = nn.LayerNorm(d_in)
|
| self.norm_bias = nn.LayerNorm(d_bias)
|
|
|
| self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
| self.to_g = nn.Linear(d_in, n_head*d_hidden)
|
| self.to_out = nn.Linear(n_head*d_hidden, d_in)
|
|
|
| self.scaling = 1/math.sqrt(d_hidden)
|
| self.h = n_head
|
| self.dim = d_hidden
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def forward(self, x, bias):
|
| B, L = x.shape[:2]
|
|
|
| x = self.norm_in(x)
|
| bias = self.norm_bias(bias)
|
|
|
| query = self.to_q(x).reshape(B, L, self.h, self.dim)
|
| key = self.to_k(x).reshape(B, L, self.h, self.dim)
|
| value = self.to_v(x).reshape(B, L, self.h, self.dim)
|
| bias = self.to_b(bias)
|
| gate = torch.sigmoid(self.to_g(x))
|
|
|
| key = key * self.scaling
|
| attn = einsum('bqhd,bkhd->bqkh', query, key)
|
| attn = attn + bias
|
| attn = F.softmax(attn, dim=-2)
|
|
|
| out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| out = gate * out
|
|
|
| out = self.to_out(out)
|
| return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class RowAttentionWithBias(nn.Module):
|
| def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
| super().__init__()
|
| self.norm_msa = nn.LayerNorm(d_msa)
|
| self.norm_pair = nn.LayerNorm(d_pair)
|
|
|
|
|
| self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
| self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
| self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
|
|
| self.scaling = 1/math.sqrt(d_hidden)
|
| self.h = n_head
|
| self.dim = d_hidden
|
|
|
| def forward(self, msa, pair, mask = None):
|
| B, L = msa.shape[:2]
|
|
|
| msa = self.norm_msa(msa)
|
| pair = self.norm_pair(pair)
|
|
|
|
|
| query = self.to_q(msa).reshape(B, L, self.h, self.dim)
|
| key = self.to_k(msa).reshape(B, L, self.h, self.dim)
|
| value = self.to_v(msa).reshape(B, L, self.h, self.dim)
|
| bias = self.to_b(pair)
|
| gate = torch.sigmoid(self.to_g(msa))
|
|
|
|
|
| key = key * self.scaling
|
| attn = einsum('bqhd,bkhd->bqkh', query, key)
|
| attn = attn + bias
|
|
|
| if mask is not None:
|
| mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
|
| attn = attn * mask_re - 1e9 * (1-mask_re)
|
|
|
| attn = F.softmax(attn, dim=-2)
|
|
|
| out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| out = gate * out
|
|
|
| out = self.to_out(out)
|
| return out
|
|
|
| class ColAttention(nn.Module):
|
| def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
| super().__init__()
|
| self.norm_msa = nn.LayerNorm(d_msa)
|
|
|
| self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
| self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
|
|
| self.scaling = 1/math.sqrt(d_hidden)
|
| self.h = n_head
|
| self.dim = d_hidden
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def forward(self, msa, mask = None):
|
| '''
|
| msa (B,L,d_node)
|
| '''
|
| B, L = msa.shape[:2]
|
|
|
| msa = self.norm_msa(msa)
|
|
|
| query = self.to_q(msa).reshape(B, L, self.h, self.dim)
|
| key = self.to_k(msa).reshape(B, L, self.h, self.dim)
|
| value = self.to_v(msa).reshape(B, L, self.h, self.dim)
|
| gate = torch.sigmoid(self.to_g(msa))
|
|
|
| query = query * self.scaling
|
| attn = einsum('bqhd,bkhd->bqkh', query, key)
|
|
|
| if mask is not None:
|
| mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
|
| attn = attn * mask_re - 1e9 * (1-mask_re)
|
|
|
| attn = F.softmax(attn, dim=-3)
|
|
|
| out = einsum('bkqh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| out = gate * out
|
|
|
| out = self.to_out(out)
|
| return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class BiasedAxialAttention(nn.Module):
|
| def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
| super().__init__()
|
|
|
| self.is_row = is_row
|
| self.norm_pair = nn.LayerNorm(d_pair)
|
| self.norm_bias = nn.LayerNorm(d_bias)
|
|
|
| self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
| self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
| self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
|
|
| self.scaling = 1/math.sqrt(d_hidden)
|
| self.h = n_head
|
| self.dim = d_hidden
|
|
|
| def forward(self, pair, bias, mask = None):
|
| '''
|
| pair: (B, L, L, d_pair)
|
| mask: (B, L)
|
| '''
|
|
|
| B, L = pair.shape[:2]
|
|
|
| if self.is_row:
|
| pair = pair.permute(0,2,1,3)
|
| bias = bias.permute(0,2,1,3)
|
|
|
| pair = self.norm_pair(pair)
|
| bias = self.norm_bias(bias)
|
|
|
| query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
| key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
| value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
| bias = self.to_b(bias)
|
| gate = torch.sigmoid(self.to_g(pair))
|
|
|
| query = query * self.scaling
|
| key = key / math.sqrt(L)
|
| attn = einsum('bnihk,bnjhk->bijh', query, key)
|
| attn = attn + bias
|
| if mask is not None:
|
| mask_temp = 1e-9 * (mask.type(torch.float) - 1)
|
| attn = attn + mask_temp.unsqueeze(1).unsqueeze(-1)
|
|
|
| attn = F.softmax(attn, dim=-2)
|
|
|
| out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
|
| out = gate * out
|
|
|
| out = self.to_out(out)
|
| if self.is_row:
|
| out = out.permute(0,2,1,3)
|
| return out
|
|
|
|
|