| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .utils import nearest_power_of_two |
|
|
| try: |
| from flash_attn import flash_attn_func as fa2 |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based flash attention: {e}. No alternative currently available." |
| ) |
| |
|
|
| class Attention(nn.Module): |
| def __init__(self, config): |
| super(Attention, self).__init__() |
| if isinstance(config.torch_dtype, str): |
| torch_dtype = getattr(torch, config.torch_dtype) |
| else: |
| torch_dtype = config.torch_dtype |
| assert torch.cuda.is_available(), "CUDA is required." |
| assert config.n_embd % config.n_heads == 0 |
| self.n_heads = config.n_heads |
|
|
| self.device = torch.device("cuda") |
| self.bsz = config.bsz |
| self.c_attn = nn.Linear( |
| config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype |
| ) |
| self.c_proj = nn.Linear( |
| config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype |
| ) |
| self.c_proj.SCALE_INIT = 1 |
| self.dropout = config.dropout |
| self.resid_dropout = nn.Dropout(self.dropout) |
| self.alibi_slopes = self._get_alibi_slopes(self.n_heads) |
| self.window_size = config.window_size |
| self.softcap = config.softcap |
|
|
| def _generate_slopes(self, n: int): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| return [start * (start**i) for i in range(n)] |
|
|
| def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25): |
| |
| if math.log2(n_heads).is_integer(): |
| slopes = self._generate_slopes(n_heads) |
| else: |
| |
| n = nearest_power_of_two(n_heads, round_up=False) |
| slopes_power_of_two = self._generate_slopes(n) |
| |
| |
| extra_slopes = self._generate_slopes(2 * n) |
| extra_slopes_trunc = extra_slopes[0::2][: n_heads - n] |
| slopes = slopes_power_of_two + extra_slopes_trunc |
| slopes = torch.tensor(slopes, device=self.device) |
| slopes = slopes * interpolation_factor |
| return slopes.to(torch.float32) |
|
|
|
|
| def forward(self, x): |
| bsz, seq_len, d_in = x.size() |
|
|
| qkv = self.c_attn(x) |
| q, k, v = torch.chunk(qkv, 3, dim=2) |
|
|
| q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
| k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
| v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
| y = fa2( |
| q, |
| k, |
| v, |
| dropout_p=self.dropout if self.training else 0.0, |
| causal=True, |
| window_size=(self.window_size, 0), |
| alibi_slopes=self.alibi_slopes, |
| softcap=self.softcap, |
| ) |
| y = y.contiguous().view(bsz, seq_len, d_in) |
| y = self.resid_dropout(self.c_proj(y)) |
| return y |
|
|
| class AttentionSDPA(nn.Module): |
| def __init__(self, config): |
| super(Attention, self).__init__() |
| if isinstance(config.torch_dtype, str): |
| torch_dtype = getattr(torch, config.torch_dtype) |
| else: |
| torch_dtype = config.torch_dtype |
| assert torch.cuda.is_available(), "CUDA is required." |
| assert config.n_embd % config.n_heads == 0 |
| self.n_heads = config.n_heads |
|
|
| self.device = torch.device("cuda") |
| self.bsz = config.bsz |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype) |
| self.dropout = config.dropout |
| self.resid_dropout = nn.Dropout(self.dropout) |
|
|
| def forward(self, x): |
| bsz, seq_len, d_in = x.size() |
|
|
| qkv = self.c_attn(x) |
| q, k, v = torch.chunk(qkv, 3, dim=2) |
|
|
| q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
| k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
| v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
|
|
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| is_causal=True, |
| dropout_p=self.dropout if self.training else 0.0 |
| ) |
|
|
| y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in) |
|
|
| y = self.resid_dropout(self.c_proj(y)) |
| return y |
|
|