| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
|
|
| class MultiHeadSDPA(nn.Module): |
| """ |
| Multi-head cross-attention using torch.nn.functional.scaled_dot_product_attention |
| without causal masking. Suitable for set inputs and cross-attention. |
| |
| If qk_norm=True, L2-normalizes Q and K per-head before the dot product, |
| then scales by a learned per-head temperature (log_scale). This caps logit |
| magnitude to [-1, +1] * exp(log_scale), preventing attention entropy |
| collapse at large head_dim. |
| """ |
| def __init__(self, d_model: int, num_heads: int, kv_heads: int = None, |
| qk_norm: bool = False, qk_norm_type: str = "l2"): |
| super().__init__() |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.kv_heads = kv_heads or num_heads |
| assert self.num_heads % self.kv_heads == 0, "kv_heads must divide num_heads" |
|
|
| self.head_dim = d_model // num_heads |
| self.qk_norm = qk_norm |
| self.qk_norm_type = qk_norm_type |
|
|
| |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False) |
|
|
| |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| nn.init.zeros_(self.out_proj.weight) |
|
|
| if qk_norm: |
| import math |
| if qk_norm_type == "rms": |
| |
| |
| |
| pass |
| else: |
| |
| |
| self.log_scale = nn.Parameter( |
| torch.full((num_heads,), math.log(math.sqrt(self.head_dim)))) |
|
|
| def forward( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| |
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(key) |
|
|
| B, Tq, _ = q.shape |
| _, Tk, _ = k.shape |
|
|
| q = q.view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2) |
|
|
| if self.kv_heads != self.num_heads: |
| repeat = self.num_heads // self.kv_heads |
| k = k.repeat_interleave(repeat, dim=1) |
| v = v.repeat_interleave(repeat, dim=1) |
|
|
| if self.qk_norm: |
| if self.qk_norm_type == "rms": |
| |
| |
| q = q * torch.rsqrt(q.square().mean(dim=-1, keepdim=True) + 1e-6) |
| k = k * torch.rsqrt(k.square().mean(dim=-1, keepdim=True) + 1e-6) |
| attn_mask = None |
| if key_padding_mask is not None: |
| attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool) |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, |
| ) |
| else: |
| |
| q = F.normalize(q, dim=-1) |
| k = F.normalize(k, dim=-1) |
| scale = self.log_scale.exp().view(1, -1, 1, 1) |
| q = q * scale |
| attn_mask = None |
| if key_padding_mask is not None: |
| attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool) |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, |
| scale=1.0, |
| ) |
| else: |
| attn_mask = None |
| if key_padding_mask is not None: |
| attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool) |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False |
| ) |
|
|
| attn_out = attn_out.transpose(1, 2).reshape(B, Tq, self.d_model) |
| return self.out_proj(attn_out) |
|
|
|
|
| |
| |
| |
|
|
| def _get_activation(name: str): |
| """Look up activation function by name. Supports 'relu_sq' for ReLU^2.""" |
| if name == "relu_sq": |
| return lambda x: F.relu(x).square() |
| return getattr(F, name) |
|
|
|
|
| class FeedForward(nn.Module): |
| """ |
| Position-wise MLP block: linear -> activation -> linear. |
| Supports 'gelu', 'relu', 'relu_sq', etc. |
| """ |
| def __init__(self, d_model: int, dim_ff: int, activation: str = "gelu"): |
| super().__init__() |
| self.linear1 = nn.Linear(d_model, dim_ff) |
| self.linear2 = nn.Linear(dim_ff, d_model) |
| nn.init.zeros_(self.linear2.weight) |
| nn.init.zeros_(self.linear2.bias) |
| self.activation = _get_activation(activation) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.linear1(x) |
| return self.linear2(self.activation(x)) |
|
|