File size: 7,646 Bytes
8b41845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from typing import Type, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import use_fused_attn

from spectre.models.layers.rotary_pos_embed import rope_apply


class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(

            self,

            dim: int,

            num_heads: int = 8,

            mode: str = "mha",

            q_proj_dim: Optional[int] = None,

            kv_proj_dim: Optional[int] = None,

            qkv_bias: bool = False,

            qk_norm: bool = False,

            proj_bias: bool = True,

            attn_drop: float = 0.,

            proj_drop: float = 0.,

            norm_layer: Type[nn.Module] = nn.LayerNorm,

    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()
        self.mode = mode.lower()
        assert self.mode in ["mha", "mqa", "mla"], "Attention mode must be 'mha', 'mqa', or 'mla'"
        assert not (self.mode == "mla" and kv_proj_dim is None), "kv_proj_dim must be provided for 'mla' mode"
        assert not (self.mode == "mla" and q_proj_dim is None), "q_proj_dim must be provided for 'mla' mode"
        
        if self.mode == "mha":
            self.q = nn.Linear(dim, dim, bias=qkv_bias)
            self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias)  # Key and value pair for every head
            self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
            self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        elif self.mode == "mqa":
            self.q = nn.Linear(dim, dim, bias=qkv_bias)
            self.kv = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias)  # Key and value pair shared across heads
            self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
            self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        elif self.mode == "mla":
            self.q_proj = nn.Linear(dim, q_proj_dim, bias=qkv_bias)  # Projected query for every head
            self.kv_proj = nn.Linear(dim, kv_proj_dim, bias=qkv_bias)  # Projected key and value pair for every head
            self.q_norm = norm_layer(q_proj_dim) if qk_norm else nn.Identity()
            self.kv_norm = norm_layer(kv_proj_dim) if qk_norm else nn.Identity()
            self.q = nn.Linear(q_proj_dim, dim, bias=qkv_bias)  # Query for every head
            self.kv = nn.Linear(kv_proj_dim, 2 * dim, bias=qkv_bias)  # Key and value pair for every head
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def apply_rotary_pos_emb(

        self, 

        q: torch.Tensor, 

        k: torch.Tensor, 

        rope: Tuple[torch.Tensor, torch.Tensor],

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Apply RoPE to the query and key tensors.

        Args:

            q (torch.Tensor): Query tensor of shape (B, num_heads, N, head_dim)

            k (torch.Tensor): Key tensor of shape (B, num_heads, N, head_dim)

            rope (Tuple[torch.Tensor, torch.Tensor]): Tuple of (sin, cos) tensors for RoPE application.

                Sin and cos can be of shape 

        """
        
        # Match dtype to rope for numeric stability
        q_dtype, k_dtype = q.dtype, k.dtype
        sin, cos = rope

        if sin.ndim == 2:
            sin = sin.unsqueeze(0).unsqueeze(0)
            cos = cos.unsqueeze(0).unsqueeze(0)
        elif sin.ndim == 3:
            sin = sin.unsqueeze(1)
            cos = cos.unsqueeze(1)
        else:
            raise ValueError("RoPE sin/cos must be of shape [N, head_dim] or [B, N, head_dim]")
        
        rope_dtype = sin.dtype

        q = q.to(dtype=rope_dtype)
        k = k.to(dtype=rope_dtype)

        N = q.shape[-2]             # total tokens per sample
        N_spatial = sin.shape[-2]   # number of spatial tokens covered by rope
        prefix = N - N_spatial      # e.g., [cls] or [reg] tokens at the front
        assert prefix >= 0, "RoPE sin/cos length exceeds sequence length"

        if prefix > 0:
            q_prefix = q[:, :, :prefix, :]
            k_prefix = k[:, :, :prefix, :]
            q_spatial = q[:, :, prefix:, :]
            k_spatial = k[:, :, prefix:, :]
        else:
            q_prefix = k_prefix = None
            q_spatial, k_spatial = q, k

        # Apply RoPE on the spatial tail
        q_spatial = rope_apply(q_spatial, sin, cos)
        k_spatial = rope_apply(k_spatial, sin, cos)

        # Stitch back
        if prefix > 0:
            q = torch.cat((q_prefix, q_spatial), dim=-2)
            k = torch.cat((k_prefix, k_spatial), dim=-2)
        else:
            q, k = q_spatial, k_spatial

        # Cast back to original dtypes
        q = q.to(dtype=q_dtype)
        k = k.to(dtype=k_dtype)

        return q, k

    def compute_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, N, _ = x.shape
        if self.mode == "mha":
            q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            k, v = kv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
        elif self.mode == "mqa":
            q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            kv = self.kv(x).reshape(B, N, 2, 1, self.head_dim).permute(2, 0, 3, 1, 4)
            kv = kv.expand(-1, -1, self.num_heads, -1, -1)  # Expand to match num_heads
            k, v = kv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
        elif self.mode == "mla":
            q = self.q_proj(x)
            kv = self.kv_proj(x)
            q, kv = self.q_norm(q), self.kv_norm(kv)  # Normalization on projections
            q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            k, v = kv.unbind(0)
        return q, k, v
        
    def compute_attention(

        self, 

        q: torch.Tensor, 

        k: torch.Tensor, 

        v: torch.Tensor,

    ) -> torch.Tensor:
        B, _, N, _ = q.shape
        C = self.num_heads * self.head_dim

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        return x.transpose(1, 2).reshape(B, N, C)

    def forward(

        self, 

        x: torch.Tensor, 

        rope = None,

    ) -> torch.Tensor:

        q, k, v = self.compute_qkv(x)

        if rope is not None:
            if isinstance(rope, list):
                rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
            q, k = self.apply_rotary_pos_emb(q, k, rope)
        
        x = self.compute_attention(q, k, v)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x