| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
| import math |
| from typing import Optional, Callable, List, Tuple, Sequence |
| import numpy as np |
|
|
| import deepspeed |
| import torch |
| import torch.nn as nn |
| from scipy.stats import truncnorm |
|
|
| from openfold.utils.checkpointing import get_checkpoint_fn |
| from openfold.utils.tensor_utils import ( |
| permute_final_dims, |
| flatten_final_dims, |
| _chunk_slice, |
| ) |
|
|
|
|
| def _prod(nums): |
| out = 1 |
| for n in nums: |
| out = out * n |
| return out |
|
|
|
|
| def _calculate_fan(linear_weight_shape, fan="fan_in"): |
| fan_out, fan_in = linear_weight_shape |
|
|
| if fan == "fan_in": |
| f = fan_in |
| elif fan == "fan_out": |
| f = fan_out |
| elif fan == "fan_avg": |
| f = (fan_in + fan_out) / 2 |
| else: |
| raise ValueError("Invalid fan option") |
|
|
| return f |
|
|
|
|
| def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): |
| shape = weights.shape |
| f = _calculate_fan(shape, fan) |
| scale = scale / max(1, f) |
| a = -2 |
| b = 2 |
| std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) |
| size = _prod(shape) |
| samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) |
| samples = np.reshape(samples, shape) |
| with torch.no_grad(): |
| weights.copy_(torch.tensor(samples, device=weights.device)) |
|
|
|
|
| def lecun_normal_init_(weights): |
| trunc_normal_init_(weights, scale=1.0) |
|
|
|
|
| def he_normal_init_(weights): |
| trunc_normal_init_(weights, scale=2.0) |
|
|
|
|
| def glorot_uniform_init_(weights): |
| nn.init.xavier_uniform_(weights, gain=1) |
|
|
|
|
| def final_init_(weights): |
| with torch.no_grad(): |
| weights.fill_(0.0) |
|
|
|
|
| def gating_init_(weights): |
| with torch.no_grad(): |
| weights.fill_(0.0) |
|
|
|
|
| def normal_init_(weights): |
| torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") |
|
|
|
|
| def ipa_point_weights_init_(weights): |
| with torch.no_grad(): |
| softplus_inverse_1 = 0.541324854612918 |
| weights.fill_(softplus_inverse_1) |
|
|
|
|
| class Linear(nn.Linear): |
| """ |
| A Linear layer with built-in nonstandard initializations. Called just |
| like torch.nn.Linear. |
| |
| Implements the initializers in 1.11.4, plus some additional ones found |
| in the code. |
| """ |
|
|
| def __init__( |
| self, |
| in_dim: int, |
| out_dim: int, |
| bias: bool = True, |
| init: str = "default", |
| init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, |
| ): |
| """ |
| Args: |
| in_dim: |
| The final dimension of inputs to the layer |
| out_dim: |
| The final dimension of layer outputs |
| bias: |
| Whether to learn an additive bias. True by default |
| init: |
| The initializer to use. Choose from: |
| |
| "default": LeCun fan-in truncated normal initialization |
| "relu": He initialization w/ truncated normal distribution |
| "glorot": Fan-average Glorot uniform initialization |
| "gating": Weights=0, Bias=1 |
| "normal": Normal initialization with std=1/sqrt(fan_in) |
| "final": Weights=0, Bias=0 |
| |
| Overridden by init_fn if the latter is not None. |
| init_fn: |
| A custom initializer taking weight and bias as inputs. |
| Overrides init if not None. |
| """ |
| super(Linear, self).__init__(in_dim, out_dim, bias=bias) |
|
|
| if bias: |
| with torch.no_grad(): |
| self.bias.fill_(0) |
|
|
| if init_fn is not None: |
| init_fn(self.weight, self.bias) |
| else: |
| if init == "default": |
| lecun_normal_init_(self.weight) |
| elif init == "relu": |
| he_normal_init_(self.weight) |
| elif init == "glorot": |
| glorot_uniform_init_(self.weight) |
| elif init == "gating": |
| gating_init_(self.weight) |
| if bias: |
| with torch.no_grad(): |
| self.bias.fill_(1.0) |
| elif init == "normal": |
| normal_init_(self.weight) |
| elif init == "final": |
| final_init_(self.weight) |
| else: |
| raise ValueError("Invalid init string.") |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, c_in, eps=1e-5): |
| super(LayerNorm, self).__init__() |
| |
| self.c_in = (c_in,) |
| self.eps = eps |
|
|
| self.weight = nn.Parameter(torch.ones(c_in)) |
| self.bias = nn.Parameter(torch.zeros(c_in)) |
|
|
| def forward(self, x): |
| d = x.dtype |
| if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): |
| with torch.cuda.amp.autocast(enabled=False): |
| out = nn.functional.layer_norm( |
| x, |
| self.c_in, |
| self.weight.to(dtype=d), |
| self.bias.to(dtype=d), |
| self.eps |
| ) |
| else: |
| out = nn.functional.layer_norm( |
| x, |
| self.c_in, |
| self.weight, |
| self.bias, |
| self.eps, |
| ) |
|
|
| return out |
|
|
| @torch.jit.ignore |
| def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: |
| """ |
| Softmax, but without automatic casting to fp32 when the input is of |
| type bfloat16 |
| """ |
| d = t.dtype |
| if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): |
| with torch.cuda.amp.autocast(enabled=False): |
| s = torch.nn.functional.softmax(t, dim=dim) |
| else: |
| s = torch.nn.functional.softmax(t, dim=dim) |
|
|
| return s |
|
|
|
|
| |
| def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: |
| |
| query = permute_final_dims(query, (1, 0, 2)) |
| |
| |
| key = permute_final_dims(key, (1, 2, 0)) |
|
|
| |
| value = permute_final_dims(value, (1, 0, 2)) |
|
|
| |
| a = torch.matmul(query, key) |
|
|
| for b in biases: |
| a += b |
|
|
| a = softmax(a, -1) |
|
|
| |
| a = torch.matmul(a, value) |
|
|
| |
| a = a.transpose(-2, -3) |
|
|
| return a |
|
|
|
|
| @torch.jit.ignore |
| def _attention_chunked_trainable( |
| query, key, value, biases, chunk_size, chunk_dim, checkpoint, |
| ): |
| if(checkpoint and len(biases) > 2): |
| raise ValueError( |
| "Checkpointed version permits only permits two bias terms" |
| ) |
|
|
| def _checkpointable_attention(q, k, v, b1, b2): |
| bs = [b for b in [b1, b2] if b is not None] |
| return _attention(q, k, v, bs) |
|
|
| o_chunks = [] |
| checkpoint_fn = get_checkpoint_fn() |
| count = query.shape[chunk_dim] |
| for start in range(0, count, chunk_size): |
| end = start + chunk_size |
| idx = [slice(None)] * len(query.shape) |
| idx[chunk_dim] = slice(start, end) |
| idx_tup = tuple(idx) |
| q_chunk = query[idx_tup] |
| k_chunk = key[idx_tup] |
| v_chunk = value[idx_tup] |
|
|
| def _slice_bias(b): |
| idx[chunk_dim] = ( |
| slice(start, end) if b.shape[chunk_dim] != 1 else slice(None) |
| ) |
| return b[tuple(idx)] |
|
|
| if(checkpoint): |
| bias_1_chunk, bias_2_chunk = [ |
| _slice_bias(b) if b is not None else None |
| for b in (biases + [None, None])[:2] |
| ] |
|
|
| o_chunk = checkpoint_fn(_checkpointable_attention, |
| q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk |
| ) |
| else: |
| bias_chunks = [ |
| _slice_bias(b) for b in biases |
| ] |
|
|
| o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) |
|
|
| o_chunks.append(o_chunk) |
|
|
| o = torch.cat(o_chunks, dim=chunk_dim) |
| return o |
|
|
|
|
| class Attention(nn.Module): |
| """ |
| Standard multi-head attention using AlphaFold's default layer |
| initialization. Allows multiple bias vectors. |
| """ |
| def __init__( |
| self, |
| c_q: int, |
| c_k: int, |
| c_v: int, |
| c_hidden: int, |
| no_heads: int, |
| gating: bool = True, |
| ): |
| """ |
| Args: |
| c_q: |
| Input dimension of query data |
| c_k: |
| Input dimension of key data |
| c_v: |
| Input dimension of value data |
| c_hidden: |
| Per-head hidden dimension |
| no_heads: |
| Number of attention heads |
| gating: |
| Whether the output should be gated using query data |
| """ |
| super(Attention, self).__init__() |
|
|
| self.c_q = c_q |
| self.c_k = c_k |
| self.c_v = c_v |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.gating = gating |
|
|
| |
| |
|
|
| self.linear_q = Linear( |
| self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" |
| ) |
| self.linear_k = Linear( |
| self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" |
| ) |
| self.linear_v = Linear( |
| self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" |
| ) |
| self.linear_o = Linear( |
| self.c_hidden * self.no_heads, self.c_q, init="final" |
| ) |
|
|
| self.linear_g = None |
| if self.gating: |
| self.linear_g = Linear( |
| self.c_q, self.c_hidden * self.no_heads, init="gating" |
| ) |
|
|
| self.sigmoid = nn.Sigmoid() |
|
|
| def _prep_qkv(self, |
| q_x: torch.Tensor, |
| kv_x: torch.Tensor |
| ) -> Tuple[ |
| torch.Tensor, torch.Tensor, torch.Tensor |
| ]: |
| |
| q = self.linear_q(q_x) |
| k = self.linear_k(kv_x) |
| v = self.linear_v(kv_x) |
|
|
| |
| q = q.view(q.shape[:-1] + (self.no_heads, -1)) |
| k = k.view(k.shape[:-1] + (self.no_heads, -1)) |
| v = v.view(v.shape[:-1] + (self.no_heads, -1)) |
|
|
| q /= math.sqrt(self.c_hidden) |
|
|
| return q, k, v |
|
|
| def _wrap_up(self, |
| o: torch.Tensor, |
| q_x: torch.Tensor |
| ) -> torch.Tensor: |
| if(self.linear_g is not None): |
| g = self.sigmoid(self.linear_g(q_x)) |
| |
| |
| g = g.view(g.shape[:-1] + (self.no_heads, -1)) |
| o = o * g |
|
|
| |
| o = flatten_final_dims(o, 2) |
|
|
| |
| o = self.linear_o(o) |
|
|
| return o |
|
|
| def forward( |
| self, |
| q_x: torch.Tensor, |
| kv_x: torch.Tensor, |
| biases: Optional[List[torch.Tensor]] = None, |
| use_lma: bool = False, |
| q_chunk_size: Optional[int] = None, |
| kv_chunk_size: Optional[int] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| q_x: |
| [*, Q, C_q] query data |
| kv_x: |
| [*, K, C_k] key data |
| biases: |
| List of biases that broadcast to [*, H, Q, K] |
| use_lma: |
| Whether to use low-memory attention |
| q_chunk_size: |
| Query chunk size (for LMA) |
| kv_chunk_size: |
| Key/Value chunk size (for LMA) |
| Returns |
| [*, Q, C_q] attention update |
| """ |
| if(biases is None): |
| biases = [] |
| if(use_lma and (q_chunk_size is None or kv_chunk_size is None)): |
| raise ValueError( |
| "If use_lma is specified, q_chunk_size and kv_chunk_size must " |
| "be provided" |
| ) |
|
|
| q, k, v = self._prep_qkv(q_x, kv_x) |
|
|
| if(use_lma): |
| biases = [ |
| b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) |
| for b in biases |
| ] |
|
|
| o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) |
| else: |
| o = _attention(q, k, v, biases) |
|
|
| o = self._wrap_up(o, q_x) |
|
|
| return o |
|
|
|
|
| class GlobalAttention(nn.Module): |
| def __init__(self, c_in, c_hidden, no_heads, inf, eps): |
| super(GlobalAttention, self).__init__() |
|
|
| self.c_in = c_in |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.inf = inf |
| self.eps = eps |
|
|
| self.linear_q = Linear( |
| c_in, c_hidden * no_heads, bias=False, init="glorot" |
| ) |
|
|
| self.linear_k = Linear( |
| c_in, c_hidden, bias=False, init="glorot", |
| ) |
| self.linear_v = Linear( |
| c_in, c_hidden, bias=False, init="glorot", |
| ) |
| self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") |
| self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") |
|
|
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| |
| q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( |
| torch.sum(mask, dim=-1)[..., None] + self.eps |
| ) |
|
|
| |
| q = self.linear_q(q) |
| q *= (self.c_hidden ** (-0.5)) |
|
|
| |
| q = q.view(q.shape[:-1] + (self.no_heads, -1)) |
|
|
| |
| k = self.linear_k(m) |
| v = self.linear_v(m) |
|
|
| |
| a = torch.matmul( |
| q, |
| k.transpose(-1, -2), |
| ) |
| bias = (self.inf * (mask - 1))[..., :, None, :] |
| a += bias |
| a = softmax(a) |
|
|
| |
| o = torch.matmul( |
| a, |
| v, |
| ) |
|
|
| |
| g = self.sigmoid(self.linear_g(m)) |
|
|
| |
| g = g.view(g.shape[:-1] + (self.no_heads, -1)) |
|
|
| |
| o = o.unsqueeze(-3) * g |
|
|
| |
| o = o.reshape(o.shape[:-2] + (-1,)) |
|
|
| |
| m = self.linear_o(o) |
|
|
| return m |
|
|
|
|
| def _lma( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| biases: List[torch.Tensor], |
| q_chunk_size: int, |
| kv_chunk_size: int, |
| ): |
| no_q, no_kv = q.shape[-3], k.shape[-3] |
|
|
| |
| o = q.new_zeros(q.shape) |
| for q_s in range(0, no_q, q_chunk_size): |
| q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] |
| large_bias_chunks = [ |
| b[..., q_s: q_s + q_chunk_size, :] for b in biases |
| ] |
|
|
| maxes = [] |
| weights = [] |
| values = [] |
| for kv_s in range(0, no_kv, kv_chunk_size): |
| k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] |
| v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] |
| small_bias_chunks = [ |
| b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks |
| ] |
|
|
| a = torch.einsum( |
| "...qhd,...khd->...hqk", q_chunk, k_chunk, |
| ) |
| |
| for b in small_bias_chunks: |
| a += b |
| |
| a = a.transpose(-2, -3) |
| |
| max_a = torch.max(a, dim=-1, keepdim=True)[0] |
| exp_a = torch.exp(a - max_a) |
| exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) |
| |
| maxes.append(max_a.detach().squeeze(-1)) |
| weights.append(torch.sum(exp_a, dim=-1)) |
| values.append(exp_v) |
|
|
| chunk_max = torch.stack(maxes, dim=-3) |
| chunk_weights = torch.stack(weights, dim=-3) |
| chunk_values = torch.stack(values, dim=-4) |
|
|
| global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] |
| max_diffs = torch.exp(chunk_max - global_max) |
| chunk_values *= max_diffs.unsqueeze(-1) |
| chunk_weights *= max_diffs |
|
|
| all_values = torch.sum(chunk_values, dim=-4) |
| all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) |
|
|
| q_chunk_out = all_values / all_weights |
|
|
| o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out |
|
|
| return o |
|
|