OpenOneRec commited on
Commit
59003da
Β·
verified Β·
1 Parent(s): 6e72c5c

Delete modeling_qwen3sa.py

Browse files
Files changed (1) hide show
  1. modeling_qwen3sa.py +0 -1587
modeling_qwen3sa.py DELETED
@@ -1,1587 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_qwen3.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # coding=utf-8
8
- # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
- #
10
- # Licensed under the Apache License, Version 2.0 (the "License");
11
- # you may not use this file except in compliance with the License.
12
- # You may obtain a copy of the License at
13
- #
14
- # http://www.apache.org/licenses/LICENSE-2.0
15
- #
16
- # Unless required by applicable law or agreed to in writing, software
17
- # distributed under the License is distributed on an "AS IS" BASIS,
18
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
- # See the License for the specific language governing permissions and
20
- # limitations under the License.
21
-
22
- from typing import Any, Callable, Optional, Union
23
-
24
- import torch
25
- from torch import nn
26
- import torch.nn.functional as F
27
- from torch.nn.attention import SDPBackend, sdpa_kernel
28
- from flash_attn import flash_attn_func
29
-
30
- from transformers.activations import ACT2FN
31
- from transformers.cache_utils import Cache, DynamicCache
32
- from transformers.generation import GenerationMixin
33
- from transformers.integrations import use_kernel_forward_from_hub
34
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
35
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
- from transformers.modeling_layers import (
37
- GenericForQuestionAnswering,
38
- GenericForSequenceClassification,
39
- GenericForTokenClassification,
40
- GradientCheckpointingLayer,
41
- )
42
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
43
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
- from transformers.processing_utils import Unpack
46
- from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
47
- from transformers.utils.deprecation import deprecate_kwarg
48
- from transformers.utils.generic import check_model_inputs
49
- from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
50
-
51
- from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context
52
- from summary_attn import summary_attn_func
53
-
54
-
55
- def _parse_config_pattern(val):
56
- """Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'."""
57
- if isinstance(val, list):
58
- return val
59
- if isinstance(val, str):
60
- return eval(val)
61
- return val
62
-
63
-
64
- @use_kernel_forward_from_hub("RMSNorm")
65
- class Qwen3RMSNorm(nn.Module):
66
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
67
- """
68
- Qwen3RMSNorm is equivalent to T5LayerNorm
69
- """
70
- super().__init__()
71
- self.weight = nn.Parameter(torch.ones(hidden_size))
72
- self.variance_epsilon = eps
73
-
74
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
75
- input_dtype = hidden_states.dtype
76
- hidden_states = hidden_states.to(torch.float32)
77
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
78
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
79
- return self.weight * hidden_states.to(input_dtype)
80
-
81
- def extra_repr(self):
82
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
83
-
84
-
85
- class Qwen3RingBufferCache:
86
- """
87
- Ring buffer KV cache with summary support.
88
-
89
- Two strategies based on per-layer sliding_chunk_num:
90
- - Large window layers (is_large_window=True): append-only buffer storing only text KV.
91
- Summary KV is NOT stored since text tokens attend to all text KV directly.
92
- - Small window layers (is_large_window=False): single buffer layout:
93
- [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries β†’fill | headroom ]
94
- ^0 ^S ^S+C-cbl ^S+C ^S+C+ws ^S+C+ws+n_sum
95
- scratch: temp area for summary self-KV at chunk boundaries (avoids cat).
96
- chunk_text fills right-to-left; summaries append left-to-right.
97
- get_attention_kv returns a single contiguous slice.
98
- get_summary_attention_kv writes summary KV to scratch, returns [0 : S+C].
99
-
100
- RoPE position information is baked into KV, so physical order doesn't matter.
101
- """
102
-
103
- is_compileable = False
104
- _SUMMARY_INIT_CAP = 512
105
- _APPEND_HEADROOM = 1024
106
-
107
- def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]):
108
- super().__init__()
109
- self.summary_chunk_size = getattr(config, "summary_chunk_size", 0)
110
- self.summary_token_num = getattr(config, "summary_token_num", 0)
111
- self.num_hidden_layers = config.num_hidden_layers
112
-
113
- self.sliding_chunk_nums = sliding_chunk_nums
114
- large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size
115
- self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums]
116
- self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums]
117
-
118
- self.key_cache = [None for _ in range(config.num_hidden_layers)]
119
- self.value_cache = [None for _ in range(config.num_hidden_layers)]
120
-
121
- # Large window: append-only
122
- self._text_len = [0] * config.num_hidden_layers
123
- self._capacity = [0] * config.num_hidden_layers
124
-
125
- # Small window: [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries β†’fill | headroom ]
126
- self._window_write_ptr = [0] * config.num_hidden_layers
127
- self._n_valid_window = [0] * config.num_hidden_layers
128
- self._chunk_buf_len = [0] * config.num_hidden_layers
129
- self._n_summaries = [0] * config.num_hidden_layers # number of summaries stored
130
-
131
- # Common
132
- self.cur_chunk_sizes = [0] * config.num_hidden_layers
133
- self.true_tokens = [0] * config.num_hidden_layers
134
- self._total_chunks = [0] * config.num_hidden_layers # completed chunks count
135
- self._reorganized = False
136
-
137
- def __len__(self):
138
- return self.num_hidden_layers
139
-
140
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
141
- """Returns nonzero when cache is populated (used to detect prefill vs decode)."""
142
- if layer_idx >= self.num_hidden_layers:
143
- return 0
144
- if self.is_large_window[layer_idx]:
145
- return self._text_len[layer_idx]
146
- else:
147
- return self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx] + self._n_summaries[layer_idx]
148
-
149
- def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int:
150
- if layer_idx is None:
151
- layer_idx = self.num_hidden_layers - 1
152
- return self.cur_chunk_sizes[layer_idx]
153
-
154
- def get_true_token_num(self, layer_idx: Optional[int] = None) -> int:
155
- if layer_idx is None:
156
- layer_idx = self.num_hidden_layers - 1
157
- return self.true_tokens[layer_idx]
158
-
159
- # ── Prefill: standard append (before reorganize) ──
160
-
161
- def update(
162
- self,
163
- key_states: torch.Tensor,
164
- value_states: torch.Tensor,
165
- layer_idx: int,
166
- cache_kwargs: Optional[dict[str, Any]] = None,
167
- ) -> tuple[torch.Tensor, torch.Tensor]:
168
- """Append KV during prefill (before reorganize). Returns full KV for prefill attention."""
169
- add_len = key_states.shape[-2]
170
- cur_len = self._text_len[layer_idx]
171
- new_len = cur_len + add_len
172
-
173
- if self.key_cache[layer_idx] is None:
174
- cap = new_len + self._APPEND_HEADROOM
175
- bsz, heads, _, head_dim = key_states.shape
176
- self.key_cache[layer_idx] = torch.empty(
177
- bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device)
178
- self.value_cache[layer_idx] = torch.empty(
179
- bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device)
180
- self._capacity[layer_idx] = cap
181
- elif new_len > self._capacity[layer_idx]:
182
- cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
183
- old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
184
- bsz, heads, _, head_dim = old_k.shape
185
- new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
186
- new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
187
- new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :])
188
- new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :])
189
- self.key_cache[layer_idx] = new_k
190
- self.value_cache[layer_idx] = new_v
191
- self._capacity[layer_idx] = cap
192
-
193
- self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states)
194
- self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states)
195
- self._text_len[layer_idx] = new_len
196
-
197
- if self.summary_chunk_size > 0:
198
- if cache_kwargs and 'summary_mask' in cache_kwargs:
199
- text_count = add_len - cache_kwargs['summary_mask'][0].sum().item()
200
- else:
201
- text_count = add_len
202
- self.cur_chunk_sizes[layer_idx] += add_len
203
- self.true_tokens[layer_idx] += text_count
204
-
205
- return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :]
206
-
207
- # ── Reorganize after prefill ──
208
-
209
- def reorganize_after_prefill(self, summary_mask: torch.Tensor):
210
- """Reorganize all layers from prefill block layout to ring buffer layout.
211
-
212
- Args:
213
- summary_mask: bool tensor [bsz, prefill_seq_len] where True = summary position.
214
- """
215
- if self._reorganized:
216
- return
217
- self._reorganized = True
218
-
219
- text_mask = ~summary_mask[0]
220
-
221
- for layer_idx in range(self.num_hidden_layers):
222
- prefill_len = self._text_len[layer_idx]
223
- prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :]
224
- prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :]
225
- bsz, heads, _, head_dim = prefill_k.shape
226
- device, dtype = prefill_k.device, prefill_k.dtype
227
-
228
- text_k = prefill_k[:, :, text_mask, :]
229
- text_v = prefill_v[:, :, text_mask, :]
230
- n_text = text_k.shape[2]
231
-
232
- if self.is_large_window[layer_idx]:
233
- # Large window: keep only text KV
234
- cap = n_text + self._APPEND_HEADROOM
235
- new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
236
- new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
237
- new_k[:, :, :n_text, :].copy_(text_k)
238
- new_v[:, :, :n_text, :].copy_(text_v)
239
- self.key_cache[layer_idx] = new_k
240
- self.value_cache[layer_idx] = new_v
241
- self._text_len[layer_idx] = n_text
242
- self._capacity[layer_idx] = cap
243
- else:
244
- # Small window: [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries β†’fill | headroom ]
245
- summary_k = prefill_k[:, :, summary_mask[0], :]
246
- summary_v = prefill_v[:, :, summary_mask[0], :]
247
- n_summary = summary_k.shape[2]
248
-
249
- C = self.summary_chunk_size
250
- S = self.summary_token_num
251
- ws = self.window_sizes[layer_idx]
252
- scn = self.sliding_chunk_nums[layer_idx]
253
-
254
- # Split text into complete chunks + partial remainder
255
- n_complete_chunks = n_text // C
256
- n_partial = n_text % C
257
- n_complete_text = n_complete_chunks * C
258
-
259
- # Window: last scn complete chunks (or all if fewer)
260
- n_window_chunks = min(scn, n_complete_chunks)
261
- n_window_text = n_window_chunks * C
262
- window_start = n_complete_text - n_window_text
263
-
264
- # Layout: [ scratch (S) | chunk_text (C) | ring_text (ws) | summaries | headroom ]
265
- summary_headroom = max(self._SUMMARY_INIT_CAP, n_summary + 256)
266
- total_cap = S + C + ws + n_summary + summary_headroom
267
-
268
- new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
269
- new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
270
-
271
- # Copy partial chunk text to [S+C-n_partial : S+C] (left-filled)
272
- if n_partial > 0:
273
- new_k[:, :, S + C - n_partial:S + C, :].copy_(
274
- text_k[:, :, n_complete_text:, :])
275
- new_v[:, :, S + C - n_partial:S + C, :].copy_(
276
- text_v[:, :, n_complete_text:, :])
277
-
278
- # Copy window text to [S+C : S+C+n_window_text]
279
- if n_window_text > 0:
280
- new_k[:, :, S + C:S + C + n_window_text, :].copy_(
281
- text_k[:, :, window_start:n_complete_text, :])
282
- new_v[:, :, S + C:S + C + n_window_text, :].copy_(
283
- text_v[:, :, window_start:n_complete_text, :])
284
- self._n_valid_window[layer_idx] = n_window_text
285
- self._window_write_ptr[layer_idx] = n_window_text % ws
286
-
287
- # Copy summaries to [S+C+ws : S+C+ws+n_summary]
288
- if n_summary > 0:
289
- new_k[:, :, S + C + ws:S + C + ws + n_summary, :].copy_(summary_k)
290
- new_v[:, :, S + C + ws:S + C + ws + n_summary, :].copy_(summary_v)
291
-
292
- self.key_cache[layer_idx] = new_k
293
- self.value_cache[layer_idx] = new_v
294
- self._n_summaries[layer_idx] = n_summary
295
- self._capacity[layer_idx] = total_cap
296
- self._text_len[layer_idx] = 0
297
- self._chunk_buf_len[layer_idx] = n_partial
298
-
299
- block = self.summary_chunk_size + self.summary_token_num
300
- for layer_idx in range(self.num_hidden_layers):
301
- self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block
302
- self._total_chunks[layer_idx] = self._n_summaries[layer_idx] if not self.is_large_window[layer_idx] else (self.true_tokens[layer_idx] // self.summary_chunk_size)
303
-
304
- # ── Decode: text token update ──
305
-
306
- def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
307
- """Write a single text token KV during decode."""
308
- if self.is_large_window[layer_idx]:
309
- cur = self._text_len[layer_idx]
310
- new_len = cur + 1
311
- if new_len > self._capacity[layer_idx]:
312
- cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
313
- old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
314
- bsz, heads, _, head_dim = old_k.shape
315
- new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
316
- new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
317
- new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :])
318
- new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :])
319
- self.key_cache[layer_idx] = new_k
320
- self.value_cache[layer_idx] = new_v
321
- self._capacity[layer_idx] = cap
322
- self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states)
323
- self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states)
324
- self._text_len[layer_idx] = new_len
325
- else:
326
- # Write to chunk_text region, left-filled: position S+C-1-cbl
327
- C = self.summary_chunk_size
328
- S = self.summary_token_num
329
- cbl = self._chunk_buf_len[layer_idx]
330
- dst = S + C - 1 - cbl
331
- self.key_cache[layer_idx][:, :, dst:dst+1, :].copy_(key_states)
332
- self.value_cache[layer_idx][:, :, dst:dst+1, :].copy_(value_states)
333
- self._chunk_buf_len[layer_idx] = cbl + 1
334
-
335
- self.cur_chunk_sizes[layer_idx] += 1
336
- self.true_tokens[layer_idx] += 1
337
-
338
- # ── Decode: summary token update ──
339
-
340
- def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
341
- """Write summary token KV during decode (chunk boundary).
342
-
343
- Large window: skip. Small window: flush chunk to ring, append summary rightward.
344
- """
345
- n_summary = key_states.shape[2]
346
-
347
- if self.is_large_window[layer_idx]:
348
- self.cur_chunk_sizes[layer_idx] += n_summary
349
- self._total_chunks[layer_idx] += n_summary
350
- return
351
-
352
- # ── Small window: boundary processing ──
353
- C = self.summary_chunk_size
354
- S = self.summary_token_num
355
- ws = self.window_sizes[layer_idx]
356
- cbl = self._chunk_buf_len[layer_idx]
357
- ptr = self._window_write_ptr[layer_idx]
358
-
359
- # Step A: Flush chunk_text to ring_text
360
- # chunk_text lives at [S+C-cbl : S+C], left-filled
361
- chunk_src = S + C - cbl
362
- if cbl > 0:
363
- ring_dst = S + C + ptr
364
- if ptr + cbl <= ws:
365
- self.key_cache[layer_idx][:, :, ring_dst:ring_dst + cbl, :].copy_(
366
- self.key_cache[layer_idx][:, :, chunk_src:chunk_src + cbl, :])
367
- self.value_cache[layer_idx][:, :, ring_dst:ring_dst + cbl, :].copy_(
368
- self.value_cache[layer_idx][:, :, chunk_src:chunk_src + cbl, :])
369
- else:
370
- first = ws - ptr
371
- self.key_cache[layer_idx][:, :, ring_dst:ring_dst + first, :].copy_(
372
- self.key_cache[layer_idx][:, :, chunk_src:chunk_src + first, :])
373
- self.value_cache[layer_idx][:, :, ring_dst:ring_dst + first, :].copy_(
374
- self.value_cache[layer_idx][:, :, chunk_src:chunk_src + first, :])
375
- rest = cbl - first
376
- self.key_cache[layer_idx][:, :, S + C:S + C + rest, :].copy_(
377
- self.key_cache[layer_idx][:, :, chunk_src + first:chunk_src + cbl, :])
378
- self.value_cache[layer_idx][:, :, S + C:S + C + rest, :].copy_(
379
- self.value_cache[layer_idx][:, :, chunk_src + first:chunk_src + cbl, :])
380
-
381
- self._window_write_ptr[layer_idx] = (ptr + cbl) % ws
382
- if self._n_valid_window[layer_idx] < ws:
383
- self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl)
384
- self._chunk_buf_len[layer_idx] = 0
385
-
386
- # Step B: Append summary rightward
387
- n_sum = self._n_summaries[layer_idx]
388
- sum_dst = S + C + ws + n_sum
389
- if sum_dst + n_summary > self._capacity[layer_idx]:
390
- self._grow_buffer_right(layer_idx)
391
- sum_dst = S + C + ws + self._n_summaries[layer_idx]
392
-
393
- self.key_cache[layer_idx][:, :, sum_dst:sum_dst + n_summary, :].copy_(key_states)
394
- self.value_cache[layer_idx][:, :, sum_dst:sum_dst + n_summary, :].copy_(value_states)
395
- self._n_summaries[layer_idx] += n_summary
396
-
397
- self.cur_chunk_sizes[layer_idx] += n_summary
398
- self._total_chunks[layer_idx] += n_summary
399
-
400
- # ── Decode: get KV for attention ──
401
-
402
- def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
403
- """Get full KV for text token attention.
404
-
405
- Large window: buffer[:text_len]
406
- Small window: always a single contiguous slice.
407
- - ring full: [S+C-cbl : S+C+ws+n_old_sum]
408
- - ring not full: [S+C-cbl : S+C+nv]
409
- """
410
- if self.is_large_window[layer_idx]:
411
- tl = self._text_len[layer_idx]
412
- return (self.key_cache[layer_idx][:, :, :tl, :],
413
- self.value_cache[layer_idx][:, :, :tl, :])
414
-
415
- C = self.summary_chunk_size
416
- S = self.summary_token_num
417
- ws = self.window_sizes[layer_idx]
418
- nv = self._n_valid_window[layer_idx]
419
- cbl = self._chunk_buf_len[layer_idx]
420
-
421
- start = S + C - cbl
422
-
423
- if nv >= ws:
424
- # Ring full: include old summaries (skip in-window ones)
425
- scn = self.sliding_chunk_nums[layer_idx]
426
- n_summaries = self._n_summaries[layer_idx]
427
- skip = min(scn * S, n_summaries)
428
- end = S + C + ws + (n_summaries - skip)
429
- else:
430
- # Ring not full: all summaries are in-window, skip them all
431
- end = S + C + nv
432
-
433
- return (self.key_cache[layer_idx][:, :, start:end, :],
434
- self.value_cache[layer_idx][:, :, start:end, :])
435
-
436
- def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
437
- """Get KV of the current chunk's C text tokens for summary token attention."""
438
- C = self.summary_chunk_size
439
- if self.is_large_window[layer_idx]:
440
- tl = self._text_len[layer_idx]
441
- return (self.key_cache[layer_idx][:, :, tl - C:tl, :],
442
- self.value_cache[layer_idx][:, :, tl - C:tl, :])
443
- else:
444
- S = self.summary_token_num
445
- cbl = self._chunk_buf_len[layer_idx]
446
- return (self.key_cache[layer_idx][:, :, S + C - cbl:S + C, :],
447
- self.value_cache[layer_idx][:, :, S + C - cbl:S + C, :])
448
-
449
- def get_summary_attention_kv(
450
- self,
451
- layer_idx: int,
452
- k_summary: torch.Tensor,
453
- v_summary: torch.Tensor,
454
- ) -> tuple[torch.Tensor, torch.Tensor]:
455
- """Write summary KV to scratch area [0:S], return contiguous [0 : S+C] for summary attention.
456
-
457
- This avoids a torch.cat by using the pre-reserved scratch region.
458
- For large window layers, falls back to cat (no scratch area).
459
- """
460
- C = self.summary_chunk_size
461
- S = self.summary_token_num
462
- if self.is_large_window[layer_idx]:
463
- tl = self._text_len[layer_idx]
464
- k_chunk = self.key_cache[layer_idx][:, :, tl - C:tl, :]
465
- v_chunk = self.value_cache[layer_idx][:, :, tl - C:tl, :]
466
- return (torch.cat([k_chunk, k_summary], dim=2),
467
- torch.cat([v_chunk, v_summary], dim=2))
468
- else:
469
- # Write summary KV into scratch area [0:S]
470
- self.key_cache[layer_idx][:, :, 0:S, :].copy_(k_summary)
471
- self.value_cache[layer_idx][:, :, 0:S, :].copy_(v_summary)
472
- # Return contiguous [scratch | chunk_text] = [0 : S+C]
473
- return (self.key_cache[layer_idx][:, :, 0:S + C, :],
474
- self.value_cache[layer_idx][:, :, 0:S + C, :])
475
-
476
- def _grow_buffer_right(self, layer_idx: int):
477
- """Grow buffer rightward when summary headroom is exhausted (doubling strategy).
478
-
479
- Only the tail (headroom) is extended; chunk_text and ring_text positions are unchanged.
480
- """
481
- old_k = self.key_cache[layer_idx]
482
- old_v = self.value_cache[layer_idx]
483
- bsz, heads, old_cap, head_dim = old_k.shape
484
-
485
- extra = max(self._SUMMARY_INIT_CAP, self._n_summaries[layer_idx])
486
- new_cap = max(old_cap + extra, old_cap * 2)
487
-
488
- new_k = torch.empty(bsz, heads, new_cap, head_dim, dtype=old_k.dtype, device=old_k.device)
489
- new_v = torch.empty(bsz, heads, new_cap, head_dim, dtype=old_v.dtype, device=old_v.device)
490
-
491
- # Copy all existing data in place β€” positions unchanged
492
- new_k[:, :, :old_cap, :].copy_(old_k)
493
- new_v[:, :, :old_cap, :].copy_(old_v)
494
-
495
- self.key_cache[layer_idx] = new_k
496
- self.value_cache[layer_idx] = new_v
497
- self._capacity[layer_idx] = new_cap
498
-
499
- def reset_chunk_counter(self):
500
- """Reset chunk counters after a chunk boundary step completes."""
501
- block = self.summary_chunk_size + self.summary_token_num
502
- for layer_idx in range(self.num_hidden_layers):
503
- if self.cur_chunk_sizes[layer_idx] >= block:
504
- self.cur_chunk_sizes[layer_idx] %= block
505
-
506
-
507
- class Qwen3MLP(nn.Module):
508
- def __init__(self, config):
509
- super().__init__()
510
- self.config = config
511
- self.hidden_size = config.hidden_size
512
- self.intermediate_size = config.intermediate_size
513
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
514
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
515
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
516
- self.act_fn = ACT2FN[config.hidden_act]
517
-
518
- def forward(self, x):
519
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
520
- return down_proj
521
-
522
-
523
- def rotate_half(x):
524
- """Rotates half the hidden dims of the input."""
525
- x1 = x[..., : x.shape[-1] // 2]
526
- x2 = x[..., x.shape[-1] // 2 :]
527
- return torch.cat((-x2, x1), dim=-1)
528
-
529
-
530
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
531
- """Applies Rotary Position Embedding to the query and key tensors.
532
-
533
- Args:
534
- q (`torch.Tensor`): The query tensor.
535
- k (`torch.Tensor`): The key tensor.
536
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
537
- sin (`torch.Tensor`): The sine part of the rotary embedding.
538
- position_ids (`torch.Tensor`, *optional*):
539
- Deprecated and unused.
540
- unsqueeze_dim (`int`, *optional*, defaults to 1):
541
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
542
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
543
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
544
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
545
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
546
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
547
- Returns:
548
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
549
- """
550
- cos = cos.unsqueeze(unsqueeze_dim)
551
- sin = sin.unsqueeze(unsqueeze_dim)
552
- q_embed = (q * cos) + (rotate_half(q) * sin)
553
- k_embed = (k * cos) + (rotate_half(k) * sin)
554
- return q_embed, k_embed
555
-
556
-
557
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
558
- """
559
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
560
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
561
- """
562
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
563
- if n_rep == 1:
564
- return hidden_states
565
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
566
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
567
-
568
-
569
- def eager_attention_forward(
570
- module: nn.Module,
571
- query: torch.Tensor,
572
- key: torch.Tensor,
573
- value: torch.Tensor,
574
- attention_mask: Optional[torch.Tensor],
575
- scaling: float,
576
- dropout: float = 0.0,
577
- **kwargs: Unpack[TransformersKwargs],
578
- ):
579
- key_states = repeat_kv(key, module.num_key_value_groups)
580
- value_states = repeat_kv(value, module.num_key_value_groups)
581
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
582
- if attention_mask is not None:
583
- attn_weights = attn_weights + attention_mask
584
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
585
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
586
- attn_output = torch.matmul(attn_weights, value_states)
587
- attn_output = attn_output.transpose(1, 2).contiguous()
588
- return attn_output, attn_weights
589
-
590
-
591
- def _sdpa_attention_forward(
592
- module: nn.Module,
593
- query: torch.Tensor,
594
- key: torch.Tensor,
595
- value: torch.Tensor,
596
- attention_mask: Optional[torch.Tensor],
597
- scaling: float,
598
- dropout: float = 0.0,
599
- **kwargs: Unpack[TransformersKwargs],
600
- ):
601
- key_states = repeat_kv(key, module.num_key_value_groups)
602
- value_states = repeat_kv(value, module.num_key_value_groups)
603
- attn_output = F.scaled_dot_product_attention(
604
- query,
605
- key_states,
606
- value_states,
607
- attn_mask=None,
608
- dropout_p=dropout,
609
- is_causal=False,
610
- )
611
- attn_output = attn_output.transpose(1, 2).contiguous()
612
- return attn_output, None
613
-
614
-
615
-
616
- class Qwen3Attention(nn.Module):
617
- """Multi-headed attention from 'Attention Is All You Need' paper"""
618
-
619
- def __init__(self, config: Qwen3Config, layer_idx: int):
620
- super().__init__()
621
- self.config = config
622
- self.layer_idx = layer_idx
623
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
624
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
625
- self.scaling = self.head_dim**-0.5
626
- self.attention_dropout = config.attention_dropout
627
-
628
- self.q_proj = nn.Linear(
629
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
630
- )
631
- self.k_proj = nn.Linear(
632
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
633
- )
634
- self.v_proj = nn.Linear(
635
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
636
- )
637
- self.o_proj = nn.Linear(
638
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
639
- )
640
- self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
641
- self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
642
- self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
643
- if getattr(config, "_attn_implementation", None) == "eager":
644
- self._decode_attn_fn = eager_attention_forward
645
- else:
646
- self._decode_attn_fn = _sdpa_attention_forward
647
-
648
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
649
- def forward(
650
- self,
651
- hidden_states: torch.Tensor,
652
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
653
- attention_mask: Optional[torch.Tensor],
654
- past_key_values: Optional[Cache] = None,
655
- cache_position: Optional[torch.LongTensor] = None,
656
- **kwargs: Unpack[FlashAttentionKwargs],
657
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
658
- input_shape = hidden_states.shape[:-1]
659
- hidden_shape = (*input_shape, -1, self.head_dim)
660
-
661
-
662
- query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
663
- key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
664
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
665
-
666
- cos, sin = position_embeddings
667
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
668
-
669
- if past_key_values is not None:
670
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
671
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
672
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
673
-
674
- attn_output, attn_weights = self._decode_attn_fn(
675
- self,
676
- query_states,
677
- key_states,
678
- value_states,
679
- attention_mask,
680
- dropout=0.0 if not self.training else self.attention_dropout,
681
- scaling=self.scaling,
682
- sliding_window=self.sliding_window, # diff with Llama
683
- **kwargs,
684
- )
685
-
686
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
687
- attn_output = self.o_proj(attn_output)
688
- return attn_output, attn_weights
689
-
690
-
691
- class Qwen3SummaryAttention(Qwen3Attention):
692
- """
693
- Summary-aware variant of Qwen3Attention: uses a sliding summary mask.
694
- """
695
-
696
- def __init__(self, config: Qwen3Config, layer_idx: int):
697
- super().__init__(config, layer_idx)
698
- self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
699
- self.summary_token_num = getattr(self.config, "summary_token_num", 0)
700
-
701
- # Cache sliding_chunk_num to avoid eval() on every forward call
702
- val = getattr(config, "summary_sliding_chunk_num", 0) or 0
703
- val = _parse_config_pattern(val)
704
- if isinstance(val, list):
705
- self._sliding_chunk_num = val[layer_idx]
706
- else:
707
- self._sliding_chunk_num = int(val)
708
-
709
- if config.summary_independent_parameters and config.mix_coeff > 0:
710
- self.q_proj_summary = nn.Linear(
711
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
712
- )
713
- self.k_proj_summary = nn.Linear(
714
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
715
- )
716
- self.v_proj_summary = nn.Linear(
717
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
718
- )
719
-
720
- def _get_sliding_chunk_num(self):
721
- return self._sliding_chunk_num
722
-
723
- def get_query_key_value_tensors(self, hidden_states):
724
- input_shape = hidden_states.shape[:-1]
725
- hidden_shape = (*input_shape, -1, self.head_dim)
726
- query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
727
- key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
728
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
729
-
730
- return query_states, key_states, value_states
731
-
732
- def get_query_key_value_tensors_summary(self, hidden_states):
733
- input_shape = hidden_states.shape[:-1]
734
- hidden_shape = (*input_shape, -1, self.head_dim)
735
- query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
736
- key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
737
- value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2)
738
-
739
- return query_states, key_states, value_states
740
-
741
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
742
- def forward(
743
- self,
744
- hidden_states: torch.Tensor,
745
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
746
- attention_mask: Optional[torch.Tensor] = None,
747
- past_key_values: Optional[Cache] = None,
748
- cache_position: Optional[torch.LongTensor] = None,
749
- summary_ctx: Optional[SummaryBatchContext] = None,
750
- **kwargs,
751
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
752
- input_shape = hidden_states.shape[:-1]
753
- if hidden_states.size(0) != 1:
754
- raise ValueError("Summary sliding attention only supports batch size=1.")
755
-
756
- # Compute q/k/v for the full sequence once.
757
- if self.config.summary_independent_parameters:
758
- if summary_ctx is None:
759
- raise ValueError("summary_ctx is required when using summary_independent_parameters.")
760
- summary_mask = summary_ctx.summary_mask
761
- summary_pos = summary_mask[0]
762
- assert (summary_mask == summary_mask[0:1]).all()
763
-
764
- if self.config.mix_coeff == 0:
765
- # When mix_coeff=0, summary projections have no effect β€” skip clone + extra linear
766
- query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
767
- else:
768
- query, key, value = self.get_query_key_value_tensors(hidden_states)
769
-
770
- query_states = query.clone()
771
- key_states = key.clone()
772
- value_states = value.clone()
773
-
774
- hs_summary = hidden_states[:, summary_pos, :]
775
- if hs_summary.size(1) > 0:
776
- base_q_summary = query[:, :, summary_pos, :]
777
- base_k_summary = key[:, :, summary_pos, :]
778
- base_v_summary = value[:, :, summary_pos, :]
779
-
780
- q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary)
781
-
782
- q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary
783
- k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary
784
- v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary
785
-
786
- query_states[:, :, summary_pos, :] = q_s
787
- key_states[:, :, summary_pos, :] = k_s
788
- value_states[:, :, summary_pos, :] = v_s
789
- else:
790
- query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
791
-
792
- cos, sin = position_embeddings
793
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
794
-
795
- query_len = query_states.shape[2]
796
- is_prefill = past_key_values is None or not past_key_values._reorganized
797
-
798
- if is_prefill:
799
- # Prefill: use standard append and summary_attn_func
800
- if past_key_values is not None:
801
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
802
- if summary_ctx is not None:
803
- cache_kwargs["summary_mask"] = summary_ctx.summary_mask
804
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
805
-
806
- with torch.cuda.device(query_states.device):
807
- attn_output, attn_weights = summary_attn_func(
808
- query_states.transpose(1,2).contiguous(),
809
- key_states.transpose(1,2).contiguous(),
810
- value_states.transpose(1,2).contiguous(),
811
- self.summary_chunk_size,
812
- self.summary_token_num,
813
- self._get_sliding_chunk_num(),
814
- summary_pos=summary_ctx.summary_mask.squeeze()
815
- )
816
- elif query_len == 1:
817
- # Single text token decode: write to cache, attend to full buffer
818
- past_key_values.update_text(key_states, value_states, self.layer_idx)
819
- k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
820
- attn_output, attn_weights = self._decode_attn_fn(
821
- self,
822
- query_states,
823
- k_full,
824
- v_full,
825
- None,
826
- dropout=0.0 if not self.training else self.attention_dropout,
827
- scaling=self.scaling,
828
- sliding_window=self.sliding_window,
829
- **kwargs,
830
- )
831
- else:
832
- # Chunk boundary: query = [text_token, summary_token(s)]
833
- # Split into text (first token) and summary (remaining tokens)
834
- q_text = query_states[:, :, :1, :]
835
- q_summary = query_states[:, :, 1:, :]
836
- k_text = key_states[:, :, :1, :]
837
- v_text = value_states[:, :, :1, :]
838
- k_summary = key_states[:, :, 1:, :]
839
- v_summary = value_states[:, :, 1:, :]
840
-
841
- # 1. Write text token to cache, get full KV, run text attention
842
- past_key_values.update_text(k_text, v_text, self.layer_idx)
843
- k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
844
- text_out, _ = self._decode_attn_fn(
845
- self,
846
- q_text,
847
- k_full,
848
- v_full,
849
- None,
850
- dropout=0.0 if not self.training else self.attention_dropout,
851
- scaling=self.scaling,
852
- sliding_window=self.sliding_window,
853
- **kwargs,
854
- )
855
-
856
- # 2. Summary attention: attend to current chunk's C text tokens + own KV (self-attention)
857
- # Uses scratch area [0:S] for summary self-KV, contiguous with chunk [S:S+C].
858
- k_chunk_with_self, v_chunk_with_self = past_key_values.get_summary_attention_kv(
859
- self.layer_idx, k_summary, v_summary)
860
- summary_out, _ = self._decode_attn_fn(
861
- self,
862
- q_summary,
863
- k_chunk_with_self,
864
- v_chunk_with_self,
865
- None,
866
- dropout=0.0 if not self.training else self.attention_dropout,
867
- scaling=self.scaling,
868
- sliding_window=self.sliding_window,
869
- **kwargs,
870
- )
871
-
872
- # 3. Write summary KV to cache
873
- past_key_values.update_summary(k_summary, v_summary, self.layer_idx)
874
-
875
- attn_output = torch.cat([text_out, summary_out], dim=2)
876
- attn_weights = None
877
-
878
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
879
- attn_output = self.o_proj(attn_output)
880
- return attn_output, attn_weights
881
-
882
-
883
- class Qwen3DecoderLayer(GradientCheckpointingLayer):
884
- def __init__(self, config: Qwen3Config, layer_idx: int):
885
- super().__init__()
886
- self.config = config
887
- self.hidden_size = config.hidden_size
888
-
889
- # Use SummaryAttention if enabled in config
890
- if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1:
891
- self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx)
892
- elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0:
893
- self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
894
- else:
895
- raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}')
896
-
897
- self.mlp = Qwen3MLP(config)
898
- self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
899
- if getattr(config, "summary_independent_attention_layernorm", False):
900
- self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
901
- self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
902
- self.attention_type = config.layer_types[layer_idx]
903
-
904
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
905
- def forward(
906
- self,
907
- hidden_states: torch.Tensor,
908
- attention_mask: Optional[torch.Tensor] = None,
909
- position_ids: Optional[torch.LongTensor] = None,
910
- past_key_values: Optional[Cache] = None,
911
- use_cache: Optional[bool] = False,
912
- cache_position: Optional[torch.LongTensor] = None,
913
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
914
- summary_ctx: Optional[SummaryBatchContext] = None,
915
- **kwargs: Unpack[TransformersKwargs],
916
- ) -> torch.Tensor:
917
- residual = hidden_states
918
- if getattr(self.config, "summary_independent_attention_layernorm", False):
919
- summary_mask = summary_ctx.summary_mask
920
- assert (summary_mask == summary_mask[0:1]).all(), \
921
- "summary_mask must be identical across batch"
922
- hidden_states = self.input_layernorm(hidden_states)
923
- if summary_mask.any():
924
- hidden_summary = residual[:, summary_mask[0].to(residual.device), :]
925
- hidden_summary = self.input_layernorm_summary(hidden_summary)
926
- hidden_states[:, summary_mask[0], :] = hidden_summary
927
- else:
928
- hidden_states = self.input_layernorm(hidden_states)
929
-
930
- # Self Attention - pass summary_ctx if using summary attention
931
- attn_kwargs = {
932
- "hidden_states": hidden_states,
933
- "attention_mask": attention_mask,
934
- "position_ids": position_ids,
935
- "past_key_values": past_key_values,
936
- "use_cache": use_cache,
937
- "cache_position": cache_position,
938
- "position_embeddings": position_embeddings,
939
- **kwargs,
940
- }
941
- if isinstance(self.self_attn, Qwen3SummaryAttention):
942
- attn_kwargs["summary_ctx"] = summary_ctx
943
-
944
- hidden_states, _ = self.self_attn(**attn_kwargs)
945
- hidden_states = residual + hidden_states
946
-
947
- # Fully Connected
948
- residual = hidden_states
949
- hidden_states = self.post_attention_layernorm(hidden_states)
950
- hidden_states = self.mlp(hidden_states)
951
- hidden_states = residual + hidden_states
952
- return hidden_states
953
-
954
-
955
- @auto_docstring
956
- class Qwen3PreTrainedModel(PreTrainedModel):
957
- config: Qwen3Config
958
- base_model_prefix = "model"
959
- supports_gradient_checkpointing = True
960
- _no_split_modules = ["Qwen3DecoderLayer"]
961
- _skip_keys_device_placement = ["past_key_values"]
962
- _supports_flash_attn = True
963
- _supports_sdpa = True
964
- _supports_flex_attn = True
965
-
966
- _can_compile_fullgraph = True
967
- _supports_attention_backend = True
968
- _can_record_outputs = {
969
- "hidden_states": Qwen3DecoderLayer,
970
- "attentions": Qwen3Attention,
971
- }
972
-
973
-
974
- class Qwen3RotaryEmbedding(nn.Module):
975
- inv_freq: torch.Tensor # fix linting for `register_buffer`
976
-
977
- def __init__(self, config: Qwen3Config, device=None):
978
- super().__init__()
979
- self.max_seq_len_cached = config.max_position_embeddings
980
- self.original_max_seq_len = config.max_position_embeddings
981
-
982
- self.config = config
983
-
984
- self.rope_type = self.config.rope_parameters["rope_type"]
985
- rope_init_fn: Callable = self.compute_default_rope_parameters
986
- if self.rope_type != "default":
987
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
988
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
989
-
990
- self.register_buffer("inv_freq", inv_freq, persistent=False)
991
- self.original_inv_freq = inv_freq
992
-
993
- @staticmethod
994
- def compute_default_rope_parameters(
995
- config: Optional[Qwen3Config] = None,
996
- device: Optional["torch.device"] = None,
997
- seq_len: Optional[int] = None,
998
- ) -> tuple["torch.Tensor", float]:
999
- """
1000
- Computes the inverse frequencies according to the original RoPE implementation
1001
- Args:
1002
- config ([`~transformers.PreTrainedConfig`]):
1003
- The model configuration.
1004
- device (`torch.device`):
1005
- The device to use for initialization of the inverse frequencies.
1006
- seq_len (`int`, *optional*):
1007
- The current sequence length. Unused for this type of RoPE.
1008
- Returns:
1009
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
1010
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
1011
- """
1012
- base = config.rope_parameters["rope_theta"]
1013
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
1014
-
1015
- attention_factor = 1.0 # Unused in this type of RoPE
1016
-
1017
- # Compute the inverse frequencies
1018
- inv_freq = 1.0 / (
1019
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
1020
- )
1021
- return inv_freq, attention_factor
1022
-
1023
- @torch.no_grad()
1024
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1025
- def forward(self, x, position_ids):
1026
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1027
- position_ids_expanded = position_ids[:, None, :].float()
1028
-
1029
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1030
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1031
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1032
- emb = torch.cat((freqs, freqs), dim=-1)
1033
- cos = emb.cos() * self.attention_scaling
1034
- sin = emb.sin() * self.attention_scaling
1035
-
1036
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1037
-
1038
-
1039
- @auto_docstring
1040
- class Qwen3Model(Qwen3PreTrainedModel):
1041
- def __init__(self, config: Qwen3Config):
1042
- super().__init__(config)
1043
- self.padding_idx = config.pad_token_id
1044
- self.vocab_size = config.vocab_size
1045
- if not getattr(config, "summary_layer_freq", False):
1046
- if config.use_summary_attention:
1047
- config.summary_layer_freq = [1]*config.num_hidden_layers
1048
- else:
1049
- config.summary_layer_freq = [0]*config.num_hidden_layers
1050
- Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}')
1051
- else:
1052
- config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq)
1053
-
1054
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1055
- self.layers = nn.ModuleList(
1056
- [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1057
- )
1058
- self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1059
- self.rotary_emb = Qwen3RotaryEmbedding(config=config)
1060
- self.gradient_checkpointing = False
1061
- self.has_sliding_layers = "sliding_attention" in self.config.layer_types
1062
-
1063
- # Cache per-layer sliding_chunk_nums for KV cache eviction
1064
- _sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0)
1065
- if isinstance(_sv, list):
1066
- self._sliding_chunk_nums = [int(v) for v in _sv]
1067
- else:
1068
- self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers
1069
-
1070
- # Initialize weights and apply final processing
1071
- self.post_init()
1072
-
1073
- def _expand_input_with_summary_tokens(self, input_ids):
1074
- """Expand input_ids with summary tokens for prefill phase (vectorized).
1075
-
1076
- Returns:
1077
- Tuple of (expanded_input_ids, position_ids, text_only_mask)
1078
- """
1079
- summary_chunk = self.config.summary_chunk_size
1080
- summary_num = self.config.summary_token_num
1081
- summary_begin = self.config.summary_token_begin
1082
-
1083
- if summary_chunk == 0 or summary_num == 0:
1084
- return input_ids, None, None
1085
-
1086
- batch_size, seq_len = input_ids.shape
1087
- device = input_ids.device
1088
- dtype = input_ids.dtype
1089
- block = summary_chunk + summary_num
1090
-
1091
- # Number of full chunks and remainder
1092
- n_full_chunks = seq_len // summary_chunk
1093
- remainder = seq_len % summary_chunk
1094
- has_remainder = remainder > 0
1095
-
1096
- # Total expanded length: full_chunks * block + remainder
1097
- expanded_len = n_full_chunks * block + (remainder if has_remainder else 0)
1098
-
1099
- # --- Build expanded_input_ids ---
1100
- expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device)
1101
- text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device)
1102
-
1103
- # Compute text positions: for chunk i, text goes to [i*block, i*block+summary_chunk)
1104
- # Summary positions: [i*block+summary_chunk, (i+1)*block)
1105
- if n_full_chunks > 0:
1106
- chunk_indices = torch.arange(n_full_chunks, device=device)
1107
- # Text source positions in original input_ids
1108
- text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1109
- # Text dest positions in expanded
1110
- text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1111
- # Summary dest positions
1112
- summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) # [n_full_chunks, summary_num]
1113
-
1114
- text_src_flat = text_src_offsets.reshape(-1)
1115
- text_dst_flat = text_dst_offsets.reshape(-1)
1116
- summary_dst_flat = summary_dst_offsets.reshape(-1)
1117
-
1118
- # Copy text tokens
1119
- expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat]
1120
- text_only_mask[:, text_dst_flat] = True
1121
-
1122
- # Fill summary tokens
1123
- summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin
1124
- expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1)
1125
-
1126
- # Handle remainder (last partial chunk, no summary tokens)
1127
- if has_remainder:
1128
- rem_start_src = n_full_chunks * summary_chunk
1129
- rem_start_dst = n_full_chunks * block
1130
- rem_offsets = torch.arange(remainder, device=device)
1131
- expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets]
1132
- text_only_mask[:, rem_start_dst + rem_offsets] = True
1133
-
1134
- # --- Build position_ids ---
1135
- position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device)
1136
-
1137
- if n_full_chunks > 0:
1138
- # Text position IDs
1139
- if self.config.summary_chunk_position_ids_type == 'origin':
1140
- text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1)
1141
- elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1142
- inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks)
1143
- text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1)
1144
- else:
1145
- raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1146
- position_ids[:, text_dst_flat] = text_pos
1147
-
1148
- # Summary position IDs
1149
- if self.config.summary_token_position_ids_type == 'zeros':
1150
- position_ids[:, summary_dst_flat] = 0
1151
- elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'):
1152
- # Vectorized slice_ends computation for all chunks at once
1153
- if self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1154
- idx = torch.arange(0, summary_num, device=device, dtype=torch.long)
1155
- else:
1156
- idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long)
1157
- # For each chunk i: prev_text_end = i * summary_chunk
1158
- prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) # [n_full_chunks, 1]
1159
- slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 # [n_full_chunks, summary_num]
1160
- slice_ends = slice_ends.clamp(min=0)
1161
- # Clamp per-chunk: min is prev_text_end for that chunk
1162
- slice_ends = torch.max(slice_ends, prev_ends)
1163
- position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1)
1164
- else:
1165
- raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}')
1166
-
1167
- # Remainder position IDs
1168
- if has_remainder:
1169
- if self.config.summary_chunk_position_ids_type == 'origin':
1170
- rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1)
1171
- elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1172
- rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1)
1173
- else:
1174
- raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1175
- position_ids[:, rem_start_dst + rem_offsets] = rem_pos
1176
-
1177
- return expanded_ids, position_ids, text_only_mask
1178
-
1179
- def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache):
1180
- """Build summary context for attention layers."""
1181
- summary_chunk = self.config.summary_chunk_size
1182
- summary_num = self.config.summary_token_num
1183
- summary_begin = self.config.summary_token_begin
1184
-
1185
- if summary_chunk > 0 and summary_num > 0:
1186
- return build_summary_sliding_context(
1187
- input_ids=input_ids,
1188
- position_ids=position_ids,
1189
- summary_token_num=summary_num,
1190
- summary_token_begin=summary_begin,
1191
- )
1192
- return None
1193
-
1194
- def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode):
1195
- """Filter out summary tokens from output hidden states."""
1196
- if text_only_mask is not None:
1197
- # Prefill: vectorized filtering using boolean mask
1198
- batch_size, _, hidden_size = hidden_states.shape
1199
- text_length = text_only_mask[0].sum().item()
1200
- return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size)
1201
- elif use_summary and is_decode and hidden_states.size(1) > 1:
1202
- # Decode: if we have multiple tokens, only return the first (text token)
1203
- return hidden_states[:, :1, :]
1204
- return hidden_states
1205
-
1206
- @check_model_inputs()
1207
- @auto_docstring
1208
- def forward(
1209
- self,
1210
- input_ids: Optional[torch.LongTensor] = None,
1211
- attention_mask: Optional[torch.Tensor] = None,
1212
- position_ids: Optional[torch.LongTensor] = None,
1213
- past_key_values: Optional[Cache] = None,
1214
- inputs_embeds: Optional[torch.FloatTensor] = None,
1215
- use_cache: Optional[bool] = None,
1216
- cache_position: Optional[torch.LongTensor] = None,
1217
- summary_ctx: Optional[SummaryBatchContext] = None,
1218
- **kwargs: Unpack[TransformersKwargs],
1219
- ) -> BaseModelOutputWithPast:
1220
- if (input_ids is None) ^ (inputs_embeds is not None):
1221
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1222
- use_summary = getattr(self.config, "use_summary_attention", False)
1223
- is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
1224
-
1225
- # Prefill phase with summary attention: expand input_ids with summary tokens
1226
- text_only_mask = None
1227
- if use_summary and input_ids is not None and inputs_embeds is None and is_prefill:
1228
- input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids)
1229
-
1230
- if inputs_embeds is None:
1231
- inputs_embeds = self.embed_tokens(input_ids)
1232
-
1233
- # Initialize cache
1234
- if use_cache and past_key_values is None:
1235
- if use_summary:
1236
- past_key_values = Qwen3RingBufferCache(
1237
- config=self.config, sliding_chunk_nums=self._sliding_chunk_nums)
1238
- else:
1239
- past_key_values = DynamicCache(config=self.config)
1240
-
1241
- if cache_position is None:
1242
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1243
- cache_position = torch.arange(
1244
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1245
- )
1246
-
1247
- if position_ids is None:
1248
- position_ids = cache_position.unsqueeze(0)
1249
-
1250
- # Build summary context if needed
1251
- if use_summary and summary_ctx is None and input_ids is not None:
1252
- summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache)
1253
-
1254
- causal_mask_mapping = attention_mask
1255
- if not isinstance(causal_mask_mapping, (dict, list)):
1256
- if summary_ctx and summary_ctx.enabled:
1257
- seq_len = inputs_embeds.shape[1]
1258
- # During prefill, Qwen3SummaryAttention uses summary_attn_func
1259
- # which does not need a dense mask. Skip expensive mask construction.
1260
- # During decode, prepare_inputs_for_generation already computed
1261
- # per-layer keep_indices and passed them as attention_mask (list).
1262
- # If we reach here with a non-list, it means no mask is needed.
1263
- causal_mask_mapping = None
1264
- else:
1265
- # Prepare mask arguments
1266
- mask_kwargs = {
1267
- "config": self.config,
1268
- "input_embeds": inputs_embeds,
1269
- "attention_mask": attention_mask,
1270
- "cache_position": cache_position,
1271
- "past_key_values": past_key_values,
1272
- "position_ids": position_ids,
1273
- }
1274
- # Create the masks - disable causal mask when summary context is enabled
1275
- causal_mask_mapping = {
1276
- "full_attention": create_causal_mask(**mask_kwargs),
1277
- }
1278
- # The sliding window alternating layers are not always activated depending on the config
1279
- if self.has_sliding_layers:
1280
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1281
-
1282
- hidden_states = inputs_embeds
1283
-
1284
- # create position embeddings to be shared across the decoder layers
1285
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
1286
-
1287
- for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
1288
- if causal_mask_mapping is None:
1289
- layer_mask = None
1290
- elif isinstance(causal_mask_mapping, list):
1291
- layer_mask = causal_mask_mapping[layer_idx]
1292
- else:
1293
- layer_mask = causal_mask_mapping[decoder_layer.attention_type]
1294
- hidden_states = decoder_layer(
1295
- hidden_states,
1296
- attention_mask=layer_mask,
1297
- position_ids=position_ids,
1298
- past_key_values=past_key_values,
1299
- use_cache=use_cache,
1300
- cache_position=cache_position,
1301
- position_embeddings=position_embeddings,
1302
- summary_ctx=summary_ctx,
1303
- **kwargs,
1304
- )
1305
-
1306
- hidden_states = self.norm(hidden_states)
1307
-
1308
- # After prefill: reorganize cache to ring buffer layout
1309
- if use_cache and use_summary and past_key_values is not None and is_prefill:
1310
- if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None:
1311
- past_key_values.reorganize_after_prefill(summary_ctx.summary_mask)
1312
-
1313
- # After chunk boundary decode: reset chunk counters
1314
- if use_cache and use_summary and past_key_values is not None and not is_prefill:
1315
- if hasattr(past_key_values, 'reset_chunk_counter'):
1316
- past_key_values.reset_chunk_counter()
1317
-
1318
- # Filter out summary tokens from output
1319
- hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary,
1320
- past_key_values is not None and past_key_values.get_seq_length() > 0)
1321
-
1322
- return BaseModelOutputWithPast(
1323
- last_hidden_state=hidden_states,
1324
- past_key_values=past_key_values if use_cache else None,
1325
- )
1326
-
1327
-
1328
- @auto_docstring
1329
- class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
1330
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1331
- _tp_plan = {"lm_head": "colwise_rep"}
1332
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1333
-
1334
- def __init__(self, config):
1335
- super().__init__(config)
1336
- self.model = Qwen3Model(config)
1337
- self.vocab_size = config.vocab_size
1338
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1339
-
1340
- # Initialize weights and apply final processing
1341
- self.post_init()
1342
-
1343
- @can_return_tuple
1344
- @auto_docstring
1345
- def forward(
1346
- self,
1347
- input_ids: Optional[torch.LongTensor] = None,
1348
- attention_mask: Optional[torch.Tensor] = None,
1349
- position_ids: Optional[torch.LongTensor] = None,
1350
- past_key_values: Optional[Cache] = None,
1351
- inputs_embeds: Optional[torch.FloatTensor] = None,
1352
- labels: Optional[torch.LongTensor] = None,
1353
- use_cache: Optional[bool] = None,
1354
- cache_position: Optional[torch.LongTensor] = None,
1355
- logits_to_keep: Union[int, torch.Tensor] = 0,
1356
- summary_ctx: Optional[SummaryBatchContext] = None,
1357
- **kwargs: Unpack[TransformersKwargs],
1358
- ) -> CausalLMOutputWithPast:
1359
- r"""
1360
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1361
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1362
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1363
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1364
-
1365
- Example:
1366
-
1367
- ```python
1368
- >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
1369
-
1370
- >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
1371
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
1372
-
1373
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1374
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1375
-
1376
- >>> # Generate
1377
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1378
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1379
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1380
- ```"""
1381
- outputs: BaseModelOutputWithPast = self.model(
1382
- input_ids=input_ids,
1383
- attention_mask=attention_mask,
1384
- position_ids=position_ids,
1385
- past_key_values=past_key_values,
1386
- inputs_embeds=inputs_embeds,
1387
- use_cache=use_cache,
1388
- cache_position=cache_position,
1389
- summary_ctx=summary_ctx,
1390
- **kwargs,
1391
- )
1392
-
1393
- hidden_states = outputs.last_hidden_state
1394
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1395
- if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None:
1396
- # Inference: only need last token's logits to avoid OOM from [seq_len, vocab_size]
1397
- logits = self.lm_head(hidden_states[:, -1:, :])
1398
- else:
1399
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1400
- logits = self.lm_head(hidden_states[:, slice_indices, :])
1401
-
1402
- truncate_n = getattr(self.config, "truncate_predict_nums", 151936)
1403
- if truncate_n > 0:
1404
- logits = logits[..., :truncate_n]
1405
-
1406
- loss = None
1407
- if labels is not None:
1408
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs)
1409
-
1410
- return CausalLMOutputWithPast(
1411
- loss=loss,
1412
- logits=logits,
1413
- past_key_values=outputs.past_key_values,
1414
- hidden_states=outputs.hidden_states,
1415
- attentions=outputs.attentions,
1416
- )
1417
-
1418
- def _build_summary_attention_mask_for_generation(
1419
- self,
1420
- *,
1421
- input_ids: torch.LongTensor,
1422
- past_key_values: Optional[Cache],
1423
- attention_mask: Optional[torch.Tensor],
1424
- ) -> Optional[torch.Tensor]:
1425
- """Ring buffer cache handles attention internally β€” no mask needed for decode."""
1426
- if isinstance(past_key_values, Qwen3RingBufferCache):
1427
- return None
1428
- return attention_mask
1429
-
1430
- def prepare_inputs_for_generation(
1431
- self,
1432
- input_ids: torch.LongTensor,
1433
- past_key_values: Optional[Cache] = None,
1434
- attention_mask: Optional[torch.LongTensor] = None,
1435
- inputs_embeds: Optional[torch.FloatTensor] = None,
1436
- cache_position: Optional[torch.LongTensor] = None,
1437
- position_ids: Optional[torch.LongTensor] = None,
1438
- **kwargs,
1439
- ):
1440
- use_summary = getattr(self.config, "use_summary_attention", False)
1441
-
1442
- # If not using summary attention, use standard behavior
1443
- if not use_summary:
1444
- return super().prepare_inputs_for_generation(
1445
- input_ids=input_ids,
1446
- past_key_values=past_key_values,
1447
- attention_mask=attention_mask,
1448
- inputs_embeds=inputs_embeds,
1449
- cache_position=cache_position,
1450
- position_ids=position_ids,
1451
- **kwargs,
1452
- )
1453
-
1454
- # For summary attention: handle cache-based input slicing
1455
- summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
1456
- summary_token_num = getattr(self.config, "summary_token_num", 0)
1457
- summary_token_begin = getattr(self.config, "summary_token_begin", 0)
1458
-
1459
- # Prefill phase: pass full sequence, forward() will handle summary token insertion
1460
- if past_key_values is None or past_key_values.get_seq_length() == 0:
1461
- if cache_position is None:
1462
- cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device)
1463
-
1464
- return {
1465
- "input_ids": input_ids,
1466
- "attention_mask": attention_mask,
1467
- "position_ids": position_ids,
1468
- "past_key_values": past_key_values,
1469
- "cache_position": cache_position,
1470
- "use_cache": kwargs.get("use_cache"),
1471
- }
1472
-
1473
- # Decode phase: only pass new tokens not in cache
1474
- # Get current chunk size (number of text tokens in current chunk)
1475
- cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0
1476
- true_token_num = past_key_values.get_true_token_num()
1477
-
1478
- # Only take the new tokens that haven't been processed
1479
- if input_ids.shape[1] > 1:
1480
- # Slice to get only new tokens
1481
- new_token_count = input_ids.shape[1] - true_token_num
1482
- assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0'
1483
- input_ids = input_ids[:, -new_token_count:]
1484
- device = input_ids.device
1485
- # Check if we need to insert summary tokens
1486
- # If cur_chunk >= summary_chunk_size, we need to generate summary tokens
1487
- if cur_chunk == summary_chunk_size - 1:
1488
- # Insert summary tokens
1489
- batch_size = input_ids.shape[0]
1490
- summary_ids = (
1491
- torch.arange(summary_token_num, device=device, dtype=input_ids.dtype)
1492
- + summary_token_begin
1493
- ).unsqueeze(0).repeat(batch_size, 1)
1494
-
1495
- # Concatenate: [text_token, summary_tokens]
1496
- input_ids = torch.cat([input_ids, summary_ids], dim=1)
1497
-
1498
- # Position IDs: text token uses cur_chunk, summary tokens use 0
1499
- if self.config.summary_chunk_position_ids_type == 'origin':
1500
- text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long)
1501
- elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1502
- text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long)
1503
- else:
1504
- raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1505
-
1506
- if self.config.summary_token_position_ids_type == 'zeros':
1507
- summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long)
1508
- elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1509
- # η­‰εˆ†ζˆ summary_num 份,每δΈͺ summary token 取对应 slice ηš„ζœ«ε°Ύ
1510
- prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1511
- cur_text_end = past_key_values.get_true_token_num()+1
1512
- chunk_len = cur_text_end - prev_text_end
1513
-
1514
- idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,)
1515
-
1516
- # ζ―δΈ€δ»½ηš„ζœ«ε°ΎοΌˆε…¨ε±€ positionοΌ‰
1517
- slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1518
- slice_ends = slice_ends.clamp(min=prev_text_end)
1519
-
1520
- summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1521
- elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right':
1522
- # η­‰εˆ†ζˆ summary_num 份,每δΈͺ summary token 取对应 slice ηš„ζœ«ε°Ύ
1523
- prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1524
- cur_text_end = past_key_values.get_true_token_num()+1
1525
- chunk_len = cur_text_end - prev_text_end
1526
-
1527
- idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,)
1528
-
1529
- # ζ―δΈ€δ»½ηš„ζœ«ε°ΎοΌˆε…¨ε±€ positionοΌ‰
1530
- slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1531
- slice_ends = slice_ends.clamp(min=prev_text_end)
1532
-
1533
- summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1534
-
1535
- else:
1536
- raise ValueError('')
1537
-
1538
- position_ids = torch.cat([text_pos, summary_pos], dim=1)
1539
- else:
1540
- # Normal decode: just the new text token with position = cur_chunk
1541
- if position_ids is None:
1542
- batch_size = input_ids.shape[0]
1543
- if self.config.summary_chunk_position_ids_type == 'origin':
1544
- position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long)
1545
- elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1546
- position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long)
1547
- else:
1548
- raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1549
- return {
1550
- "input_ids": input_ids,
1551
- "attention_mask": self._build_summary_attention_mask_for_generation(
1552
- input_ids=input_ids,
1553
- past_key_values=past_key_values,
1554
- attention_mask=attention_mask,
1555
- ),
1556
- "position_ids": position_ids,
1557
- "past_key_values": past_key_values,
1558
- "cache_position": cache_position,
1559
- "use_cache": kwargs.get("use_cache"),
1560
- }
1561
-
1562
-
1563
- class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
1564
- pass
1565
-
1566
-
1567
- class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
1568
- pass
1569
-
1570
-
1571
- class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
1572
- base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
1573
-
1574
-
1575
- __all__ = [
1576
- "Qwen3ForCausalLM",
1577
- "Qwen3ForQuestionAnswering",
1578
- "Qwen3PreTrainedModel",
1579
- "Qwen3Model",
1580
- "Qwen3ForSequenceClassification",
1581
- "Qwen3ForTokenClassification",
1582
- "Qwen3RingBufferCache",
1583
- "Qwen3SummaryAttention",
1584
- "SummaryBatchContext",
1585
- "build_summary_context",
1586
- "build_summary_sliding_context",
1587
- ]