shreyask commited on
Commit
162baf0
·
verified ·
1 Parent(s): 75700ed

Upload needle_torch/layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. needle_torch/layers.py +227 -0
needle_torch/layers.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Building-block nn.Modules for the Needle Simple Attention Network.
2
+
3
+ ZCRMSNorm — zero-centred RMSNorm: (1+γ)*x / RMS(x)
4
+ RoPE — pre-computed cos/sin freqs + static apply()
5
+ MultiHeadAttention — GQA, optional RoPE, optional past-KV caching
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .config import TransformerConfig
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # ZCRMSNorm
18
+ # ---------------------------------------------------------------------------
19
+
20
+ class ZCRMSNorm(nn.Module):
21
+ """Zero-centred RMSNorm.
22
+
23
+ Formula: (1 + γ) * x / RMS(x)
24
+ where γ is a learnable scale initialized to zero.
25
+
26
+ Matches Flax architecture.py ZCRMSNorm exactly.
27
+ """
28
+
29
+ def __init__(self, d: int, epsilon: float = 1e-6):
30
+ super().__init__()
31
+ self.epsilon = epsilon
32
+ # γ initialized to zero — param named "scale" to match Flax
33
+ self.scale = nn.Parameter(torch.zeros(d))
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ # Compute RMS in float32 for stability, then cast back
37
+ orig_dtype = x.dtype
38
+ x_f32 = x.float()
39
+ rms = torch.sqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + self.epsilon)
40
+ return ((1.0 + self.scale) * x_f32 / rms).to(orig_dtype)
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # RoPE
45
+ # ---------------------------------------------------------------------------
46
+
47
+ class RoPE(nn.Module):
48
+ """Pre-computed rotary position embeddings.
49
+
50
+ Buffers are NOT parameters (no gradient needed).
51
+ Exposes a static apply() helper for use inside MultiHeadAttention.
52
+ """
53
+
54
+ def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10000.0):
55
+ super().__init__()
56
+ # freqs: (head_dim//2,)
57
+ half = head_dim // 2
58
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
59
+ t = torch.arange(max_seq_len).float()
60
+ angles = torch.outer(t, freqs) # (max_seq_len, half)
61
+ self.register_buffer("cos", torch.cos(angles), persistent=False)
62
+ self.register_buffer("sin", torch.sin(angles), persistent=False)
63
+
64
+ @staticmethod
65
+ def apply(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
66
+ """Apply RoPE to x of shape (B, num_heads, T, head_dim).
67
+
68
+ Matches Flax apply_rope():
69
+ x1 = x[..., :half] x2 = x[..., half:]
70
+ return cat([x1*cos - x2*sin, x2*cos + x1*sin], dim=-1)
71
+ """
72
+ T = x.shape[2]
73
+ # cos/sin are (max_seq_len, half); slice to T and broadcast
74
+ cos_t = cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, half)
75
+ sin_t = sin[:T].unsqueeze(0).unsqueeze(0)
76
+ half = x.shape[-1] // 2
77
+ x1, x2 = x[..., :half], x[..., half:]
78
+ return torch.cat([x1 * cos_t - x2 * sin_t,
79
+ x2 * cos_t + x1 * sin_t], dim=-1)
80
+
81
+ def get_cos_sin(self, seq_len: int):
82
+ """Return (cos, sin) buffers sliced to seq_len."""
83
+ return self.cos[:seq_len], self.sin[:seq_len]
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Helpers
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def make_causal_mask(seq: int, past_seq: int = 0, device=None) -> torch.Tensor:
91
+ """Lower-triangular bool mask of shape (1, 1, seq, seq+past_seq).
92
+
93
+ Position i in the query can attend to positions 0..i+past_seq inclusive.
94
+ """
95
+ total = seq + past_seq
96
+ # rows = current positions (past_seq .. total-1 in the full KV sequence)
97
+ # columns = all KV positions (0 .. total-1)
98
+ row_idx = torch.arange(past_seq, total, device=device).unsqueeze(1) # (seq, 1)
99
+ col_idx = torch.arange(total, device=device).unsqueeze(0) # (1, total)
100
+ mask = row_idx >= col_idx # (seq, total)
101
+ return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq, total)
102
+
103
+
104
+ def make_padding_mask(tokens: torch.Tensor, pad_token_id: int) -> torch.Tensor:
105
+ """Boolean padding mask: True where token != pad. Shape (B, 1, 1, T)."""
106
+ return (tokens != pad_token_id).unsqueeze(1).unsqueeze(2)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # MultiHeadAttention
111
+ # ---------------------------------------------------------------------------
112
+
113
+ class MultiHeadAttention(nn.Module):
114
+ """Grouped-query attention with optional RoPE and past-KV caching.
115
+
116
+ Args:
117
+ config: TransformerConfig
118
+ is_cross_attn: if True, Q comes from one source and K/V from another.
119
+ is_causal: if True, applies causal mask (for decoder self-attention).
120
+ Also enables past_kv acceptance in forward().
121
+ """
122
+
123
+ def __init__(self, config: TransformerConfig, is_cross_attn: bool = False,
124
+ is_causal: bool = False):
125
+ super().__init__()
126
+ self.num_heads = config.num_heads
127
+ self.num_kv_heads = config.num_kv_heads
128
+ self.d_model = config.d_model
129
+ self.head_dim = config.d_model // config.num_heads
130
+ self.is_cross_attn = is_cross_attn
131
+ self.is_causal = is_causal
132
+ self.rope_keys_only = config.rope_keys_only
133
+ self.repeats = config.num_heads // config.num_kv_heads
134
+
135
+ kv_dim = config.num_kv_heads * self.head_dim
136
+
137
+ # Projections — no bias, matching Flax use_bias=False
138
+ self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
139
+ self.k_proj = nn.Linear(config.d_model, kv_dim, bias=False)
140
+ self.v_proj = nn.Linear(config.d_model, kv_dim, bias=False)
141
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
142
+
143
+ # Per-head QK norms (applied after reshape, before GQA expand)
144
+ # q_norm operates on num_heads heads of head_dim
145
+ # k_norm operates on num_kv_heads heads of head_dim
146
+ self.q_norm = ZCRMSNorm(self.head_dim)
147
+ self.k_norm = ZCRMSNorm(self.head_dim)
148
+
149
+ self._scale = math.sqrt(self.head_dim)
150
+
151
+ def forward(
152
+ self,
153
+ q_input: torch.Tensor,
154
+ kv_input: torch.Tensor,
155
+ mask: torch.Tensor | None = None,
156
+ rope: tuple[torch.Tensor, torch.Tensor] | None = None,
157
+ past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
158
+ ):
159
+ """
160
+ Args:
161
+ q_input: (B, T_q, d_model)
162
+ kv_input: (B, T_kv, d_model)
163
+ mask: (B, 1, T_q, T_kv) bool — True = attend
164
+ rope: (cos, sin) tensors of shape (T, head_dim//2) each
165
+ past_kv: (k_cache, v_cache), each (B, num_kv_heads, past_T, head_dim)
166
+ Only used when is_causal=True (decoder self-attn).
167
+
168
+ Returns:
169
+ out: (B, T_q, d_model)
170
+ present_kv: (k, v) each (B, num_kv_heads, T_total, head_dim)
171
+ """
172
+ B, T_q, _ = q_input.shape
173
+
174
+ q = self.q_proj(q_input) # (B, T_q, d_model)
175
+ k = self.k_proj(kv_input) # (B, T_kv, kv_dim)
176
+ v = self.v_proj(kv_input) # (B, T_kv, kv_dim)
177
+
178
+ # Reshape to (B, num_heads/num_kv_heads, T, head_dim)
179
+ q = q.reshape(B, T_q, self.num_heads, self.head_dim).transpose(1, 2)
180
+ T_kv = k.shape[1]
181
+ k = k.reshape(B, T_kv, self.num_kv_heads, self.head_dim).transpose(1, 2)
182
+ v = v.reshape(B, T_kv, self.num_kv_heads, self.head_dim).transpose(1, 2)
183
+
184
+ # QK norms (on the head_dim axis, before GQA expansion)
185
+ q = self.q_norm(q)
186
+ k = self.k_norm(k)
187
+
188
+ # RoPE application — applied to the CURRENT q and k only.
189
+ # Cached entries in past_kv already have RoPE at their original positions
190
+ # baked in; re-applying after cache-concat would double-rotate them.
191
+ if rope is not None:
192
+ cos, sin = rope
193
+ if not self.rope_keys_only:
194
+ q = RoPE.apply(q, cos, sin)
195
+ k = RoPE.apply(k, cos, sin)
196
+
197
+ # Concatenate past KV (decoder self-attn only).
198
+ # past_kv stores K with its original-position RoPE already applied.
199
+ if past_kv is not None:
200
+ k_past, v_past = past_kv
201
+ k = torch.cat([k_past, k], dim=2) # (B, num_kv_heads, past_T+T_kv, head_dim)
202
+ v = torch.cat([v_past, v], dim=2)
203
+
204
+ present_kv = (k, v)
205
+
206
+ # GQA expansion: repeat K and V to match num_heads
207
+ if self.repeats > 1:
208
+ k = k.repeat_interleave(self.repeats, dim=1) # (B, num_heads, T_total, head_dim)
209
+ v = v.repeat_interleave(self.repeats, dim=1)
210
+
211
+ # Scaled dot-product attention
212
+ # q: (B, num_heads, T_q, head_dim)
213
+ # k: (B, num_heads, T_total, head_dim)
214
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self._scale # (B, H, T_q, T_total)
215
+
216
+ if mask is not None:
217
+ # mask: True = attend, False = block
218
+ # Fill blocked positions with -inf
219
+ attn_weights = attn_weights.masked_fill(~mask, float("-inf"))
220
+
221
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).to(q.dtype)
222
+
223
+ out = torch.matmul(attn_weights, v) # (B, num_heads, T_q, head_dim)
224
+ out = out.transpose(1, 2).reshape(B, T_q, self.d_model)
225
+ out = self.out_proj(out)
226
+
227
+ return out, present_kv