maxchbx commited on
Commit
7acd624
·
verified ·
1 Parent(s): 0297069

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "seqcond",
3
+ "architectures": [
4
+ "SeqCondForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_seqcond.SeqCondConfig",
8
+ "AutoModelForCausalLM": "modeling_seqcond.SeqCondForCausalLM",
9
+ "AutoTokenizer": [
10
+ "tokenization_seqcond.SeqCondTokenizer",
11
+ null
12
+ ]
13
+ },
14
+ "transformers_version": "5.3.0",
15
+ "d_model": 1024,
16
+ "d_ff": 2730,
17
+ "num_layers": 24,
18
+ "vocab_size": 100300,
19
+ "maxlen": 4096,
20
+ "num_heads": 16,
21
+ "num_kv_heads": 4,
22
+ "qk_norm": true,
23
+ "qk_norm_eps": 1e-06,
24
+ "seqcond_heads": 16,
25
+ "num_query_heads": 16,
26
+ "num_thetas": 2,
27
+ "conv_kernel_size": 4,
28
+ "expand_factor": 2.0,
29
+ "out_expand_factor": 3,
30
+ "seqcond_ratio": 2,
31
+ "skip_low_rank": false,
32
+ "num_anchor_heads": 0,
33
+ "eos_token_id": 100279,
34
+ "pad_token_id": 100279,
35
+ "bos_token_id": null
36
+ }
configuration_seqcond.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SeqCond HuggingFace configuration.
3
+ """
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class SeqCondConfig(PretrainedConfig):
9
+ """
10
+ Configuration class for SeqCond models.
11
+
12
+ SeqCond is a hybrid recurrent-transformer architecture that interleaves
13
+ SeqCond (sequential conditioning) blocks with standard Transformer decoder
14
+ blocks. SeqCond blocks replace softmax attention with a closed-form
15
+ complex-exponential accumulator, enabling O(1) per-token decoding.
16
+
17
+ Args:
18
+ d_model: Hidden dimension.
19
+ d_ff: Feed-forward dimension (typically 3×d_model).
20
+ num_layers: Total number of blocks (SeqCond + Transformer).
21
+ vocab_size: Vocabulary size.
22
+ maxlen: Maximum sequence length (also sets KV-cache size).
23
+ dropout: Dropout rate (0.0 disables).
24
+ tie_weights: Whether to tie embedding and LM-head weights.
25
+ num_heads: Number of attention heads in Transformer blocks.
26
+ num_kv_heads: Number of KV heads (GQA). None = full MHA.
27
+ qk_norm: Whether to apply QK-normalization in Transformer blocks.
28
+ qk_norm_eps: Epsilon for QK-norm.
29
+ seqcond_heads: Number of SeqCond memory heads (K).
30
+ num_query_heads: Number of query heads in SeqCond (K_q, must divide K).
31
+ num_thetas: Number of frequency components per head (M).
32
+ derivative_order: Unused — kept for checkpoint compatibility.
33
+ num_anchor_heads: Number of anchor heads (no decay) in SeqCond.
34
+ conv_kernel_size: Depthwise conv kernel size inside SeqCond.
35
+ expand_factor: Inner expansion factor for SeqCond memory dimension.
36
+ out_expand_factor: SwiGLU expansion factor in SeqCond.
37
+ use_positional_embedding: Whether to add learnable positional embeddings.
38
+ seqcond_ratio: Block interleaving ratio. Every (seqcond_ratio+1)-th
39
+ block (1-indexed) is a Transformer block; the rest are SeqCond.
40
+ chunk_size: Chunk size for chunked computation (unused in PyTorch path).
41
+ use_square_matrix: Unused — kept for checkpoint compatibility.
42
+ """
43
+
44
+ model_type = "seqcond"
45
+
46
+ def __init__(
47
+ self,
48
+ # Core
49
+ d_model: int = 768,
50
+ d_ff: int = 2304,
51
+ num_layers: int = 12,
52
+ vocab_size: int = 100300,
53
+ maxlen: int = 768,
54
+ dropout: float = 0.0,
55
+ tie_weights: bool = True,
56
+ # Transformer block params
57
+ num_heads: int = 8,
58
+ num_kv_heads=None,
59
+ qk_norm: bool = True,
60
+ qk_norm_eps: float = 1e-6,
61
+ # SeqCond block params
62
+ seqcond_heads: int = 32,
63
+ num_query_heads: int = 6,
64
+ num_thetas: int = 4,
65
+ derivative_order: int = 0,
66
+ num_anchor_heads: int = 0,
67
+ conv_kernel_size: int = 4,
68
+ expand_factor: float = 2.0,
69
+ out_expand_factor: int = 3,
70
+ use_positional_embedding: bool = False,
71
+ seqcond_ratio: int = 5,
72
+ chunk_size: int = 128,
73
+ use_square_matrix: bool = False,
74
+ # Special token IDs (filled in by convert_checkpoint.py)
75
+ bos_token_id=None,
76
+ eos_token_id=None,
77
+ pad_token_id=None,
78
+ **kwargs,
79
+ ):
80
+ self.d_model = d_model
81
+ self.d_ff = d_ff
82
+ self.num_layers = num_layers
83
+ self.vocab_size = vocab_size
84
+ self.maxlen = maxlen
85
+ self.dropout = dropout
86
+ self.tie_weights = tie_weights
87
+
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads
90
+ self.qk_norm = qk_norm
91
+ self.qk_norm_eps = qk_norm_eps
92
+
93
+ self.seqcond_heads = seqcond_heads
94
+ self.num_query_heads = num_query_heads
95
+ self.num_thetas = num_thetas
96
+ self.derivative_order = derivative_order
97
+ self.num_anchor_heads = num_anchor_heads
98
+ self.conv_kernel_size = conv_kernel_size
99
+ self.expand_factor = expand_factor
100
+ self.out_expand_factor = out_expand_factor
101
+ self.use_positional_embedding = use_positional_embedding
102
+ self.seqcond_ratio = seqcond_ratio
103
+ self.chunk_size = chunk_size
104
+ self.use_square_matrix = use_square_matrix
105
+
106
+ super().__init__(
107
+ bos_token_id=bos_token_id,
108
+ eos_token_id=eos_token_id,
109
+ pad_token_id=pad_token_id,
110
+ **kwargs,
111
+ )
generation_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generation_utils.py — High-level generation helpers for SeqCond models.
3
+
4
+ These functions wrap SeqCondForCausalLM.generate() / generate_batch() with a
5
+ more user-friendly interface that handles tokenization, formatting, and
6
+ streaming.
7
+
8
+ Example usage:
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ model = AutoModelForCausalLM.from_pretrained("path/to/model", trust_remote_code=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("path/to/model", trust_remote_code=True)
12
+ model.eval().cuda()
13
+
14
+ text = generate(model, tokenizer, "What is 2 + 2?")
15
+ print(text)
16
+
17
+ # Batched
18
+ texts = generate_batch(model, tokenizer, ["What is 2+2?", "Name a planet."])
19
+ """
20
+
21
+ from typing import Iterator, List, Optional
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+
27
+ _SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] # power-of-2 for CUDA graphs
28
+
29
+
30
+ def _quantized_seq_len(pos: int) -> int:
31
+ needed = pos + 1
32
+ for s in _SEQ_LENS:
33
+ if s >= needed:
34
+ return s
35
+ return _SEQ_LENS[-1]
36
+
37
+
38
+ @torch.no_grad()
39
+ def generate(
40
+ model,
41
+ tokenizer,
42
+ prompt: str,
43
+ max_new_tokens: int = 512,
44
+ temperature: float = 0.7,
45
+ top_p: float = 0.9,
46
+ top_k: int = 50,
47
+ repetition_penalty: float = 1.0,
48
+ use_chat_template: bool = True,
49
+ use_triton: bool = False,
50
+ strip_thinking: bool = False,
51
+ max_thinking_tokens: Optional[int] = None,
52
+ ) -> str:
53
+ """
54
+ Generate a single completion for *prompt*.
55
+
56
+ Args:
57
+ model: SeqCondForCausalLM instance.
58
+ tokenizer: SeqCondTokenizer instance.
59
+ prompt: Plain-text user prompt.
60
+ max_new_tokens: Maximum tokens to generate.
61
+ temperature: Sampling temperature (0 = greedy).
62
+ top_p: Nucleus sampling probability.
63
+ top_k: Top-k filtering (0 = disabled).
64
+ repetition_penalty: Penalty for repeating tokens.
65
+ use_chat_template: If True, wrap prompt in <|im_start|>user…<|think_start|>.
66
+ use_triton: If True, use Triton kernels for SeqCond steps.
67
+ strip_thinking: If True, return only the text after <|think_end|>.
68
+ max_thinking_tokens: If set, inject <|think_end|> after this many
69
+ thinking tokens to cap reasoning length.
70
+
71
+ Returns:
72
+ Generated text (completion only, EOS stripped).
73
+ """
74
+ device = next(model.parameters()).device
75
+ eos_id = tokenizer.im_end_id
76
+ think_end_id = tokenizer.think_end_id
77
+
78
+ if use_chat_template:
79
+ ids = tokenizer.encode_chat(prompt, add_think_start=True)
80
+ else:
81
+ ids = tokenizer.encode(prompt)
82
+
83
+ input_ids = torch.tensor([ids], dtype=torch.long, device=device)
84
+ logits, states = model.model.prefill(input_ids)
85
+ logits = logits.squeeze(1)
86
+
87
+ generated: List[int] = []
88
+ token_buf = torch.zeros((1, 1), dtype=torch.long, device=device)
89
+ seq_len = len(ids)
90
+
91
+ in_thinking = use_chat_template
92
+ thinking_tokens = 0
93
+ think_end_injected = False
94
+ counts: dict = {}
95
+
96
+ for _ in range(max_new_tokens):
97
+ ls = logits[0] / max(temperature, 1e-8) if temperature > 0 else logits[0].clone()
98
+
99
+ if repetition_penalty != 1.0:
100
+ for t in set(generated):
101
+ if 0 <= t < model.config.vocab_size:
102
+ ls[t] /= repetition_penalty
103
+
104
+ if temperature == 0:
105
+ next_token = int(torch.argmax(ls))
106
+ else:
107
+ if top_k > 0:
108
+ kth = torch.topk(ls, top_k).values[-1]
109
+ ls = ls.masked_fill(ls < kth, float("-inf"))
110
+ if top_p < 1.0:
111
+ sorted_ls, sorted_idx = torch.sort(ls, descending=True)
112
+ cum = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1)
113
+ remove = cum > top_p
114
+ remove[1:] = remove[:-1].clone(); remove[0] = False
115
+ ls[sorted_idx[remove]] = float("-inf")
116
+ probs = F.softmax(ls, dim=-1)
117
+ next_token = int(torch.multinomial(probs, 1))
118
+
119
+ # Thinking budget
120
+ if next_token == think_end_id:
121
+ in_thinking = False
122
+ if in_thinking:
123
+ thinking_tokens += 1
124
+ if (
125
+ max_thinking_tokens is not None
126
+ and in_thinking
127
+ and thinking_tokens >= max_thinking_tokens
128
+ and not think_end_injected
129
+ ):
130
+ next_token = think_end_id
131
+ in_thinking = False
132
+ think_end_injected = True
133
+
134
+ generated.append(next_token)
135
+ if next_token == eos_id:
136
+ break
137
+
138
+ token_buf[0, 0] = next_token
139
+ seq_len += 1
140
+ logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
141
+
142
+ # Decode
143
+ if generated and generated[-1] == eos_id:
144
+ generated = generated[:-1]
145
+
146
+ text = tokenizer.decode(generated)
147
+ if strip_thinking and "<|think_end|>" in text:
148
+ text = text.split("<|think_end|>", 1)[1].strip()
149
+ return text
150
+
151
+
152
+ @torch.no_grad()
153
+ def generate_batch(
154
+ model,
155
+ tokenizer,
156
+ prompts: List[str],
157
+ max_new_tokens: int = 512,
158
+ temperature: float = 0.7,
159
+ use_chat_template: bool = True,
160
+ use_triton: bool = False,
161
+ strip_thinking: bool = False,
162
+ ) -> List[str]:
163
+ """
164
+ Batched generation for a list of prompts.
165
+
166
+ Each prompt is prefilled individually (no padding noise), then all
167
+ sequences are decoded in lockstep with per-sample early stopping.
168
+
169
+ Returns a list of completion strings (EOS stripped).
170
+ """
171
+ device = next(model.parameters()).device
172
+ eos_id = tokenizer.im_end_id
173
+ B = len(prompts)
174
+
175
+ if use_chat_template:
176
+ all_ids = [tokenizer.encode_chat(p, add_think_start=True) for p in prompts]
177
+ else:
178
+ all_ids = [tokenizer.encode(p) for p in prompts]
179
+
180
+ # Individual prefills
181
+ all_logits, all_states = [], []
182
+ for ids in all_ids:
183
+ inp = torch.tensor([ids], dtype=torch.long, device=device)
184
+ lg, st = model.model.prefill(inp)
185
+ all_logits.append(lg.squeeze(1))
186
+ all_states.append(st)
187
+
188
+ logits = torch.cat(all_logits, dim=0)
189
+ num_blocks = len(all_states[0])
190
+ states = [
191
+ tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i])))
192
+ for i in range(num_blocks)
193
+ ]
194
+
195
+ generated = [[] for _ in range(B)]
196
+ finished = [False] * B
197
+ active_map = list(range(B))
198
+ token_buf = torch.zeros((B, 1), dtype=torch.long, device=device)
199
+ seq_len = max(len(ids) for ids in all_ids)
200
+
201
+ for _ in range(max_new_tokens):
202
+ B_cur = len(active_map)
203
+ if B_cur == 0:
204
+ break
205
+
206
+ if temperature == 0:
207
+ next_tokens = torch.argmax(logits, dim=-1)
208
+ else:
209
+ probs = F.softmax(logits / max(temperature, 1e-8), dim=-1)
210
+ next_tokens = torch.multinomial(probs, 1).squeeze(-1)
211
+
212
+ newly_done: set = set()
213
+ for bi in range(B_cur):
214
+ oi = active_map[bi]
215
+ tok = int(next_tokens[bi])
216
+ generated[oi].append(tok)
217
+ if tok == eos_id:
218
+ finished[oi] = True
219
+ newly_done.add(bi)
220
+ else:
221
+ token_buf[bi, 0] = tok
222
+
223
+ if all(finished):
224
+ break
225
+
226
+ if newly_done:
227
+ keep = [bi for bi in range(B_cur) if bi not in newly_done]
228
+ if not keep:
229
+ break
230
+ keep_idx = torch.tensor(keep, device=device)
231
+ token_buf = token_buf[keep_idx].contiguous()
232
+ states = [tuple(s[keep_idx].contiguous() for s in st) for st in states]
233
+ logits = logits[keep_idx]
234
+ active_map = [active_map[bi] for bi in keep]
235
+
236
+ seq_len += 1
237
+ logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
238
+
239
+ results = []
240
+ for toks in generated:
241
+ if toks and toks[-1] == eos_id:
242
+ toks = toks[:-1]
243
+ text = tokenizer.decode(toks)
244
+ if strip_thinking and "<|think_end|>" in text:
245
+ text = text.split("<|think_end|>", 1)[1].strip()
246
+ results.append(text)
247
+ return results
248
+
249
+
250
+ @torch.no_grad()
251
+ def stream(
252
+ model,
253
+ tokenizer,
254
+ prompt: str,
255
+ max_new_tokens: int = 512,
256
+ temperature: float = 0.7,
257
+ use_chat_template: bool = True,
258
+ use_triton: bool = False,
259
+ ) -> Iterator[str]:
260
+ """
261
+ Streaming token-by-token generation.
262
+
263
+ Yields decoded text fragments as they are produced. Useful for interactive
264
+ applications (e.g., a chat interface).
265
+
266
+ Example:
267
+ for fragment in stream(model, tokenizer, "Explain gravity."):
268
+ print(fragment, end="", flush=True)
269
+ """
270
+ device = next(model.parameters()).device
271
+ eos_id = tokenizer.im_end_id
272
+
273
+ if use_chat_template:
274
+ ids = tokenizer.encode_chat(prompt, add_think_start=True)
275
+ else:
276
+ ids = tokenizer.encode(prompt)
277
+
278
+ input_ids = torch.tensor([ids], dtype=torch.long, device=device)
279
+ logits, states = model.model.prefill(input_ids)
280
+ logits = logits.squeeze(1)
281
+
282
+ token_buf = torch.zeros((1, 1), dtype=torch.long, device=device)
283
+ seq_len = len(ids)
284
+
285
+ for _ in range(max_new_tokens):
286
+ if temperature == 0:
287
+ next_token = int(torch.argmax(logits[0]))
288
+ else:
289
+ probs = F.softmax(logits[0] / max(temperature, 1e-8), dim=-1)
290
+ next_token = int(torch.multinomial(probs, 1))
291
+
292
+ if next_token == eos_id:
293
+ break
294
+
295
+ try:
296
+ yield tokenizer.decode([next_token])
297
+ except Exception:
298
+ yield ""
299
+
300
+ token_buf[0, 0] = next_token
301
+ seq_len += 1
302
+ logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff37a8ff2b5b7f7fbe456efae39a5c7f82460b31e27239acee0210e3f044a0dc
3
+ size 949771696
modeling_seqcond.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SeqCond model — self-contained HuggingFace implementation.
3
+
4
+ All model code is embedded here so that trust_remote_code=True works without
5
+ any dependency on the original seqcond package.
6
+
7
+ Architecture:
8
+ - Hybrid recurrent-transformer: every (seqcond_ratio+1)-th block (1-indexed)
9
+ is a standard Transformer decoder block; the rest are SeqCond blocks.
10
+ - SeqCond blocks use complex-exponential accumulators (den_acc, re_acc, im_acc)
11
+ for O(1) per-token autoregressive decoding.
12
+ - Transformer blocks use GQA with RoPE and KV-cache for autoregressive decoding.
13
+ """
14
+
15
+ import math
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import PreTrainedModel
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from .configuration_seqcond import SeqCondConfig
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Optional Triton kernels (accelerates SeqCond step, not required)
29
+ # ---------------------------------------------------------------------------
30
+ try:
31
+ from .triton_kernels import (
32
+ gated_rmsnorm_triton,
33
+ seqcond_step_triton,
34
+ TRITON_AVAILABLE,
35
+ )
36
+ except ImportError:
37
+ gated_rmsnorm_triton = None
38
+ TRITON_AVAILABLE = False
39
+ seqcond_step_triton = None
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Normalisation layers
44
+ # ---------------------------------------------------------------------------
45
+
46
+ class RMSNorm(nn.Module):
47
+ def __init__(self, hidden_size: int, epsilon: float = 1e-5):
48
+ super().__init__()
49
+ self.epsilon = epsilon
50
+ self.scale = nn.Parameter(torch.ones(hidden_size))
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ orig = x.dtype
54
+ x = x.float()
55
+ x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.epsilon)
56
+ return (x * self.scale.float()).to(orig)
57
+
58
+
59
+ class GatedRMSNorm(nn.Module):
60
+ """RMSNorm with SiLU gating: rmsnorm(x * silu(residual))."""
61
+
62
+ def __init__(self, hidden_size: int, epsilon: float = 1e-6):
63
+ super().__init__()
64
+ self.epsilon = epsilon
65
+ self.weight = nn.Parameter(torch.ones(hidden_size))
66
+
67
+ def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
68
+ orig = x.dtype
69
+ x = x.float() * F.silu(residual.float())
70
+ x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.epsilon)
71
+ return (x * self.weight.float()).to(orig)
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Rotary Position Embedding
76
+ # ---------------------------------------------------------------------------
77
+
78
+ def precompute_freqs(maxlen: int, head_dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ half_d = head_dim // 2
80
+ pos = np.arange(maxlen)[:, None]
81
+ dim = np.arange(half_d)[None, :]
82
+ angles = pos * (1.0 / (10000 ** (dim / half_d)))
83
+ cos = torch.from_numpy(np.cos(angles).astype(np.float32))
84
+ sin = torch.from_numpy(np.sin(angles).astype(np.float32))
85
+ return cos, sin
86
+
87
+
88
+ def apply_rope(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
89
+ dim = tensor.shape[-1] // 2
90
+ cos = cos[..., :dim]
91
+ sin = sin[..., :dim]
92
+ x1, x2 = tensor[..., :dim], tensor[..., dim:]
93
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).view(tensor.shape)
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Transformer decoder block (GQA + RoPE)
98
+ # ---------------------------------------------------------------------------
99
+
100
+ class RotarySelfAttention(nn.Module):
101
+ def __init__(
102
+ self,
103
+ d_model: int,
104
+ num_heads: int,
105
+ num_kv_heads: Optional[int] = None,
106
+ dropout: float = 0.0,
107
+ qk_norm: bool = False,
108
+ qk_norm_eps: float = 1e-6,
109
+ ):
110
+ super().__init__()
111
+ self.d_model = d_model
112
+ self.num_heads = num_heads
113
+ self._num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
114
+ self.num_groups = num_heads // self._num_kv_heads
115
+ self.head_dim = d_model // num_heads
116
+ self.dropout = dropout
117
+ self.qk_norm = qk_norm
118
+ self.qk_norm_eps = qk_norm_eps
119
+
120
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
121
+ self.k_proj = nn.Linear(d_model, self._num_kv_heads * self.head_dim, bias=False)
122
+ self.v_proj = nn.Linear(d_model, self._num_kv_heads * self.head_dim, bias=False)
123
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
124
+
125
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
126
+ if self.num_groups == 1:
127
+ return x
128
+ b, l = x.shape[:2]
129
+ extra = x.shape[2:]
130
+ x = x.view(b, l, self._num_kv_heads, 1, *extra[1:])
131
+ x = x.expand(b, l, self._num_kv_heads, self.num_groups, *extra[1:])
132
+ return x.reshape(b, l, self.num_heads, *extra[1:])
133
+
134
+ def forward(
135
+ self,
136
+ x: torch.Tensor,
137
+ cos: torch.Tensor,
138
+ sin: torch.Tensor,
139
+ mask: Optional[torch.Tensor] = None,
140
+ return_state: bool = False,
141
+ ):
142
+ b, l = x.shape[0], x.shape[1]
143
+ q = self.q_proj(x).reshape(b, l, self.num_heads, self.head_dim)
144
+ k = self.k_proj(x).reshape(b, l, self._num_kv_heads, self.head_dim)
145
+ v = self.v_proj(x).reshape(b, l, self._num_kv_heads, self.head_dim)
146
+
147
+ q = apply_rope(q, cos, sin)
148
+ cos_kv = cos[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else cos
149
+ sin_kv = sin[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else sin
150
+ k = apply_rope(k, cos_kv, sin_kv)
151
+
152
+ if self.qk_norm:
153
+ q_f = q.float(); k_f = k.float()
154
+ q = (q_f * torch.rsqrt(q_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(q.dtype)
155
+ k = (k_f * torch.rsqrt(k_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(k.dtype)
156
+
157
+ k_cache = k; v_cache = v
158
+ k = self._repeat_kv(k); v = self._repeat_kv(v)
159
+
160
+ scale = 1.0 / math.sqrt(self.head_dim)
161
+ scores = torch.einsum("blhd,bmhd->bhlm", q, k) * scale
162
+ causal = torch.tril(torch.ones(l, l, dtype=torch.bool, device=x.device)).unsqueeze(0).unsqueeze(0)
163
+ scores = torch.where(causal, scores, torch.full_like(scores, -1e4))
164
+ attn = F.softmax(scores.float(), dim=-1).to(v.dtype)
165
+ if self.dropout > 0 and self.training:
166
+ attn = F.dropout(attn, p=self.dropout)
167
+ out = torch.einsum("bhql,blhd->bqhd", attn, v).reshape(b, l, self.d_model).to(x.dtype)
168
+
169
+ if return_state:
170
+ return self.out_proj(out), (k_cache, v_cache)
171
+ return self.out_proj(out)
172
+
173
+ def step(
174
+ self,
175
+ x_t: torch.Tensor,
176
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
177
+ pos: torch.Tensor,
178
+ cos_t: torch.Tensor,
179
+ sin_t: torch.Tensor,
180
+ seq_len: Optional[int] = None,
181
+ ) -> Tuple[torch.Tensor, Tuple]:
182
+ b = x_t.shape[0]
183
+ q = self.q_proj(x_t).reshape(b, 1, self.num_heads, self.head_dim)
184
+ k_new = self.k_proj(x_t).reshape(b, 1, self._num_kv_heads, self.head_dim)
185
+ v_new = self.v_proj(x_t).reshape(b, 1, self._num_kv_heads, self.head_dim)
186
+
187
+ q = apply_rope(q, cos_t, sin_t)
188
+ cos_kv = cos_t[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else cos_t
189
+ sin_kv = sin_t[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else sin_t
190
+ k_new = apply_rope(k_new, cos_kv, sin_kv)
191
+
192
+ if self.qk_norm:
193
+ q_f = q.float(); k_f = k_new.float()
194
+ q = (q_f * torch.rsqrt(q_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(q.dtype)
195
+ k_new = (k_f * torch.rsqrt(k_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(k_new.dtype)
196
+
197
+ k_cache, v_cache = kv_cache
198
+ pos_idx = pos.long().view(b, 1, 1, 1).expand(-1, 1, k_new.size(2), k_new.size(3))
199
+ k_cache.scatter_(1, pos_idx, k_new.to(k_cache.dtype))
200
+ v_cache.scatter_(1, pos_idx, v_new.to(v_cache.dtype))
201
+
202
+ if seq_len is not None:
203
+ k_slice, v_slice = k_cache[:, :seq_len], v_cache[:, :seq_len]; L = seq_len
204
+ else:
205
+ k_slice, v_slice = k_cache, v_cache; L = k_cache.shape[1]
206
+
207
+ k_r = self._repeat_kv(k_slice); v_r = self._repeat_kv(v_slice)
208
+ mask = torch.arange(L, device=k_cache.device).view(1, 1, 1, L) > pos.long().view(b, 1, 1, 1)
209
+ scale = 1.0 / math.sqrt(self.head_dim)
210
+ scores = torch.einsum("bqhd,bkhd->bhqk", q, k_r) * scale
211
+ scores = scores.masked_fill(mask, float("-inf"))
212
+ attn = F.softmax(scores.float(), dim=-1).to(v_r.dtype)
213
+ out = torch.einsum("bhqk,bkhd->bqhd", attn, v_r).reshape(b, self.d_model).to(x_t.dtype)
214
+ return self.out_proj(out), (k_cache, v_cache)
215
+
216
+
217
+ class TransformerDecoderBlock(nn.Module):
218
+ def __init__(
219
+ self,
220
+ d_model: int,
221
+ num_heads: int,
222
+ d_ff: int,
223
+ num_kv_heads: Optional[int] = None,
224
+ dropout: float = 0.0,
225
+ norm_eps: float = 1e-6,
226
+ qk_norm: bool = False,
227
+ qk_norm_eps: float = 1e-6,
228
+ ):
229
+ super().__init__()
230
+ self.norm1 = RMSNorm(d_model, epsilon=norm_eps)
231
+ self.attn = RotarySelfAttention(d_model, num_heads, num_kv_heads, dropout, qk_norm, qk_norm_eps)
232
+ self.norm2 = RMSNorm(d_model, epsilon=norm_eps)
233
+ self.ff_in = nn.Linear(d_model, 2 * d_ff, bias=True)
234
+ self.ff_out = nn.Linear(d_ff, d_model, bias=True)
235
+ self.dropout = dropout
236
+
237
+ def forward(self, x, cos, sin, mask=None, return_state=False):
238
+ y = self.norm1(x)
239
+ if return_state:
240
+ y, kv = self.attn(y, cos=cos, sin=sin, mask=mask, return_state=True)
241
+ else:
242
+ y = self.attn(y, cos=cos, sin=sin, mask=mask)
243
+ if self.dropout > 0 and self.training:
244
+ y = F.dropout(y, p=self.dropout)
245
+ x = x + y
246
+ y = self.norm2(x)
247
+ u, v = self.ff_in(y).chunk(2, dim=-1)
248
+ y = self.ff_out(F.silu(v) * u)
249
+ if self.dropout > 0 and self.training:
250
+ y = F.dropout(y, p=self.dropout)
251
+ out = x + y
252
+ return (out, kv) if return_state else out
253
+
254
+ def step(self, x_t, kv_cache, pos, cos_t, sin_t, seq_len=None):
255
+ y = self.norm1(x_t)
256
+ y, new_kv = self.attn.step(y, kv_cache, pos, cos_t, sin_t, seq_len=seq_len)
257
+ x_t = x_t + y
258
+ y = self.norm2(x_t)
259
+ u, v = self.ff_in(y).chunk(2, dim=-1)
260
+ return x_t + self.ff_out(F.silu(v) * u), new_kv
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # SeqCond attention block
265
+ # ---------------------------------------------------------------------------
266
+
267
+ class SeqCondAttention(nn.Module):
268
+ def __init__(
269
+ self,
270
+ d_model: int,
271
+ num_heads: int = 12,
272
+ num_query_heads: int = 6,
273
+ num_anchor_heads: int = 0,
274
+ num_thetas: int = 1,
275
+ conv_kernel_size: int = 4,
276
+ expand_factor: int = 1,
277
+ out_expand_factor: int = 3,
278
+ dropout: float = 0.0,
279
+ maxlen: Optional[int] = None,
280
+ **kwargs,
281
+ ):
282
+ super().__init__()
283
+ assert num_heads % num_query_heads == 0
284
+
285
+ self.d_model = d_model
286
+ self.K = num_heads
287
+ self.K_q = num_query_heads
288
+ self.n_rep = num_heads // num_query_heads
289
+ self.M = num_thetas
290
+ self.num_decay_heads = num_heads - num_anchor_heads
291
+ self.num_anchor_heads = num_anchor_heads
292
+ self.conv_kernel_size = conv_kernel_size
293
+ self.dropout_rate = dropout
294
+ self.maxlen = maxlen
295
+
296
+ d_inner = int(d_model * expand_factor)
297
+ self.H = max(1, d_inner // (self.K * self.M))
298
+ self.dim_memory = self.K * self.H
299
+ self.dim_query_head = self.H * self.M * 2
300
+ self.dim_query_total = self.K_q * self.dim_query_head
301
+ self.dim_expand = self.H * out_expand_factor
302
+ self.dim_swiglu_head = self.dim_expand * 2
303
+ self.dim_swiglu_total = self.K * self.dim_swiglu_head
304
+ self.dim_mem_total = self.dim_memory + self.K
305
+ self.dim_conv_total = self.dim_mem_total + self.dim_query_total
306
+
307
+ self.in_proj = nn.Linear(d_model, self.dim_conv_total, bias=False)
308
+ self.conv_weight = nn.Parameter(torch.empty(self.dim_conv_total, 1, conv_kernel_size))
309
+ nn.init.kaiming_normal_(self.conv_weight)
310
+
311
+ # Cached buffers (computed lazily)
312
+ self.register_buffer("_conv_kernel_t", None)
313
+ self.register_buffer("_theta_cached", None)
314
+ self.register_buffer("_w_int_cached", None)
315
+ self.register_buffer("_decay_slopes_cached", None)
316
+ self.register_buffer("_anchor_slopes_cached", None)
317
+ self.register_buffer("_phase_scale_b", None)
318
+ self.register_buffer("_score_scale_b", None)
319
+ self.register_buffer("_score_bias_b", None)
320
+ self._triton_out_re_buffer = None
321
+ self._triton_out_im_buffer = None
322
+ self._triton_norm_buffer = None
323
+
324
+ if self.M == 1:
325
+ init_theta = np.geomspace(0.001, 3.0, self.K).reshape(1, 1, self.K, 1, 1)
326
+ init_theta = np.tile(init_theta, (1, 1, 1, self.H, 1))
327
+ x = np.clip((init_theta - 0.001) / 2.999, 1e-4, 1 - 1e-4)
328
+ self.theta_raw = nn.Parameter(torch.from_numpy((np.log(x) - np.log(1 - x)).astype(np.float32)))
329
+ self.w_int_raw = nn.Parameter(torch.zeros(1, 1, self.K_q, self.n_rep, self.H, 1))
330
+ else:
331
+ init_vals = np.geomspace(0.001, 3.0, self.M).reshape(1, 1, 1, 1, self.M)
332
+ init_vals = np.tile(init_vals, (1, 1, self.K, self.H, 1))
333
+ self.theta_d_raw = nn.Parameter(torch.from_numpy(np.log(np.exp(init_vals) - 1.0 + 1e-4).astype(np.float32)))
334
+ self.w_int_raw = nn.Parameter(torch.zeros(1, 1, self.K_q, self.n_rep, self.H, self.M))
335
+
336
+ if self.num_decay_heads > 0:
337
+ self.decay_slopes = nn.Parameter(
338
+ torch.from_numpy(np.log(np.exp(np.geomspace(0.001, 0.1, self.num_decay_heads)) - 1).astype(np.float32))
339
+ )
340
+ if self.num_anchor_heads > 0:
341
+ self.anchor_slopes = nn.Parameter(
342
+ torch.from_numpy(np.log(np.exp(np.geomspace(0.01, 0.1, self.num_anchor_heads)) - 1).astype(np.float32))
343
+ )
344
+
345
+ self.score_scale = nn.Parameter(torch.ones(self.K))
346
+ self.score_bias = nn.Parameter(torch.zeros(self.K))
347
+ self.phase_scale = nn.Parameter(torch.ones(self.K))
348
+ self.gate_proj = nn.Linear(d_model, self.K * 2 * self.H, bias=False)
349
+ self.gated_norm = GatedRMSNorm(self.K * 2 * self.H)
350
+ self.W_readout = nn.Parameter(torch.empty(self.K, 2 * self.H, self.dim_swiglu_head))
351
+ nn.init.xavier_uniform_(self.W_readout)
352
+ self.out_proj = nn.Linear(self.dim_swiglu_total // 2, d_model, bias=False)
353
+
354
+ def forward(self, x: torch.Tensor, mask=None, return_state: bool = False):
355
+ B, L, D = x.shape
356
+ z_conv = self.in_proj(x)
357
+ z_conv_t = F.pad(z_conv.transpose(1, 2), (self.conv_kernel_size - 1, 0))
358
+ z_conv = F.silu(F.conv1d(z_conv_t, self.conv_weight, groups=self.dim_conv_total).transpose(1, 2))
359
+
360
+ z_mem = z_conv[..., : self.dim_mem_total]
361
+ q_raw = z_conv[..., self.dim_mem_total :]
362
+ k_val = z_mem[..., : self.dim_memory].reshape(B, L, self.K, self.H)
363
+ s_raw = z_mem[..., self.dim_memory :]
364
+ q_raw = q_raw.reshape(B, L, self.K_q, 1, self.H, self.M, 2)
365
+ q_re, q_im = q_raw[..., 0], q_raw[..., 1]
366
+
367
+ if self.M == 1:
368
+ theta = 0.001 + 2.999 * torch.sigmoid(self.theta_raw)
369
+ else:
370
+ theta_d = F.softplus(self.theta_d_raw) + 1e-4
371
+ theta_accum = torch.cumsum(theta_d, dim=-1)
372
+ theta = 0.001 + (theta_accum / theta_accum[..., -1:]) * 2.999
373
+
374
+ w_int = torch.exp(self.w_int_raw)
375
+ w_int = w_int / (w_int.sum(dim=-1, keepdim=True) + 1e-6)
376
+
377
+ pos = torch.arange(L, dtype=torch.float32, device=x.device)
378
+ log_w_list = []
379
+ if self.num_decay_heads > 0:
380
+ slopes = F.softplus(self.decay_slopes).view(1, 1, -1)
381
+ dist = torch.clamp((self.maxlen or L) - 1 - pos, min=0.0).view(1, L, 1)
382
+ log_w_list.append(-slopes * dist)
383
+ if self.num_anchor_heads > 0:
384
+ log_w_list.append(-F.softplus(self.anchor_slopes).view(1, 1, -1) * pos.view(1, L, 1))
385
+ log_tw = torch.cat(log_w_list, dim=2) if log_w_list else torch.zeros(1, L, self.K, device=x.device)
386
+
387
+ score_raw = self.score_scale.view(1, 1, -1) * s_raw.float() + self.score_bias.view(1, 1, -1)
388
+ p_w = (F.softplus(score_raw) * torch.exp(log_tw)).clamp(1e-4, 5000.0)
389
+
390
+ k_f32 = k_val.float().unsqueeze(-1)
391
+ p_w_b = p_w.unsqueeze(-1).unsqueeze(-1)
392
+ phase_scale_b = self.phase_scale.view(1, 1, self.K, 1, 1)
393
+ k_scaled = k_f32 * phase_scale_b
394
+ phi = (k_scaled / (1.0 + k_scaled.abs())) * theta
395
+ kvw = k_f32 * p_w_b
396
+ re = kvw * torch.cos(phi)
397
+ im = kvw * torch.sin(phi)
398
+
399
+ flat_size = self.K * self.H * self.M
400
+ stack = torch.cat([p_w.float(), re.reshape(B, L, -1), im.reshape(B, L, -1)], dim=-1)
401
+ cumsum = torch.cumsum(stack, dim=1)
402
+ den_acc = cumsum[..., : self.K]
403
+ re_acc = cumsum[..., self.K : self.K + flat_size].reshape(B, L, self.K, self.H, self.M)
404
+ im_acc = cumsum[..., self.K + flat_size :].reshape(B, L, self.K, self.H, self.M)
405
+
406
+ inv_den = (1.0 / torch.clamp(den_acc, min=1e-4)).unsqueeze(-1).unsqueeze(-1)
407
+ state_re_g = (re_acc * inv_den).reshape(B, L, self.K_q, self.n_rep, self.H, self.M)
408
+ state_im_g = (im_acc * inv_den).reshape(B, L, self.K_q, self.n_rep, self.H, self.M)
409
+
410
+ scale = 1.0 / (self.H ** 0.5)
411
+ match_re = ((state_re_g * q_re + state_im_g * q_im) * scale).float()
412
+ match_im = ((state_im_g * q_re - state_re_g * q_im) * scale).float()
413
+ out_re = ((match_re * w_int.float()).sum(dim=-1)).reshape(B, L, self.K, self.H).to(x.dtype)
414
+ out_im = ((match_im * w_int.float()).sum(dim=-1)).reshape(B, L, self.K, self.H).to(x.dtype)
415
+ out_complex = self.gated_norm(torch.cat([out_re, out_im], dim=-1).reshape(B, L, -1), self.gate_proj(x))
416
+ out_complex = out_complex.reshape(B, L, self.K, 2 * self.H)
417
+
418
+ y_raw = torch.einsum("blkf,kfn->blkn", out_complex, self.W_readout.to(out_complex.dtype))
419
+ y_val, y_gate = y_raw.chunk(2, dim=-1)
420
+ output = self.out_proj((y_val * torch.sigmoid(y_gate)).reshape(B, L, -1).to(x.dtype))
421
+
422
+ if return_state:
423
+ z_pre = self.in_proj(x)
424
+ buf_sz = self.conv_kernel_size - 1
425
+ conv_buf = z_pre[:, -buf_sz:] if L >= buf_sz else torch.cat([
426
+ torch.zeros(B, buf_sz - L, self.dim_conv_total, device=x.device, dtype=z_pre.dtype), z_pre], dim=1)
427
+ state = (
428
+ p_w.sum(dim=1),
429
+ re_acc[:, -1],
430
+ im_acc[:, -1],
431
+ torch.full((B,), L, dtype=torch.float32, device=x.device),
432
+ conv_buf,
433
+ )
434
+ return output, state
435
+ return output
436
+
437
+ def step(self, x_t: torch.Tensor, state: Tuple, use_triton: bool = False) -> Tuple:
438
+ B, D = x_t.shape
439
+ den_acc, re_acc, im_acc, pos, conv_buffer = state
440
+
441
+ z_conv = self.in_proj(x_t)
442
+
443
+ if self._conv_kernel_t is None or self._conv_kernel_t.device != z_conv.device:
444
+ self._conv_kernel_t = self.conv_weight[:, 0, :].t().contiguous()
445
+
446
+ conv_input = torch.cat([conv_buffer, z_conv.unsqueeze(1)], dim=1)
447
+ z_conv_act = F.silu((conv_input * self._conv_kernel_t).sum(dim=1))
448
+
449
+ z_mem = z_conv_act[..., : self.dim_mem_total]
450
+ q_raw = z_conv_act[..., self.dim_mem_total :]
451
+ k_val = z_mem[..., : self.dim_memory].reshape(B, self.K, self.H)
452
+ s_raw = z_mem[..., self.dim_memory :]
453
+ q_raw = q_raw.reshape(B, self.K_q, 1, self.H, self.M, 2)
454
+ q_re, q_im = q_raw[..., 0], q_raw[..., 1]
455
+
456
+ if self._theta_cached is None:
457
+ if self.M == 1:
458
+ self._theta_cached = (0.001 + 2.999 * torch.sigmoid(self.theta_raw))[0, 0]
459
+ else:
460
+ theta_d = F.softplus(self.theta_d_raw) + 1e-4
461
+ theta_accum = torch.cumsum(theta_d, dim=-1)
462
+ self._theta_cached = (0.001 + (theta_accum / theta_accum[..., -1:]) * 2.999)[0, 0]
463
+ w = torch.exp(self.w_int_raw)
464
+ self._w_int_cached = w / (w.sum(dim=-1, keepdim=True) + 1e-6)
465
+ self._w_int_cached = self._w_int_cached[0, 0]
466
+ theta = self._theta_cached
467
+ w_int = self._w_int_cached
468
+
469
+ if self._decay_slopes_cached is None and self.num_decay_heads > 0:
470
+ self._decay_slopes_cached = F.softplus(self.decay_slopes).view(1, -1)
471
+ if self._anchor_slopes_cached is None and self.num_anchor_heads > 0:
472
+ self._anchor_slopes_cached = F.softplus(self.anchor_slopes).view(1, -1)
473
+ if self._score_scale_b is None:
474
+ self._score_scale_b = self.score_scale.view(1, -1)
475
+ self._score_bias_b = self.score_bias.view(1, -1)
476
+ self._phase_scale_b = self.phase_scale.view(1, self.K, 1, 1)
477
+
478
+ log_w_list = []
479
+ if self.num_decay_heads > 0:
480
+ dist = (self.maxlen or 2048) - 1 - pos.unsqueeze(-1)
481
+ log_w_list.append(-self._decay_slopes_cached * dist.clamp(min=0.0))
482
+ if self.num_anchor_heads > 0:
483
+ log_w_list.append(-self._anchor_slopes_cached * pos.unsqueeze(-1))
484
+ log_tw = torch.cat(log_w_list, dim=1) if log_w_list else torch.zeros(B, self.K, device=x_t.device)
485
+
486
+ if (
487
+ use_triton
488
+ and x_t.is_cuda
489
+ and self.n_rep == 1
490
+ and TRITON_AVAILABLE
491
+ and seqcond_step_triton is not None
492
+ ):
493
+ if (
494
+ self._triton_out_re_buffer is None
495
+ or self._triton_out_re_buffer.shape != (B, self.K, self.H)
496
+ or self._triton_out_re_buffer.device != x_t.device
497
+ ):
498
+ self._triton_out_re_buffer = torch.empty(
499
+ B, self.K, self.H, device=x_t.device, dtype=torch.float32
500
+ )
501
+ self._triton_out_im_buffer = torch.empty_like(
502
+ self._triton_out_re_buffer
503
+ )
504
+ out_re, out_im = seqcond_step_triton(
505
+ k_val,
506
+ s_raw,
507
+ q_re.squeeze(2),
508
+ q_im.squeeze(2),
509
+ re_acc,
510
+ im_acc,
511
+ den_acc,
512
+ theta,
513
+ w_int,
514
+ self.phase_scale,
515
+ self.score_scale,
516
+ self.score_bias,
517
+ log_tw,
518
+ out_re_buffer=self._triton_out_re_buffer,
519
+ out_im_buffer=self._triton_out_im_buffer,
520
+ )
521
+ out_complex = torch.cat([out_re, out_im], dim=-1)
522
+ else:
523
+ score_raw = self._score_scale_b * s_raw.float() + self._score_bias_b
524
+ p_w = (F.softplus(score_raw) * torch.exp(log_tw)).clamp(1e-4, 5000.0)
525
+ k_f32 = k_val.float().unsqueeze(-1)
526
+ k_scaled = k_f32 * self._phase_scale_b
527
+ phi = (k_scaled / (1.0 + k_scaled.abs())) * theta
528
+ kvw = k_f32 * p_w.unsqueeze(-1).unsqueeze(-1)
529
+ re = kvw * torch.cos(phi)
530
+ im = kvw * torch.sin(phi)
531
+ den_acc.add_(p_w); re_acc.add_(re); im_acc.add_(im)
532
+ inv_den = (1.0 / torch.clamp(den_acc, min=1e-4)).unsqueeze(-1).unsqueeze(-1)
533
+ state_re_g = (re_acc * inv_den).reshape(B, self.K_q, self.n_rep, self.H, self.M)
534
+ state_im_g = (im_acc * inv_den).reshape(B, self.K_q, self.n_rep, self.H, self.M)
535
+ scale = 1.0 / (self.H ** 0.5)
536
+ match_re = ((state_re_g * q_re + state_im_g * q_im) * scale).float()
537
+ match_im = ((state_im_g * q_re - state_re_g * q_im) * scale).float()
538
+ out_re = ((match_re * w_int.float()).sum(-1)).reshape(B, self.K, self.H).to(x_t.dtype)
539
+ out_im = ((match_im * w_int.float()).sum(-1)).reshape(B, self.K, self.H).to(x_t.dtype)
540
+ out_complex = torch.cat([out_re, out_im], dim=-1)
541
+
542
+ out_complex = out_complex.reshape(B, self.K, 2 * self.H)
543
+ out_complex_flat = out_complex.reshape(B, -1)
544
+ gate_for_norm = self.gate_proj(x_t)
545
+ if use_triton and x_t.is_cuda and gated_rmsnorm_triton is not None:
546
+ if (
547
+ self._triton_norm_buffer is None
548
+ or self._triton_norm_buffer.shape != out_complex_flat.shape
549
+ or self._triton_norm_buffer.device != x_t.device
550
+ ):
551
+ self._triton_norm_buffer = torch.empty(
552
+ out_complex_flat.shape,
553
+ device=x_t.device,
554
+ dtype=torch.float32,
555
+ )
556
+ out_flat = gated_rmsnorm_triton(
557
+ out_complex_flat,
558
+ gate_for_norm,
559
+ self.gated_norm.weight,
560
+ self.gated_norm.epsilon,
561
+ out_buffer=self._triton_norm_buffer,
562
+ )
563
+ else:
564
+ out_flat = self.gated_norm(out_complex_flat, gate_for_norm)
565
+ out_complex = out_flat.to(x_t.dtype).reshape(B, self.K, 2 * self.H)
566
+ y_raw = torch.einsum("bkf,kfn->bkn", out_complex, self.W_readout.to(out_complex.dtype))
567
+ y_val, y_gate = y_raw.chunk(2, dim=-1)
568
+ out = self.out_proj((y_val * torch.sigmoid(y_gate)).reshape(B, -1).to(x_t.dtype))
569
+
570
+ pos.add_(1).clamp_(max=(self.maxlen or 2048) - 1)
571
+ if self.conv_kernel_size > 1:
572
+ if self.conv_kernel_size > 2:
573
+ conv_buffer[:, :-1, :].copy_(conv_buffer[:, 1:, :].clone())
574
+ conv_buffer[:, -1, :].copy_(z_conv)
575
+
576
+ return out, (den_acc, re_acc, im_acc, pos, conv_buffer)
577
+
578
+
579
+ class SeqCondBlock(nn.Module):
580
+ def __init__(self, d_model: int, norm_eps: float = 1e-6, **kwargs):
581
+ super().__init__()
582
+ self.norm = RMSNorm(d_model, epsilon=norm_eps)
583
+ self.attn = SeqCondAttention(d_model=d_model, **kwargs)
584
+
585
+ def forward(self, x, mask=None, return_state=False):
586
+ if return_state:
587
+ out, state = self.attn(self.norm(x), mask=mask, return_state=True)
588
+ return x + out, state
589
+ return x + self.attn(self.norm(x), mask=mask)
590
+
591
+ def step(self, x_t, state, use_triton=False):
592
+ out, new_state = self.attn.step(self.norm(x_t), state, use_triton=use_triton)
593
+ return x_t + out, new_state
594
+
595
+
596
+ # ---------------------------------------------------------------------------
597
+ # Core SeqCond language model
598
+ # ---------------------------------------------------------------------------
599
+
600
+ class SeqCondModel(nn.Module):
601
+ """Core SeqCond model (no HF wrapper). Used internally by SeqCondForCausalLM."""
602
+
603
+ def __init__(self, config: SeqCondConfig):
604
+ super().__init__()
605
+ self.d_model = config.d_model
606
+ self.d_ff = config.d_ff
607
+ self.num_layers = config.num_layers
608
+ self.vocab_size = config.vocab_size
609
+ self.maxlen = config.maxlen
610
+ self.num_heads = config.num_heads
611
+ self.num_kv_heads = config.num_kv_heads if config.num_kv_heads is not None else config.num_heads
612
+ self.seqcond_ratio = config.seqcond_ratio
613
+
614
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
615
+
616
+ self.use_positional_embedding = config.use_positional_embedding
617
+ if config.use_positional_embedding:
618
+ self.position_embedding = nn.Embedding(config.maxlen, config.d_model)
619
+
620
+ head_dim = config.d_model // config.num_heads
621
+ cos, sin = precompute_freqs(config.maxlen, head_dim)
622
+ self.register_buffer("cos_emb", cos)
623
+ self.register_buffer("sin_emb", sin)
624
+
625
+ self.blocks = nn.ModuleList()
626
+ self.block_types = []
627
+ for i in range(config.num_layers):
628
+ if (i + 1) % (config.seqcond_ratio + 1) == 0:
629
+ block = TransformerDecoderBlock(
630
+ d_model=config.d_model,
631
+ num_heads=config.num_heads,
632
+ d_ff=config.d_ff,
633
+ num_kv_heads=self.num_kv_heads,
634
+ dropout=config.dropout,
635
+ qk_norm=config.qk_norm,
636
+ qk_norm_eps=config.qk_norm_eps,
637
+ )
638
+ self.block_types.append("transformer")
639
+ else:
640
+ block = SeqCondBlock(
641
+ d_model=config.d_model,
642
+ num_heads=config.seqcond_heads,
643
+ num_query_heads=config.num_query_heads,
644
+ num_anchor_heads=config.num_anchor_heads,
645
+ num_thetas=config.num_thetas,
646
+ conv_kernel_size=config.conv_kernel_size,
647
+ expand_factor=config.expand_factor,
648
+ out_expand_factor=config.out_expand_factor,
649
+ dropout=config.dropout,
650
+ maxlen=config.maxlen,
651
+ )
652
+ self.block_types.append("seqcond")
653
+ self.blocks.append(block)
654
+
655
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
656
+ if config.tie_weights:
657
+ self.lm_head.weight = self.embedding.weight
658
+
659
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
660
+ B, L = input_ids.shape
661
+ x = self.embedding(input_ids)
662
+ if self.use_positional_embedding:
663
+ x = x + self.position_embedding(torch.arange(L, device=input_ids.device))
664
+ cos = self.cos_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1)
665
+ sin = self.sin_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1)
666
+ for block, bt in zip(self.blocks, self.block_types):
667
+ x = block(x, cos, sin) if bt == "transformer" else block(x)
668
+ return self.lm_head(x)
669
+
670
+ def prefill(self, input_ids: torch.Tensor, return_all_logits: bool = False):
671
+ B, L = input_ids.shape
672
+ device = input_ids.device
673
+ x = self.embedding(input_ids)
674
+ if self.use_positional_embedding:
675
+ x = x + self.position_embedding(torch.arange(L, device=device))
676
+ cos = self.cos_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1)
677
+ sin = self.sin_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1)
678
+ states = []
679
+ for block, bt in zip(self.blocks, self.block_types):
680
+ if bt == "transformer":
681
+ x, kv = block(x, cos, sin, return_state=True)
682
+ k, v = kv
683
+ k_cache = torch.zeros(B, self.maxlen, self.num_kv_heads, self.d_model // self.num_heads, device=device, dtype=k.dtype)
684
+ v_cache = torch.zeros_like(k_cache)
685
+ k_cache[:, :L] = k; v_cache[:, :L] = v
686
+ states.append((k_cache, v_cache))
687
+ else:
688
+ x, state = block(x, return_state=True)
689
+ states.append(state)
690
+ logits = self.lm_head(x)
691
+ if return_all_logits:
692
+ return logits, states
693
+ return logits[:, -1:, :], states
694
+
695
+ def init_state(self, batch_size: int, device: torch.device) -> List:
696
+ states = []
697
+ for block, bt in zip(self.blocks, self.block_types):
698
+ if bt == "transformer":
699
+ k = torch.zeros(batch_size, self.maxlen, self.num_kv_heads, self.d_model // self.num_heads, device=device)
700
+ states.append((k, torch.zeros_like(k)))
701
+ else:
702
+ a = block.attn
703
+ states.append((
704
+ torch.zeros(batch_size, a.K, device=device),
705
+ torch.zeros(batch_size, a.K, a.H, a.M, device=device),
706
+ torch.zeros(batch_size, a.K, a.H, a.M, device=device),
707
+ torch.zeros(batch_size, device=device),
708
+ torch.zeros(batch_size, a.conv_kernel_size - 1, a.dim_conv_total, device=device),
709
+ ))
710
+ return states
711
+
712
+ def step(self, token_id: torch.Tensor, states: List, pos=None, seq_len=None, use_triton=False):
713
+ B = token_id.size(0)
714
+ if pos is None:
715
+ for state, bt in zip(states, self.block_types):
716
+ if bt == "seqcond":
717
+ pos = state[3]; break
718
+ if pos is None:
719
+ pos = torch.zeros(B, device=token_id.device, dtype=torch.long)
720
+
721
+ x = self.embedding(token_id).squeeze(1)
722
+ pos = pos.clamp(max=self.maxlen - 1)
723
+ if self.use_positional_embedding:
724
+ x = x + torch.index_select(self.position_embedding.weight, 0, pos.long())
725
+
726
+ pos_idx = pos.long()
727
+ cos_t = torch.index_select(self.cos_emb, 0, pos_idx).unsqueeze(1).unsqueeze(1).expand(B, 1, self.num_heads, -1)
728
+ sin_t = torch.index_select(self.sin_emb, 0, pos_idx).unsqueeze(1).unsqueeze(1).expand(B, 1, self.num_heads, -1)
729
+
730
+ new_states = []
731
+ for block, bt, state in zip(self.blocks, self.block_types, states):
732
+ if bt == "transformer":
733
+ x, ns = block.step(x, state, pos, cos_t, sin_t, seq_len=seq_len)
734
+ else:
735
+ x, ns = block.step(x, state, use_triton=use_triton)
736
+ new_states.append(ns)
737
+
738
+ return self.lm_head(x), new_states
739
+
740
+
741
+ # ---------------------------------------------------------------------------
742
+ # HuggingFace wrapper
743
+ # ---------------------------------------------------------------------------
744
+
745
+ class SeqCondPreTrainedModel(PreTrainedModel):
746
+ config_class = SeqCondConfig
747
+ base_model_prefix = "model"
748
+ supports_gradient_checkpointing = False
749
+
750
+ def _init_weights(self, module):
751
+ if isinstance(module, nn.Linear):
752
+ nn.init.normal_(module.weight, std=0.02)
753
+ if module.bias is not None:
754
+ nn.init.zeros_(module.bias)
755
+ elif isinstance(module, nn.Embedding):
756
+ nn.init.normal_(module.weight, std=0.02)
757
+
758
+
759
+ class SeqCondForCausalLM(SeqCondPreTrainedModel):
760
+ """
761
+ SeqCond causal language model, HuggingFace-compatible.
762
+
763
+ Supports:
764
+ - Standard HF forward() for training / perplexity evaluation.
765
+ - Custom generate() using state-based O(1) decoding.
766
+ - generate_batch() for batched generation with per-sample early stopping.
767
+ """
768
+
769
+ def __init__(self, config: SeqCondConfig):
770
+ super().__init__(config)
771
+ self.model = SeqCondModel(config)
772
+ self.post_init()
773
+
774
+ def get_input_embeddings(self):
775
+ return self.model.embedding
776
+
777
+ def set_input_embeddings(self, value):
778
+ self.model.embedding = value
779
+
780
+ def get_output_embeddings(self):
781
+ return self.model.lm_head
782
+
783
+ def set_output_embeddings(self, value):
784
+ self.model.lm_head = value
785
+
786
+ def forward(
787
+ self,
788
+ input_ids: Optional[torch.LongTensor] = None,
789
+ attention_mask: Optional[torch.Tensor] = None,
790
+ labels: Optional[torch.LongTensor] = None,
791
+ **kwargs,
792
+ ) -> CausalLMOutputWithPast:
793
+ """
794
+ Standard forward pass (used for training / perplexity).
795
+
796
+ Note: attention_mask is accepted for API compatibility but is not used
797
+ in the forward pass — SeqCond is always causal.
798
+ """
799
+ logits = self.model(input_ids)
800
+
801
+ loss = None
802
+ if labels is not None:
803
+ shift_logits = logits[..., :-1, :].contiguous()
804
+ shift_labels = labels[..., 1:].contiguous()
805
+ loss = F.cross_entropy(
806
+ shift_logits.view(-1, shift_logits.size(-1)),
807
+ shift_labels.view(-1),
808
+ )
809
+
810
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
811
+
812
+ @torch.no_grad()
813
+ def generate(
814
+ self,
815
+ input_ids: torch.LongTensor,
816
+ max_new_tokens: int = 256,
817
+ temperature: float = 0.7,
818
+ top_p: float = 0.9,
819
+ top_k: int = 50,
820
+ repetition_penalty: float = 1.0,
821
+ eos_token_id: Optional[int] = None,
822
+ use_triton: bool = False,
823
+ **kwargs,
824
+ ) -> torch.LongTensor:
825
+ """
826
+ Autoregressive generation with state-based O(1) decoding.
827
+
828
+ Returns the full sequence (prompt + generated tokens) as a LongTensor.
829
+ """
830
+ if eos_token_id is None:
831
+ eos_token_id = self.config.eos_token_id
832
+
833
+ device = input_ids.device
834
+ B = input_ids.size(0)
835
+
836
+ # Prefill
837
+ logits, states = self.model.prefill(input_ids)
838
+ logits = logits.squeeze(1) # (B, vocab)
839
+
840
+ generated = input_ids.tolist()
841
+ finished = [False] * B
842
+ token_buf = torch.zeros((B, 1), dtype=torch.long, device=device)
843
+ seq_len = input_ids.size(1)
844
+
845
+ for _ in range(max_new_tokens):
846
+ # Temperature scaling
847
+ if temperature > 0:
848
+ ls = logits / temperature
849
+ else:
850
+ ls = logits.clone()
851
+
852
+ # Repetition penalty
853
+ if repetition_penalty != 1.0:
854
+ for bi, toks in enumerate(generated):
855
+ for t in set(toks):
856
+ if 0 <= t < self.config.vocab_size:
857
+ ls[bi, t] /= repetition_penalty
858
+
859
+ # Sampling
860
+ if temperature == 0:
861
+ next_tokens = torch.argmax(ls, dim=-1)
862
+ else:
863
+ if top_k > 0:
864
+ kth = torch.topk(ls, top_k, dim=-1).values[:, -1:]
865
+ ls = ls.masked_fill(ls < kth, float("-inf"))
866
+ if top_p < 1.0:
867
+ sorted_ls, sorted_idx = torch.sort(ls, dim=-1, descending=True)
868
+ cum_probs = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1)
869
+ sorted_remove = cum_probs > top_p
870
+ sorted_remove[:, 1:] = sorted_remove[:, :-1].clone()
871
+ sorted_remove[:, 0] = False
872
+ remove = torch.zeros_like(sorted_remove)
873
+ remove.scatter_(1, sorted_idx, sorted_remove)
874
+ ls = ls.masked_fill(remove, float("-inf"))
875
+ probs = F.softmax(ls, dim=-1)
876
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
877
+
878
+ for bi in range(B):
879
+ tok = next_tokens[bi].item()
880
+ generated[bi].append(tok)
881
+ if eos_token_id is not None and tok == eos_token_id:
882
+ finished[bi] = True
883
+ token_buf[bi, 0] = tok
884
+
885
+ if all(finished):
886
+ break
887
+
888
+ seq_len += 1
889
+ logits, states = self.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
890
+
891
+ max_len = max(len(g) for g in generated)
892
+ pad_id = self.config.pad_token_id or 0
893
+ out = torch.full((B, max_len), pad_id, dtype=torch.long, device=device)
894
+ for bi, g in enumerate(generated):
895
+ out[bi, : len(g)] = torch.tensor(g, dtype=torch.long, device=device)
896
+ return out
897
+
898
+ @torch.no_grad()
899
+ def generate_batch(
900
+ self,
901
+ input_ids_list: List[torch.LongTensor],
902
+ max_new_tokens: int = 256,
903
+ temperature: float = 0.7,
904
+ eos_token_id: Optional[int] = None,
905
+ use_triton: bool = False,
906
+ ) -> List[List[int]]:
907
+ """
908
+ Batched generation: each prompt is prefilled independently, then
909
+ decoded in lockstep with per-sample early stopping.
910
+
911
+ Args:
912
+ input_ids_list: List of 1D LongTensors, one per prompt.
913
+ Returns:
914
+ List of generated token id lists (completion only, EOS stripped).
915
+ """
916
+ if eos_token_id is None:
917
+ eos_token_id = self.config.eos_token_id
918
+
919
+ device = input_ids_list[0].device
920
+ B = len(input_ids_list)
921
+
922
+ # Per-sample prefill
923
+ all_logits, all_states = [], []
924
+ for ids in input_ids_list:
925
+ lg, st = self.model.prefill(ids.unsqueeze(0))
926
+ all_logits.append(lg.squeeze(1))
927
+ all_states.append(st)
928
+
929
+ logits = torch.cat(all_logits, dim=0)
930
+ # Stack states
931
+ num_blocks = len(all_states[0])
932
+ states = [
933
+ tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i])))
934
+ for i in range(num_blocks)
935
+ ]
936
+
937
+ generated = [[] for _ in range(B)]
938
+ finished = [False] * B
939
+ active_map = list(range(B))
940
+ token_buf = torch.zeros((B, 1), dtype=torch.long, device=device)
941
+ seq_len = max(ids.size(0) for ids in input_ids_list)
942
+
943
+ for _ in range(max_new_tokens):
944
+ B_cur = len(active_map)
945
+ if B_cur == 0:
946
+ break
947
+
948
+ if temperature == 0:
949
+ next_tokens = torch.argmax(logits, dim=-1)
950
+ else:
951
+ probs = F.softmax(logits / temperature, dim=-1)
952
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
953
+
954
+ newly_done = set()
955
+ for bi in range(B_cur):
956
+ oi = active_map[bi]
957
+ tok = next_tokens[bi].item()
958
+ generated[oi].append(tok)
959
+ if eos_token_id is not None and tok == eos_token_id:
960
+ finished[oi] = True
961
+ newly_done.add(bi)
962
+ else:
963
+ token_buf[bi, 0] = tok
964
+
965
+ if all(finished):
966
+ break
967
+
968
+ if newly_done:
969
+ keep = [bi for bi in range(B_cur) if bi not in newly_done]
970
+ if not keep:
971
+ break
972
+ keep_idx = torch.tensor(keep, device=device)
973
+ token_buf = token_buf[keep_idx].contiguous()
974
+ states = [tuple(s[keep_idx].contiguous() for s in st) for st in states]
975
+ active_map = [active_map[bi] for bi in keep]
976
+
977
+ seq_len += 1
978
+ logits, states = self.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
979
+
980
+ results = []
981
+ for toks in generated:
982
+ if toks and toks[-1] == eos_token_id:
983
+ toks = toks[:-1]
984
+ results.append(toks)
985
+ return results
tokenization_seqcond.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SeqCond tokenizer — tiktoken cl100k_base with 4 additional special tokens.
3
+
4
+ Special tokens (assigned in order after the base vocab):
5
+ <|im_start|> — marks the start of a chat turn
6
+ <|im_end|> — marks the end of a chat turn (also used as EOS)
7
+ <|think_start|> — marks the start of chain-of-thought reasoning
8
+ <|think_end|> — marks the end of chain-of-thought reasoning
9
+
10
+ Chat template:
11
+ <|im_start|>user
12
+ {prompt}
13
+ <|im_end|><|im_start|>assistant
14
+ <|think_start|>{thinking}<|think_end|>
15
+ {answer}
16
+ <|im_end|>
17
+ """
18
+
19
+ import os
20
+ from typing import Dict, List, Optional, Tuple
21
+
22
+ from transformers import PreTrainedTokenizer
23
+
24
+ _SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>", "<|think_start|>", "<|think_end|>"]
25
+ _SPECIAL_TOKEN_IDS = {
26
+ "<|im_start|>": 100278,
27
+ "<|im_end|>": 100279,
28
+ "<|think_start|>": 100280,
29
+ "<|think_end|>": 100281,
30
+ "<|endoftext|>": 100282,
31
+ "<|fim_prefix|>": 100283,
32
+ "<|fim_middle|>": 100284,
33
+ "<|fim_suffix|>": 100285,
34
+ "<|endofprompt|>": 100286,
35
+ }
36
+ _BASE_VOCAB_SIZE = 100256
37
+ _VOCAB_SIZE = max(_SPECIAL_TOKEN_IDS.values()) + 1
38
+
39
+
40
+ def _build_tiktoken_enc():
41
+ """Build tiktoken encoding with SeqCond special tokens."""
42
+ try:
43
+ import tiktoken
44
+ except ImportError as e:
45
+ raise ImportError("tiktoken is required: pip install tiktoken") from e
46
+
47
+ base = tiktoken.get_encoding("cl100k_base")
48
+ return tiktoken.Encoding(
49
+ name="seqcond",
50
+ pat_str=base._pat_str,
51
+ mergeable_ranks=base._mergeable_ranks,
52
+ special_tokens=_SPECIAL_TOKEN_IDS,
53
+ )
54
+
55
+
56
+ class SeqCondTokenizer(PreTrainedTokenizer):
57
+ """
58
+ Tokenizer for SeqCond models, backed by tiktoken cl100k_base.
59
+
60
+ This is a slow tokenizer that wraps tiktoken. Tokens are represented
61
+ internally as their stringified integer IDs (e.g. "42", "100256").
62
+ This avoids building a full vocab dict while remaining compatible with
63
+ HuggingFace's PreTrainedTokenizer interface.
64
+
65
+ Requires: pip install tiktoken
66
+ """
67
+
68
+ vocab_files_names: Dict[str, str] = {}
69
+ model_input_names = ["input_ids", "attention_mask"]
70
+
71
+ def __init__(
72
+ self,
73
+ eos_token: str = "<|im_end|>",
74
+ bos_token: Optional[str] = None,
75
+ unk_token: Optional[str] = None,
76
+ pad_token: str = "<|im_end|>",
77
+ add_bos_token: bool = False,
78
+ **kwargs,
79
+ ):
80
+ self._enc = _build_tiktoken_enc()
81
+ self._id_to_special: Dict[int, str] = {idx: tok for tok, idx in _SPECIAL_TOKEN_IDS.items()}
82
+ self._special_to_id: Dict[str, int] = {v: k for k, v in self._id_to_special.items()}
83
+
84
+ # Register special tokens before calling super().__init__
85
+ kwargs.setdefault("additional_special_tokens", [t for t in _SPECIAL_TOKENS if t not in (eos_token, bos_token, unk_token, pad_token)])
86
+
87
+ super().__init__(
88
+ eos_token=eos_token,
89
+ bos_token=bos_token,
90
+ unk_token=unk_token,
91
+ pad_token=pad_token,
92
+ add_bos_token=add_bos_token,
93
+ **kwargs,
94
+ )
95
+
96
+ @property
97
+ def vocab_size(self) -> int:
98
+ return _VOCAB_SIZE
99
+
100
+ # ------------------------------------------------------------------
101
+ # Core token ↔ id mappings
102
+ # ------------------------------------------------------------------
103
+
104
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
105
+ """Encode text into a list of token-id strings."""
106
+ ids = self._enc.encode(text, allowed_special="all")
107
+ # Shift non-special BPE IDs by +1 to match convectors.Tiktokenize
108
+ # offset used during training (ID 0 reserved).
109
+ shifted = [i if i in self._id_to_special else i + 1 for i in ids]
110
+ return [str(i) for i in shifted]
111
+
112
+ def _convert_token_to_id(self, token: str) -> int:
113
+ """Convert a token string (or id-string) to an integer id."""
114
+ if token in self._special_to_id:
115
+ return self._special_to_id[token]
116
+ try:
117
+ return int(token)
118
+ except ValueError:
119
+ return 0
120
+
121
+ def _convert_id_to_token(self, index: int) -> str:
122
+ """Convert an integer id to its token string."""
123
+ if index in self._id_to_special:
124
+ return self._id_to_special[index]
125
+ return str(index)
126
+
127
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
128
+ """Decode a list of token strings back to text."""
129
+ ids = []
130
+ for t in tokens:
131
+ if t in self._special_to_id:
132
+ ids.append(self._special_to_id[t])
133
+ else:
134
+ try:
135
+ ids.append(int(t))
136
+ except ValueError:
137
+ pass
138
+ # Reverse the +1 BPE shift before decoding; skip invalid/ID 0 tokens.
139
+ real_ids = []
140
+ for i in ids:
141
+ if i in self._id_to_special:
142
+ real_ids.append(i)
143
+ elif i >= 1:
144
+ real_ids.append(i - 1)
145
+ return self._enc.decode(real_ids)
146
+
147
+ def get_vocab(self) -> Dict[str, int]:
148
+ """
149
+ Return a vocab dict. Only special tokens are included with their names;
150
+ regular BPE tokens are included as their id-string representation.
151
+ (Building a full 100k-entry reverse BPE map is expensive and rarely needed.)
152
+ """
153
+ vocab = {str(i): i for i in range(self.vocab_size)}
154
+ for tok, idx in self._special_to_id.items():
155
+ vocab[tok] = idx
156
+ return vocab
157
+
158
+ def save_vocabulary(
159
+ self, save_directory: str, filename_prefix: Optional[str] = None
160
+ ) -> Tuple[str, ...]:
161
+ """
162
+ No vocabulary file is needed — the tiktoken encoding is fetched from
163
+ the tiktoken package at runtime. Returns an empty tuple.
164
+ """
165
+ return ()
166
+
167
+ # ------------------------------------------------------------------
168
+ # Convenience helpers
169
+ # ------------------------------------------------------------------
170
+
171
+ @property
172
+ def im_start_id(self) -> int:
173
+ return self._special_to_id["<|im_start|>"]
174
+
175
+ @property
176
+ def im_end_id(self) -> int:
177
+ return self._special_to_id["<|im_end|>"]
178
+
179
+ @property
180
+ def think_start_id(self) -> int:
181
+ return self._special_to_id["<|think_start|>"]
182
+
183
+ @property
184
+ def think_end_id(self) -> int:
185
+ return self._special_to_id["<|think_end|>"]
186
+
187
+ def encode_chat(self, prompt: str, add_think_start: bool = True) -> List[int]:
188
+ """
189
+ Format and encode a user prompt using the standard chat template.
190
+
191
+ Args:
192
+ prompt: The user's message (plain text).
193
+ add_think_start: If True (default), append <|think_start|> so the
194
+ model begins generating its chain-of-thought immediately.
195
+
196
+ Returns:
197
+ List of token ids (prompt already encoded, ready for prefill).
198
+ """
199
+ text = f"<|im_start|>user\n{prompt}\n<|im_end|><|im_start|>assistant\n"
200
+ if add_think_start:
201
+ text += "<|think_start|>"
202
+ ids = self._enc.encode(text, allowed_special="all")
203
+ return [i if i in self._id_to_special else i + 1 for i in ids]
204
+
205
+ def apply_chat_template(self, conversation, add_generation_prompt: bool = True, **kwargs) -> List[int]:
206
+ """
207
+ Minimal chat template support for HF pipeline compatibility.
208
+
209
+ Expects conversation as a list of {"role": ..., "content": ...} dicts.
210
+ Only the last user turn is supported for now.
211
+ """
212
+ text = ""
213
+ for msg in conversation:
214
+ role = msg.get("role", "user")
215
+ content = msg.get("content", "")
216
+ text += f"<|im_start|>{role}\n{content}\n<|im_end|>"
217
+ if add_generation_prompt:
218
+ text += "<|im_start|>assistant\n<|think_start|>"
219
+ ids = self._enc.encode(text, allowed_special="all")
220
+ return [i if i in self._id_to_special else i + 1 for i in ids]
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "SeqCondTokenizer",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_seqcond.SeqCondTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "model_max_length": 4096,
10
+ "eos_token": "<|im_end|>",
11
+ "bos_token": null,
12
+ "unk_token": null,
13
+ "pad_token": "<|im_end|>",
14
+ "additional_special_tokens": [
15
+ "<|im_start|>",
16
+ "<|think_start|>",
17
+ "<|think_end|>"
18
+ ]
19
+ }
triton_kernels.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ try:
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ TRITON_AVAILABLE = True
8
+ except ImportError:
9
+ TRITON_AVAILABLE = False
10
+ triton = None
11
+ tl = None
12
+
13
+
14
+ if TRITON_AVAILABLE:
15
+ def _select_seqcond_launch_config(H: int, M: int) -> tuple[int, int]:
16
+ if M <= 1:
17
+ block_m = 1
18
+ elif M <= 2:
19
+ block_m = 2
20
+ elif M <= 4:
21
+ block_m = 4
22
+ elif M <= 8:
23
+ block_m = 8
24
+ else:
25
+ block_m = 16
26
+
27
+ if H >= 64:
28
+ block_h = 64
29
+ elif H >= 32:
30
+ block_h = 32
31
+ elif H >= 16:
32
+ block_h = 16
33
+ elif H >= 8:
34
+ block_h = 8
35
+ elif H >= 4:
36
+ block_h = 4
37
+ elif H >= 2:
38
+ block_h = 2
39
+ else:
40
+ block_h = 1
41
+ return block_m, block_h
42
+
43
+ @triton.jit
44
+ def _seqcond_fully_fused_kernel_impl(
45
+ k_ptr,
46
+ s_raw_ptr,
47
+ q_re_ptr,
48
+ q_im_ptr,
49
+ re_acc_ptr,
50
+ im_acc_ptr,
51
+ den_acc_ptr,
52
+ theta_ptr,
53
+ w_int_ptr,
54
+ phase_scale_ptr,
55
+ score_scale_ptr,
56
+ score_bias_ptr,
57
+ log_tw_ptr,
58
+ out_re_ptr,
59
+ out_im_ptr,
60
+ K: tl.constexpr,
61
+ H: tl.constexpr,
62
+ M: tl.constexpr,
63
+ stride_k_b,
64
+ stride_k_k,
65
+ stride_k_h,
66
+ stride_acc_b,
67
+ stride_acc_k,
68
+ stride_acc_h,
69
+ stride_acc_m,
70
+ stride_theta_k,
71
+ stride_theta_h,
72
+ stride_theta_m,
73
+ stride_q_b,
74
+ stride_q_k,
75
+ stride_q_h,
76
+ stride_q_m,
77
+ stride_w_k,
78
+ stride_w_h,
79
+ stride_w_m,
80
+ stride_out_b,
81
+ stride_out_k,
82
+ stride_out_h,
83
+ BLOCK_M: tl.constexpr,
84
+ BLOCK_H: tl.constexpr,
85
+ ):
86
+ pid = tl.program_id(0)
87
+ num_h_blocks = (H + BLOCK_H - 1) // BLOCK_H
88
+ b = pid // (K * num_h_blocks)
89
+ rem = pid % (K * num_h_blocks)
90
+ k = rem // num_h_blocks
91
+ h_block = rem % num_h_blocks
92
+ h_start = h_block * BLOCK_H
93
+
94
+ s_raw = tl.load(s_raw_ptr + b * K + k)
95
+ score_scale = tl.load(score_scale_ptr + k)
96
+ score_bias = tl.load(score_bias_ptr + k)
97
+ log_tw = tl.load(log_tw_ptr + b * K + k)
98
+ phase_scale = tl.load(phase_scale_ptr + k)
99
+
100
+ score = score_scale * s_raw + score_bias
101
+ p_w_content = tl.where(score > 20.0, score, tl.log(1.0 + tl.exp(score)))
102
+ p_w = p_w_content * tl.exp(log_tw)
103
+ p_w = tl.minimum(tl.maximum(p_w, 1e-4), 5000.0)
104
+
105
+ old_den = tl.load(den_acc_ptr + b * K + k)
106
+ new_den = old_den + p_w
107
+ if h_block == 0:
108
+ tl.store(den_acc_ptr + b * K + k, new_den)
109
+
110
+ offs_h = tl.arange(0, BLOCK_H)
111
+ h_idx = h_start + offs_h
112
+ h_mask = h_idx < H
113
+ k_val = tl.load(
114
+ k_ptr + b * stride_k_b + k * stride_k_k + h_idx * stride_k_h,
115
+ mask=h_mask,
116
+ other=0.0,
117
+ )
118
+ k_scaled = k_val * phase_scale
119
+ phi_base = k_scaled / (1.0 + tl.abs(k_scaled))
120
+ kvw = k_val * p_w
121
+ sum_re = tl.zeros((BLOCK_H,), dtype=tl.float32)
122
+ sum_im = tl.zeros((BLOCK_H,), dtype=tl.float32)
123
+ inv_den = 1.0 / tl.maximum(new_den, 1e-4)
124
+ scale = 1.0 / tl.sqrt(float(H))
125
+ offs_m = tl.arange(0, BLOCK_M)
126
+
127
+ for m_start in range(0, M, BLOCK_M):
128
+ m_idx = m_start + offs_m
129
+ m_mask = m_idx < M
130
+ theta_base = k * stride_theta_k
131
+ theta_vals = tl.load(
132
+ theta_ptr + theta_base + h_idx[:, None] * stride_theta_h + m_idx[None, :] * stride_theta_m,
133
+ mask=h_mask[:, None] & m_mask[None, :],
134
+ other=0.0,
135
+ )
136
+ phi = phi_base[:, None] * theta_vals
137
+ cos_phi = tl.cos(phi)
138
+ sin_phi = tl.sin(phi)
139
+ acc_base = b * stride_acc_b + k * stride_acc_k
140
+ old_re = tl.load(
141
+ re_acc_ptr + acc_base + h_idx[:, None] * stride_acc_h + m_idx[None, :] * stride_acc_m,
142
+ mask=h_mask[:, None] & m_mask[None, :],
143
+ other=0.0,
144
+ )
145
+ old_im = tl.load(
146
+ im_acc_ptr + acc_base + h_idx[:, None] * stride_acc_h + m_idx[None, :] * stride_acc_m,
147
+ mask=h_mask[:, None] & m_mask[None, :],
148
+ other=0.0,
149
+ )
150
+ new_re = old_re + kvw[:, None] * cos_phi
151
+ new_im = old_im + kvw[:, None] * sin_phi
152
+ tl.store(
153
+ re_acc_ptr + acc_base + h_idx[:, None] * stride_acc_h + m_idx[None, :] * stride_acc_m,
154
+ new_re,
155
+ mask=h_mask[:, None] & m_mask[None, :],
156
+ )
157
+ tl.store(
158
+ im_acc_ptr + acc_base + h_idx[:, None] * stride_acc_h + m_idx[None, :] * stride_acc_m,
159
+ new_im,
160
+ mask=h_mask[:, None] & m_mask[None, :],
161
+ )
162
+ q_base = b * stride_q_b + k * stride_q_k
163
+ q_re_vals = tl.load(
164
+ q_re_ptr + q_base + h_idx[:, None] * stride_q_h + m_idx[None, :] * stride_q_m,
165
+ mask=h_mask[:, None] & m_mask[None, :],
166
+ other=0.0,
167
+ )
168
+ q_im_vals = tl.load(
169
+ q_im_ptr + q_base + h_idx[:, None] * stride_q_h + m_idx[None, :] * stride_q_m,
170
+ mask=h_mask[:, None] & m_mask[None, :],
171
+ other=0.0,
172
+ )
173
+ w_base = k * stride_w_k
174
+ w_vals = tl.load(
175
+ w_int_ptr + w_base + h_idx[:, None] * stride_w_h + m_idx[None, :] * stride_w_m,
176
+ mask=h_mask[:, None] & m_mask[None, :],
177
+ other=0.0,
178
+ )
179
+ state_re = new_re * inv_den
180
+ state_im = new_im * inv_den
181
+ match_re = (state_re * q_re_vals + state_im * q_im_vals) * scale
182
+ match_im = (state_im * q_re_vals - state_re * q_im_vals) * scale
183
+ sum_re += tl.sum(match_re * w_vals, axis=1)
184
+ sum_im += tl.sum(match_im * w_vals, axis=1)
185
+
186
+ out_base = b * stride_out_b + k * stride_out_k
187
+ tl.store(out_re_ptr + out_base + h_idx * stride_out_h, sum_re, mask=h_mask)
188
+ tl.store(out_im_ptr + out_base + h_idx * stride_out_h, sum_im, mask=h_mask)
189
+
190
+
191
+ def seqcond_step_triton(
192
+ k_val: torch.Tensor,
193
+ s_raw: torch.Tensor,
194
+ q_re: torch.Tensor,
195
+ q_im: torch.Tensor,
196
+ re_acc: torch.Tensor,
197
+ im_acc: torch.Tensor,
198
+ den_acc: torch.Tensor,
199
+ theta: torch.Tensor,
200
+ w_int: torch.Tensor,
201
+ phase_scale: torch.Tensor,
202
+ score_scale: torch.Tensor,
203
+ score_bias: torch.Tensor,
204
+ log_time_weight: torch.Tensor,
205
+ out_re_buffer: torch.Tensor | None = None,
206
+ out_im_buffer: torch.Tensor | None = None,
207
+ ) -> tuple[torch.Tensor, torch.Tensor]:
208
+ B, K, H = k_val.shape
209
+ M = theta.shape[2]
210
+ K_q = q_re.shape[1]
211
+ assert K_q == K, (
212
+ f"Triton kernel requires n_rep==1 (K_q==K), got K_q={K_q}, K={K}. "
213
+ f"Use PyTorch path for n_rep>1."
214
+ )
215
+
216
+ def _prep_f32(t: torch.Tensor) -> torch.Tensor:
217
+ if t.dtype == torch.float32:
218
+ return t
219
+ return t.float()
220
+
221
+ def _prep_f32_contiguous(t: torch.Tensor) -> torch.Tensor:
222
+ if t.dtype != torch.float32:
223
+ t = t.float()
224
+ if not t.is_contiguous():
225
+ t = t.contiguous()
226
+ return t
227
+
228
+ k_val = _prep_f32(k_val)
229
+ s_raw = _prep_f32_contiguous(s_raw)
230
+ q_re = _prep_f32(q_re)
231
+ q_im = _prep_f32(q_im)
232
+ theta = _prep_f32(theta)
233
+ phase_scale = _prep_f32_contiguous(phase_scale)
234
+ score_scale = _prep_f32_contiguous(score_scale)
235
+ score_bias = _prep_f32_contiguous(score_bias)
236
+ log_time_weight = _prep_f32_contiguous(log_time_weight)
237
+ if w_int.dim() == 4:
238
+ w_int = w_int.squeeze(1)
239
+ w_int = _prep_f32(w_int)
240
+
241
+ if (
242
+ out_re_buffer is None
243
+ or out_re_buffer.shape != (B, K, H)
244
+ or out_re_buffer.device != k_val.device
245
+ or out_re_buffer.dtype != torch.float32
246
+ ):
247
+ out_re = torch.empty(B, K, H, device=k_val.device, dtype=torch.float32)
248
+ else:
249
+ out_re = out_re_buffer
250
+ if (
251
+ out_im_buffer is None
252
+ or out_im_buffer.shape != (B, K, H)
253
+ or out_im_buffer.device != k_val.device
254
+ or out_im_buffer.dtype != torch.float32
255
+ ):
256
+ out_im = torch.empty(B, K, H, device=k_val.device, dtype=torch.float32)
257
+ else:
258
+ out_im = out_im_buffer
259
+
260
+ common_args = (
261
+ k_val,
262
+ s_raw,
263
+ q_re,
264
+ q_im,
265
+ re_acc,
266
+ im_acc,
267
+ den_acc,
268
+ theta,
269
+ w_int,
270
+ phase_scale,
271
+ score_scale,
272
+ score_bias,
273
+ log_time_weight,
274
+ out_re,
275
+ out_im,
276
+ K,
277
+ H,
278
+ M,
279
+ k_val.stride(0),
280
+ k_val.stride(1),
281
+ k_val.stride(2),
282
+ re_acc.stride(0),
283
+ re_acc.stride(1),
284
+ re_acc.stride(2),
285
+ re_acc.stride(3),
286
+ theta.stride(0),
287
+ theta.stride(1),
288
+ theta.stride(2),
289
+ q_re.stride(0),
290
+ q_re.stride(1),
291
+ q_re.stride(2),
292
+ q_re.stride(3),
293
+ w_int.stride(0),
294
+ w_int.stride(1),
295
+ w_int.stride(2),
296
+ out_re.stride(0),
297
+ out_re.stride(1),
298
+ out_re.stride(2),
299
+ )
300
+ block_m, block_h = _select_seqcond_launch_config(H, M)
301
+ grid = (B * K * ((H + block_h - 1) // block_h),)
302
+ _seqcond_fully_fused_kernel_impl[grid](*common_args, BLOCK_M=block_m, BLOCK_H=block_h)
303
+ return out_re, out_im
304
+
305
+
306
+ if TRITON_AVAILABLE:
307
+ def _select_rmsnorm_block_size(n_cols: int) -> int:
308
+ block = 1
309
+ while block < n_cols:
310
+ block *= 2
311
+ return min(block, 4096)
312
+
313
+ @triton.jit
314
+ def _gated_rmsnorm_kernel(
315
+ x_ptr,
316
+ residual_ptr,
317
+ weight_ptr,
318
+ out_ptr,
319
+ n_cols,
320
+ stride_x_row,
321
+ stride_residual_row,
322
+ stride_out_row,
323
+ epsilon,
324
+ BLOCK_N: tl.constexpr,
325
+ ):
326
+ row = tl.program_id(0)
327
+ offs = tl.arange(0, BLOCK_N)
328
+ mask = offs < n_cols
329
+ x = tl.load(x_ptr + row * stride_x_row + offs, mask=mask, other=0.0).to(tl.float32)
330
+ residual = tl.load(residual_ptr + row * stride_residual_row + offs, mask=mask, other=0.0).to(tl.float32)
331
+ weight = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32)
332
+ gated = x * (residual * tl.sigmoid(residual))
333
+ variance = tl.sum(gated * gated, axis=0) / n_cols
334
+ inv_rms = tl.rsqrt(variance + epsilon)
335
+ out = gated * inv_rms * weight
336
+ tl.store(out_ptr + row * stride_out_row + offs, out, mask=mask)
337
+
338
+
339
+ def gated_rmsnorm_triton(
340
+ x: torch.Tensor,
341
+ residual: torch.Tensor,
342
+ weight: torch.Tensor,
343
+ epsilon: float,
344
+ out_buffer: torch.Tensor | None = None,
345
+ ) -> torch.Tensor:
346
+ if not TRITON_AVAILABLE:
347
+ raise RuntimeError("Triton is not available")
348
+ if x.dim() != 2 or residual.dim() != 2:
349
+ raise ValueError(
350
+ f"gated_rmsnorm_triton expects 2D tensors, got x.shape={tuple(x.shape)} residual.shape={tuple(residual.shape)}"
351
+ )
352
+ if x.shape != residual.shape:
353
+ raise ValueError(
354
+ f"gated_rmsnorm_triton expects matching x/residual shapes, got {tuple(x.shape)} and {tuple(residual.shape)}"
355
+ )
356
+ if weight.dim() != 1 or weight.shape[0] != x.shape[1]:
357
+ raise ValueError(
358
+ f"gated_rmsnorm_triton expects weight.shape == ({x.shape[1]},), got {tuple(weight.shape)}"
359
+ )
360
+
361
+ def _prep_f32_contiguous(t: torch.Tensor) -> torch.Tensor:
362
+ if t.dtype != torch.float32:
363
+ t = t.float()
364
+ if not t.is_contiguous():
365
+ t = t.contiguous()
366
+ return t
367
+
368
+ x = _prep_f32_contiguous(x)
369
+ residual = _prep_f32_contiguous(residual)
370
+ weight = _prep_f32_contiguous(weight)
371
+ rows, n_cols = x.shape
372
+ if (
373
+ out_buffer is None
374
+ or out_buffer.shape != x.shape
375
+ or out_buffer.device != x.device
376
+ or out_buffer.dtype != torch.float32
377
+ ):
378
+ out = torch.empty_like(x, dtype=torch.float32)
379
+ else:
380
+ out = out_buffer
381
+ block_n = _select_rmsnorm_block_size(n_cols)
382
+ _gated_rmsnorm_kernel[(rows,)](
383
+ x,
384
+ residual,
385
+ weight,
386
+ out,
387
+ n_cols,
388
+ x.stride(0),
389
+ residual.stride(0),
390
+ out.stride(0),
391
+ epsilon,
392
+ BLOCK_N=block_n,
393
+ )
394
+ return out