SPECTRE-Large / spectre /models /layers /attention.py
cclaess's picture
Initial commit
8b41845 verified
from typing import Type, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import use_fused_attn
from spectre.models.layers.rotary_pos_embed import rope_apply
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
mode: str = "mha",
q_proj_dim: Optional[int] = None,
kv_proj_dim: Optional[int] = None,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.mode = mode.lower()
assert self.mode in ["mha", "mqa", "mla"], "Attention mode must be 'mha', 'mqa', or 'mla'"
assert not (self.mode == "mla" and kv_proj_dim is None), "kv_proj_dim must be provided for 'mla' mode"
assert not (self.mode == "mla" and q_proj_dim is None), "q_proj_dim must be provided for 'mla' mode"
if self.mode == "mha":
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
elif self.mode == "mqa":
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias) # Key and value pair shared across heads
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
elif self.mode == "mla":
self.q_proj = nn.Linear(dim, q_proj_dim, bias=qkv_bias) # Projected query for every head
self.kv_proj = nn.Linear(dim, kv_proj_dim, bias=qkv_bias) # Projected key and value pair for every head
self.q_norm = norm_layer(q_proj_dim) if qk_norm else nn.Identity()
self.kv_norm = norm_layer(kv_proj_dim) if qk_norm else nn.Identity()
self.q = nn.Linear(q_proj_dim, dim, bias=qkv_bias) # Query for every head
self.kv = nn.Linear(kv_proj_dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
rope: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE to the query and key tensors.
Args:
q (torch.Tensor): Query tensor of shape (B, num_heads, N, head_dim)
k (torch.Tensor): Key tensor of shape (B, num_heads, N, head_dim)
rope (Tuple[torch.Tensor, torch.Tensor]): Tuple of (sin, cos) tensors for RoPE application.
Sin and cos can be of shape
"""
# Match dtype to rope for numeric stability
q_dtype, k_dtype = q.dtype, k.dtype
sin, cos = rope
if sin.ndim == 2:
sin = sin.unsqueeze(0).unsqueeze(0)
cos = cos.unsqueeze(0).unsqueeze(0)
elif sin.ndim == 3:
sin = sin.unsqueeze(1)
cos = cos.unsqueeze(1)
else:
raise ValueError("RoPE sin/cos must be of shape [N, head_dim] or [B, N, head_dim]")
rope_dtype = sin.dtype
q = q.to(dtype=rope_dtype)
k = k.to(dtype=rope_dtype)
N = q.shape[-2] # total tokens per sample
N_spatial = sin.shape[-2] # number of spatial tokens covered by rope
prefix = N - N_spatial # e.g., [cls] or [reg] tokens at the front
assert prefix >= 0, "RoPE sin/cos length exceeds sequence length"
if prefix > 0:
q_prefix = q[:, :, :prefix, :]
k_prefix = k[:, :, :prefix, :]
q_spatial = q[:, :, prefix:, :]
k_spatial = k[:, :, prefix:, :]
else:
q_prefix = k_prefix = None
q_spatial, k_spatial = q, k
# Apply RoPE on the spatial tail
q_spatial = rope_apply(q_spatial, sin, cos)
k_spatial = rope_apply(k_spatial, sin, cos)
# Stitch back
if prefix > 0:
q = torch.cat((q_prefix, q_spatial), dim=-2)
k = torch.cat((k_prefix, k_spatial), dim=-2)
else:
q, k = q_spatial, k_spatial
# Cast back to original dtypes
q = q.to(dtype=q_dtype)
k = k.to(dtype=k_dtype)
return q, k
def compute_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, N, _ = x.shape
if self.mode == "mha":
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
elif self.mode == "mqa":
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, N, 2, 1, self.head_dim).permute(2, 0, 3, 1, 4)
kv = kv.expand(-1, -1, self.num_heads, -1, -1) # Expand to match num_heads
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
elif self.mode == "mla":
q = self.q_proj(x)
kv = self.kv_proj(x)
q, kv = self.q_norm(q), self.kv_norm(kv) # Normalization on projections
q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
return q, k, v
def compute_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
B, _, N, _ = q.shape
C = self.num_heads * self.head_dim
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
return x.transpose(1, 2).reshape(B, N, C)
def forward(
self,
x: torch.Tensor,
rope = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x)
if rope is not None:
if isinstance(rope, list):
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
q, k = self.apply_rotary_pos_emb(q, k, rope)
x = self.compute_attention(q, k, v)
x = self.proj(x)
x = self.proj_drop(x)
return x