OpenMOSE commited on
Commit
f4c0387
·
verified ·
1 Parent(s): b2b4601

Upload folder using huggingface_hub

Browse files
.claude/settings.local.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(python:*)"
5
+ ]
6
+ }
7
+ }
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ __pycache__/modeling_gemma4.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
__pycache__/modeling_gemma4.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c18c26df8ceaf6e0d059f5b3ade4c280e5c620a53a884f3862d030b661afb580
3
+ size 138944
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c794a1dc764de59e26dc03e4ea15d21287b9bba608919978953a0dc28f83988
3
+ size 10592043378
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:495bc9b397d06f21180c706a30cde46890212cc0d152099ab82569693587707d
3
+ size 10713269604
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87b710e3ca5c45b26da0e325b0436914a6e7a38fccbd6357c0a62705250b1701
3
+ size 10625200294
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57d37a5e4d430fbb167a648cf99d9cc62c6db7a9005999331389165969d07b13
3
+ size 10691259566
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0884a2fa37260eb1a0a37be6589433b82d66ae3456fcecf25a00d727823a7dff
3
+ size 10735300830
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3394b69596bb7a0fabd54a52ed355095edd88feba30fab32b0ac3a438209ba98
3
+ size 8037713032
modeling_gemma4.py CHANGED
@@ -1506,7 +1506,6 @@ class Gemma4TextModel(Gemma4PreTrainedModel):
1506
  _can_record_outputs = {
1507
  "router_logits": OutputRecorder(Gemma4TextRouter, index=0),
1508
  "hidden_states": Gemma4TextDecoderLayer,
1509
- "attentions": Gemma4TextAttention,
1510
  }
1511
 
1512
  def __init__(self, config: Gemma4TextConfig):
 
1506
  _can_record_outputs = {
1507
  "router_logits": OutputRecorder(Gemma4TextRouter, index=0),
1508
  "hidden_states": Gemma4TextDecoderLayer,
 
1509
  }
1510
 
1511
  def __init__(self, config: Gemma4TextConfig):
modeling_qwen3vlmoetext.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ PyTorch RWKV07BMoE model.
22
+ base code from SmerkyG @ recursal.ai, featherless.ai
23
+ hxa07B implementation RWKV07B + NoPE Hybrid Attention + Mixture of Experts
24
+
25
+ """
26
+
27
+ import math
28
+ import inspect
29
+ from typing import List, Optional, Tuple, Union, Dict, Any
30
+
31
+ import torch
32
+ import torch.utils.checkpoint
33
+ from torch import nn
34
+ import torch.nn.functional as F
35
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
36
+
37
+ from transformers.activations import ACT2FN
38
+ from transformers.cache_utils import Cache#, DynamicCache, CacheLayerMixin
39
+ from transformers.generation import GenerationMixin
40
+ from transformers.integrations import use_kernel_forward_from_hub
41
+ from transformers.masking_utils import create_causal_mask#, create_sliding_window_causal_mask
42
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
43
+ # from transformers.modeling_layers import (
44
+ # GenericForQuestionAnswering,
45
+ # GenericForSequenceClassification,
46
+ # GenericForTokenClassification,
47
+ # GradientCheckpointingLayer,
48
+ # )
49
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
50
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
51
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
52
+ # from transformers.processing_utils import Unpack
53
+ #from transformers.utils import TransformersKwargs#, auto_docstring, can_return_tuple
54
+ # from transformers.utils.generic import check_model_inputs
55
+
56
+ from .configuration_qwen3vlmoetext import RWKV07BMoEConfig
57
+
58
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention,Qwen3MoeSparseMoeBlock,Qwen3MoeMLP,Qwen3MoeDecoderLayer,Qwen3MoeRMSNorm
59
+
60
+ class RWKV07BState():
61
+ def __init__(self) -> None:
62
+ #super().__init__()
63
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
64
+ self.layer_kv_states: List[torch.Tensor] = []
65
+ self.layer_shift_states: List[torch.Tensor] = []
66
+ self.cumulative_scores: List[torch.Tensor] = []
67
+ self.sin: List[torch.Tensor] = []
68
+ self.cos: List[torch.Tensor] = []
69
+
70
+ def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """
72
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
73
+ sequence length.
74
+ """
75
+ if layer_idx < len(self):
76
+ return (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
77
+ else:
78
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
79
+
80
+ def __iter__(self):
81
+ """
82
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
83
+ keys and values
84
+ """
85
+ for layer_idx in range(len(self)):
86
+ yield (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
87
+
88
+ def __len__(self):
89
+ """
90
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
91
+ to the number of layers in the model.
92
+ """
93
+ return len(self.layer_kv_states)
94
+
95
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
96
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
97
+ # Linear Attention variants do not have a maximum length
98
+ return new_seq_length
99
+
100
+ def reorder_cache(self, beam_idx: torch.LongTensor):
101
+ """Reorders the cache for beam search, given the selected beam indices."""
102
+ raise NotImplementedError('Cannot reorder Linear Attention state')
103
+
104
+ def get_seq_length(self, layer_idx: int = 0) -> int:
105
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
106
+ return self._seen_tokens
107
+
108
+ def get_max_cache_shape(self) -> Optional[int]:
109
+ """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
110
+ return None
111
+
112
+ def get_max_length(self) -> Optional[int]:
113
+ """
114
+ Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
115
+ """
116
+ return None
117
+
118
+ def crop(self, max_length: int):
119
+ # can't implement this for linear attention variants
120
+ return
121
+
122
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
123
+ """Return the length and offset of the cache, used to generate the mask"""
124
+ kv_offset = 0
125
+ query_length = cache_position.shape[0]
126
+ past_seen_tokens = self.get_seq_length()
127
+ kv_length = query_length + past_seen_tokens
128
+ return kv_length, kv_offset
129
+
130
+ @property
131
+ def is_compileable(self) -> bool:
132
+ """Return whether the cache is compileable"""
133
+ return True #all(layer.is_compileable for layer in self.layers)
134
+
135
+ @torch.no_grad
136
+ def update(
137
+ self,
138
+ kv_state: torch.Tensor,
139
+ shift_state: torch.Tensor,
140
+ layer_idx: int,
141
+ token_count: int = 0,
142
+ is_attention_layer: bool = True,
143
+ cache_kwargs: Optional[Dict[str, Any]] = None,
144
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
145
+ # Update the number of seen tokens
146
+ if layer_idx == 0:
147
+ if is_attention_layer:
148
+ token_count = kv_state.size(-2)
149
+ self._seen_tokens += token_count
150
+
151
+ #print(f'self._seen_tokens = {self._seen_tokens} layer_idx = {layer_idx} is_attention_layer = {is_attention_layer} kv_state.size(-2) = {kv_state.size(-2)}')
152
+
153
+ # Update the cache
154
+ if kv_state is not None:
155
+ # There may be skipped layers, fill them with empty lists
156
+ if layer_idx >= len(self.layer_kv_states):
157
+ for _ in range(len(self.layer_kv_states), layer_idx):
158
+ if is_attention_layer:
159
+ self.layer_kv_states.append(torch.tensor([], dtype=kv_state.dtype, device=kv_state.device)) # acts as key_cache
160
+ self.layer_shift_states.append(torch.tensor([], dtype=shift_state.dtype, device=shift_state.device)) # acts as value_cache
161
+ else:
162
+ self.layer_kv_states.append(torch.zeros_like(kv_state).requires_grad_(False))
163
+ self.layer_shift_states.append(torch.zeros_like(shift_state).requires_grad_(False))
164
+ self.layer_kv_states.append(kv_state) # acts as key_cache
165
+ self.layer_shift_states.append(shift_state) # acts as value_cache
166
+ else:
167
+ if is_attention_layer:
168
+ self.layer_kv_states[layer_idx] = torch.cat([self.layer_kv_states[layer_idx], kv_state], dim=-2) # acts as key_cache
169
+ self.layer_shift_states[layer_idx] = torch.cat([self.layer_shift_states[layer_idx], shift_state], dim=-2) # acts as value_cache
170
+ else:
171
+ self.layer_kv_states[layer_idx].copy_(kv_state)
172
+ self.layer_shift_states[layer_idx].copy_(shift_state)
173
+
174
+ return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
175
+
176
+ # try:
177
+ # from fla.ops.rwkv7.chunk import chunk_rwkv7
178
+ # from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7
179
+ # except ImportError:
180
+ # print("Required module is not installed. Please install it using the following commands:")
181
+ # print("pip install --no-use-pep517 flash-linear-attention")
182
+ # print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
183
+ # print("pip install triton>=2.2.0")
184
+
185
+ # def is_layer_attention(config, layer_id):
186
+ # return layer_id >= config.first_attention_layer and layer_id < config.first_post_attention_layer and (layer_id > min(config.num_hidden_layers, config.last_striping_layer) or (min(config.num_hidden_layers-1, config.last_striping_layer) - layer_id) % config.attention_striping == 0)
187
+
188
+ def is_layer_attention(config, layer_id):
189
+ return layer_id in config.transformer_layers
190
+
191
+ def repeat_kv_rwkv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
192
+ """
193
+ Repeat KV heads along the head dimension (GQA).
194
+ Input: (B, T, H_kv, D)
195
+ Output: (B, T, H_kv * n_rep, D)
196
+ """
197
+ B, T, H_kv, D = hidden_states.shape
198
+ if n_rep == 1:
199
+ return hidden_states
200
+ # Expand head dim
201
+ hidden_states = hidden_states[:, :, :, None, :] # (B, T, H_kv, 1, D)
202
+ hidden_states = hidden_states.expand(B, T, H_kv, n_rep, D) # (B, T, H_kv, n_rep, D)
203
+ return hidden_states.reshape(B, T, H_kv * n_rep, D).contiguous()
204
+
205
+ def T5RMSNorm(hidden_states,weight,variance_epsilon:float=1e-6):
206
+ input_dtype = hidden_states.dtype
207
+ hidden_states = hidden_states.to(torch.float32)
208
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
209
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
210
+ return (weight * hidden_states).to(input_dtype)
211
+
212
+ def compute_qwen3_rope_cache(seq_len, rotary_dim, device, dtype, rope_theta):
213
+ half_dim = rotary_dim // 2
214
+ freq_seq = torch.arange(half_dim, dtype=dtype, device=device)
215
+ inv_freq = 1.0 / (rope_theta ** (freq_seq / half_dim))
216
+ positions = torch.arange(seq_len, dtype=dtype, device=device)
217
+ freqs = torch.einsum("i,j->ij", positions, inv_freq)
218
+ emb = torch.cat([freqs, freqs], dim=-1)
219
+ cos = emb.cos()
220
+ sin = emb.sin()
221
+ return cos.unsqueeze(0), sin.unsqueeze(0), inv_freq
222
+
223
+ def compute_qwen3_mrope_cache_text_only(
224
+ seq_len: int,
225
+ rotary_dim: int,
226
+ device,
227
+ dtype=torch.float32,
228
+ rope_theta: float = 5000000,
229
+ mrope_section=(24, 20, 20), # Qwen3VL のデフォルト想定
230
+ ):
231
+ """
232
+ Qwen3VL の text-only MRoPE と互換な cos/sin キャッシュを作る版。
233
+ 戻り値の cos/sin shape は (1, seq_len, rotary_dim) で、
234
+ 既存の apply_rotary_pos_emb からそのまま使える想定。
235
+ """
236
+ half_dim = rotary_dim // 2
237
+
238
+ # 1D RoPE と同じ inv_freq
239
+ freq_seq = torch.arange(half_dim, dtype=torch.float32, device=device)
240
+ inv_freq = 1.0 / (rope_theta ** (freq_seq / half_dim))
241
+
242
+ # positions: 0..T-1
243
+ positions = torch.arange(seq_len, dtype=torch.float32, device=device) # (T,)
244
+
245
+ # text-only なので T/H/W すべて同じ positions を使う: (3, 1, T)
246
+ position_ids = positions.view(1, 1, seq_len).expand(3, 1, -1)
247
+
248
+ # (3, 1, half_dim, 1) と (3, 1, 1, T) から freqs: (3, 1, T, half_dim)
249
+ inv_freq_expanded = inv_freq.view(1, 1, half_dim, 1).expand(3, 1, half_dim, 1)
250
+ pos_expanded = position_ids.view(3, 1, 1, seq_len)
251
+ freqs = torch.matmul(inv_freq_expanded, pos_expanded).transpose(2, 3) # (3, 1, T, half_dim)
252
+
253
+ # --- Qwen3VL の apply_interleaved_mrope 相当 ---
254
+ # freqs[0]: T 軸用をベースにして、H/W 軸の一部をインターリーブ
255
+ freqs_t = freqs[0] # (1, T, half_dim)
256
+
257
+ # dim=1,2 が H,W 軸
258
+ for dim, offset in enumerate((1, 2), start=1): # H, W
259
+ length = mrope_section[dim] * 3 # 例: 20 * 3 = 60
260
+ end = min(length, half_dim) # 安全のため half_dim を超えないように
261
+ idx = slice(offset, end, 3) # 1,4,7,... / 2,5,8,... みたいなインターリーブ位置
262
+ freqs_t[..., idx] = freqs[dim, ..., idx]
263
+
264
+ # 最後に [freqs_t, freqs_t] を結合して rotary_dim にする
265
+ emb = torch.cat([freqs_t, freqs_t], dim=-1) # (1, T, rotary_dim)
266
+
267
+ cos = emb.cos().to(dtype)
268
+ sin = emb.sin().to(dtype)
269
+ return cos, sin, inv_freq.to(dtype)
270
+
271
+
272
+ # def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
273
+ # """Applies Rotary Position Embedding to the query and key tensors.
274
+
275
+ # Args:
276
+ # q (`torch.Tensor`): The query tensor.
277
+ # k (`torch.Tensor`): The key tensor.
278
+ # cos (`torch.Tensor`): The cosine part of the rotary embedding.
279
+ # sin (`torch.Tensor`): The sine part of the rotary embedding.
280
+ # position_ids (`torch.Tensor`, *optional*):
281
+ # Deprecated and unused.
282
+ # unsqueeze_dim (`int`, *optional*, defaults to 1):
283
+ # The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
284
+ # sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
285
+ # that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
286
+ # k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
287
+ # cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
288
+ # the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
289
+ # Returns:
290
+ # `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
291
+ # """
292
+ # cos = cos.unsqueeze(unsqueeze_dim)
293
+ # sin = sin.unsqueeze(unsqueeze_dim)
294
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
295
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
296
+ # return q_embed, k_embed
297
+
298
+ class Qwen3RotaryEmbedding(nn.Module):
299
+ def __init__(self, config: RWKV07BMoEConfig, device=None):
300
+ super().__init__()
301
+ # BC: "rope_type" was originally "type"
302
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
303
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
304
+ else:
305
+ self.rope_type = "default"
306
+ self.max_seq_len_cached = config.max_position_embeddings
307
+ self.original_max_seq_len = config.max_position_embeddings
308
+
309
+ self.config = config
310
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
311
+
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
313
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
314
+ self.original_inv_freq = self.inv_freq
315
+
316
+ def _dynamic_frequency_update(self, position_ids, device):
317
+ """
318
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
319
+ 1 - growing beyond the cached sequence length (allow scaling)
320
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
321
+ """
322
+ seq_len = torch.max(position_ids) + 1
323
+ if seq_len > self.max_seq_len_cached: # growth
324
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
326
+ self.max_seq_len_cached = seq_len
327
+
328
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
329
+ # This .to() is needed if the model has been moved to a device after being initialized (because
330
+ # the buffer is automatically moved, but not the original copy)
331
+ self.original_inv_freq = self.original_inv_freq.to(device)
332
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
333
+ self.max_seq_len_cached = self.original_max_seq_len
334
+
335
+ @torch.no_grad()
336
+ def forward(self, x, position_ids):
337
+ if "dynamic" in self.rope_type:
338
+ self._dynamic_frequency_update(position_ids, device=x.device)
339
+
340
+ # Core RoPE block
341
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
342
+ position_ids_expanded = position_ids[:, None, :].float()
343
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
344
+ device_type = x.device.type
345
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
346
+ with torch.autocast(device_type=device_type, enabled=False):
347
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
348
+ emb = torch.cat((freqs, freqs), dim=-1)
349
+ cos = emb.cos()
350
+ sin = emb.sin()
351
+
352
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
353
+ cos = cos * self.attention_scaling
354
+ sin = sin * self.attention_scaling
355
+
356
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
357
+
358
+ def rms_norm(hidden_states, eps = 1e-6):
359
+ #print('ugyuugyu')
360
+ input_dtype = hidden_states.dtype
361
+ hidden_states = hidden_states.to(torch.float32)
362
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
363
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
364
+ return hidden_states.to(input_dtype)
365
+
366
+ def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
367
+ #inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
368
+
369
+ angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
370
+ angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
371
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
372
+ emb = torch.cat((angles, angles), dim=-1)
373
+ return torch.stack([emb.cos(), emb.sin()], dim=0)
374
+ #return torch.polar(torch.ones_like(angles), angles)
375
+
376
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
377
+ def rotate_half(x):
378
+ """Rotates half the hidden dims of the input."""
379
+ x1 = x[..., : x.shape[-1] // 2]
380
+ x2 = x[..., x.shape[-1] // 2 :]
381
+ return torch.cat((-x2, x1), dim=-1)
382
+
383
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
384
+ """Applies Rotary Position Embedding to the query and key tensors.
385
+
386
+ Args:
387
+ q (`torch.Tensor`): The query tensor.
388
+ k (`torch.Tensor`): The key tensor.
389
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
390
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
391
+ position_ids (`torch.Tensor`, *optional*):
392
+ Deprecated and unused.
393
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
394
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
395
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
396
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
397
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
398
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
399
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
400
+ Returns:
401
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
402
+ """
403
+ cos = cos.unsqueeze(unsqueeze_dim)
404
+ sin = sin.unsqueeze(unsqueeze_dim)
405
+ q_embed = (q * cos) + (rotate_half(q) * sin)
406
+ k_embed = (k * cos) + (rotate_half(k) * sin)
407
+ return q_embed, k_embed
408
+
409
+ def apply_rotary_pos_emb_single(x, cos, sin, unsqueeze_dim=1):
410
+ return (x * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(x) * sin.unsqueeze(unsqueeze_dim))
411
+
412
+ from typing import Callable, Optional, Tuple, Union
413
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
414
+ from transformers.processing_utils import Unpack
415
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
416
+
417
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
418
+ """
419
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
420
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
421
+ """
422
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
423
+ if n_rep == 1:
424
+ return hidden_states
425
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
426
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
427
+
428
+ def eager_attention_forward(
429
+ module: nn.Module,
430
+ query: torch.Tensor,
431
+ key: torch.Tensor,
432
+ value: torch.Tensor,
433
+ attention_mask: Optional[torch.Tensor],
434
+ scaling: float,
435
+ dropout: float = 0.0,
436
+ **kwargs,
437
+ ):
438
+ key_states = repeat_kv(key, module.num_key_value_groups)
439
+ value_states = repeat_kv(value, module.num_key_value_groups)
440
+
441
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
442
+ if attention_mask is not None:
443
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
444
+ attn_weights = attn_weights + causal_mask
445
+
446
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
447
+ attn_weights = attn_weights.masked_fill(attn_weights.isnan(), 0) # IMPORTANT FOR BATCHED INFERENCE IN LM EVAL!
448
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
449
+ attn_output = torch.matmul(attn_weights, value_states)
450
+ attn_output = attn_output.transpose(1, 2).contiguous()
451
+
452
+ return attn_output, attn_weights
453
+
454
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask
455
+ from functools import lru_cache
456
+
457
+ block_mask = None
458
+
459
+
460
+
461
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
462
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
463
+ L, S = query.size(-2), key.size(-2)
464
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
465
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
466
+ if is_causal:
467
+ assert attn_mask is None
468
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
469
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
470
+ attn_bias.to(query.dtype)
471
+
472
+ if attn_mask is not None:
473
+ if attn_mask.dtype == torch.bool:
474
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
475
+ else:
476
+ attn_bias = attn_mask + attn_bias
477
+
478
+ if enable_gqa:
479
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
480
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
481
+
482
+ attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor
483
+ attn_weight += attn_bias.float()
484
+ #attn_weight = stable_softmax(attn_weight, dim=-1)
485
+ attn_weight = torch.softmax(attn_weight, dim=-1)
486
+ attn_weight = attn_weight.masked_fill(attn_weight.isnan(), 0) # IMPORTANT FOR BATCHED INFERENCE IN LM EVAL!
487
+ #attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
488
+ return attn_weight @ value.float()
489
+
490
+
491
+
492
+ class Attention_Causal(Qwen3MoeAttention):
493
+ def forward(
494
+ self,
495
+ hidden_states: torch.Tensor,
496
+ frozen_residual: torch.Tensor,
497
+ # v_first: Optional[torch.Tensor] = None,
498
+ # k_first: Optional[torch.Tensor] = None,
499
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ past_key_values: Optional[Cache] = None,
502
+ cache_position: Optional[torch.LongTensor] = None,
503
+ **kwargs: Unpack[FlashAttentionKwargs],
504
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
505
+ x = hidden_states
506
+
507
+ B, L, D = x.size()
508
+
509
+ input_shape = x.shape[:-1]
510
+ hidden_shape = (*input_shape, -1, self.head_dim)
511
+
512
+ q = self.q_norm(self.q_proj(x).view(hidden_shape)).transpose(1, 2)
513
+ k = self.k_norm(self.k_proj(x).view(hidden_shape)).transpose(1, 2)
514
+
515
+ cos, sin = position_embeddings
516
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
517
+
518
+
519
+
520
+
521
+
522
+
523
+ v = self.v_proj(x).view(hidden_shape).transpose(1, 2)
524
+
525
+ if past_key_values is not None:
526
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
527
+ cache_kwargs = {"cache_position": cache_position}
528
+ k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs)
529
+
530
+ # repeat k/v heads if n_kv_heads < n_heads
531
+ k = repeat_kv(k, self.num_key_value_groups)
532
+ v = repeat_kv(v, self.num_key_value_groups)
533
+
534
+ S = k.size(-2)
535
+
536
+ y = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, attn_mask=attention_mask, is_causal=attention_mask is None and L==S)
537
+ y = y.transpose(1,2)
538
+ y = y.reshape(*input_shape, -1)#.contiguous()
539
+ y = self.o_proj(y)
540
+
541
+ attn_weights = None
542
+
543
+ return y, attn_weights#, v_first, k_first
544
+
545
+
546
+ class RWKV07BAttention(nn.Module):
547
+ def __init__(self, config, layer_idx: Optional[int] = None):
548
+ super().__init__()
549
+ self.config = config
550
+ self.layer_idx = layer_idx
551
+ C = self.hidden_size = config.hidden_size
552
+ H = self.num_heads = config.num_attention_heads
553
+ H_kv = config.num_key_value_heads
554
+ N = self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
555
+ self.num_key_value_heads = config.num_key_value_heads
556
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
557
+ self.attention_dropout = config.attention_dropout
558
+
559
+ if self.hidden_size % self.num_heads != 0:
560
+ raise ValueError(
561
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
562
+ f" and `num_heads`: {self.num_heads})."
563
+ )
564
+ self.receptance = nn.Linear(
565
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
566
+ )
567
+ self.key = nn.Linear(
568
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
569
+ )
570
+ self.value = nn.Linear(
571
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
572
+ )
573
+ self.output = nn.Linear(
574
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
575
+ )
576
+ self.r_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
577
+ self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
578
+
579
+
580
+ lora_rank_decay = config.lora_rank_decay
581
+ lora_rank_iclr = config.lora_rank_iclr
582
+ lora_rank_value_residual_mix = config.lora_rank_value_residual_mix
583
+ lora_rank_key_residual_mix = config.lora_rank_key_residual_mix
584
+ lora_rank_gate = config.lora_rank_gate
585
+
586
+ print(f"v lora projection = {lora_rank_value_residual_mix} k lora projection={lora_rank_key_residual_mix}")
587
+
588
+
589
+ self.w0 = nn.Parameter(torch.empty(1,1,H*N))
590
+ self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
591
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, H*N))
592
+
593
+ self.a0 = nn.Parameter(torch.empty(1,1,H*N))
594
+ self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
595
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, H*N))
596
+
597
+
598
+ #self.v0 = nn.Parameter(torch.empty(1,1,H_kv*N))
599
+ self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
600
+ self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, H*N))
601
+
602
+ #self.k0 = nn.Parameter(torch.empty(1,1,H_kv*N))
603
+ self.k1 = nn.Parameter(torch.empty(C, lora_rank_key_residual_mix))
604
+ self.k2 = nn.Parameter(torch.empty(lora_rank_key_residual_mix, H*N))
605
+
606
+
607
+ self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
608
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, H*N))
609
+
610
+ self.D_MK_LoRA_Scaling = 0.1
611
+ self.D_MV_LoRA_Scaling = 0.2
612
+
613
+ #self.r_k = nn.Parameter(torch.empty(H,N))
614
+
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states: torch.Tensor,
619
+ frozen_residual: torch.Tensor,
620
+ v_first: Optional[torch.Tensor] = None,
621
+ k_first: Optional[torch.Tensor] = None,
622
+ attention_mask: Optional[torch.Tensor] = None,
623
+ position_ids: Optional[torch.LongTensor] = None,
624
+ past_key_values: Optional[RWKV07BState] = None,
625
+ output_attentions: bool = False,
626
+ use_cache: bool = False,
627
+ cache_position: Optional[torch.LongTensor] = None,
628
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
629
+ **kwargs,
630
+ ):
631
+ if attention_mask is not None:
632
+ assert len(attention_mask.shape) in (2, 4)
633
+
634
+ output_shift_state = hidden_states[:, -1:].detach().clone()
635
+
636
+ x = hidden_states
637
+
638
+ B, T, C = hidden_states.shape
639
+ H = self.num_heads
640
+ N = self.head_dim
641
+
642
+ q_len = T
643
+
644
+ if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
645
+ #print(f'use past state layer {self.layer_idx}')
646
+ input_vk_state, input_shift_state = past_key_values[self.layer_idx]
647
+ else:
648
+ input_vk_state, input_shift_state = torch.zeros(B,H,N,N, dtype=torch.bfloat16,device=x.device), torch.zeros_like(x[:, -1:])
649
+
650
+ xr = xw = xk = xv = xa = xg = x
651
+
652
+ r = self.r_norm(self.receptance(xr).view(B,T,-1,N))
653
+ w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) -0.5
654
+ k = self.k_norm(self.key(xk).view(B,T,-1,N))
655
+ v = self.value(xv).view(B,T,-1,N)
656
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
657
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
658
+
659
+ if position_embeddings is not None:
660
+ cos, sin = position_embeddings
661
+ r, k = apply_rotary_pos_emb(r, k, cos, sin, unsqueeze_dim=2)
662
+
663
+ if attention_mask is not None:
664
+ if attention_mask is not None:
665
+ if attention_mask.ndim == 2:
666
+ # [B, S]
667
+ mask = attention_mask[:, -T:] # [B, T]
668
+ v = v * mask[:, :, None, None] # → [B, T, 1, 1] に拡張して掛け算
669
+ elif attention_mask.ndim == 4:
670
+ # [B, 1, L, S]
671
+ mask = attention_mask[:, 0, -1, -T:] # [B, T]
672
+ v = v * mask[:, :, None, None] # 同上
673
+
674
+
675
+ # repeat k/v heads if n_kv_heads < n_heads
676
+ # add LoRA Projection after expand
677
+ k = repeat_kv_rwkv(k, self.num_key_value_groups).view(B, T, -1)# + (((x @ self.k1) @ self.k2) * self.D_MK_LoRA_Scaling)
678
+ v = repeat_kv_rwkv(v, self.num_key_value_groups).view(B, T, -1) + (((x @ self.v1) @ self.v2) * self.D_MV_LoRA_Scaling)
679
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
680
+
681
+ kk = (k).view(B,T,H,-1).float()
682
+ kk = (kk / (torch.norm(kk, dim=-1, keepdim=True) + 1e-12)).view(B,T,-1).to(k.dtype)
683
+ k = k * (1.0 - w + a)
684
+
685
+ aa = -kk
686
+ bb = kk * a
687
+ w = -w.exp()
688
+
689
+ r_,w_,k_,v_,aa_,bb_ = [i.view(B,T,H,N) for i in [r,w,k,v,aa,bb]]
690
+
691
+ x, output_vk_state = fused_recurrent_rwkv7(r_, w_, k_, v_, aa_, bb_, scale=1.0, initial_state=input_vk_state, output_final_state=True, head_first=False)
692
+
693
+ x = x.view(B,T,-1) * (float(N) ** -0.5)
694
+
695
+ x = x * g
696
+ x = self.output(x)
697
+
698
+ if past_key_values is not None:
699
+ past_key_values.update(output_vk_state, output_shift_state, self.layer_idx, q_len, is_layer_attention(self.config, self.layer_idx))
700
+
701
+ return x, v_first, k_first
702
+
703
+
704
+
705
+ class RWKV07BMoEDecoderLayer(nn.Module):
706
+ def __init__(self, config: RWKV07BMoEConfig, layer_idx: int):
707
+ super().__init__()
708
+ self.hidden_size = config.hidden_size
709
+
710
+ self.layer_idx = layer_idx
711
+
712
+ if is_layer_attention(config, layer_idx):
713
+ print(f'layer {layer_idx} : attention')
714
+ att_fn = Attention_Causal #Qwen3KeyQuant #Qwen3SWAPrefill #Qwen3DropoutSWASink #Qwen3AttentionNoPE #Qwen3MOBA #Qwen3AttentionVerticalSparse # Qwen3DoubleAttention # Qwen3SymPow #Qwen3Chunk #Qwen3Power #Qwen3MOBA #Qwen3Attention # Qwen3NewAttention # Qwen3AttentionAdapted
715
+ else:
716
+ print(f'layer {layer_idx} : rwkv')
717
+ att_fn = RWKV07BAttention
718
+
719
+ self.self_attn = att_fn(config, layer_idx)
720
+
721
+ if (layer_idx not in config.mlp_only_layers) and (
722
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
723
+ ):
724
+ self.mlp = Qwen3MoeSparseMoeBlock(config)
725
+ else:
726
+ self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
727
+
728
+ self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
729
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
730
+ self.attention_type = config.layer_types[layer_idx]
731
+
732
+ def forward(
733
+ self,
734
+ hidden_states: torch.Tensor,
735
+ frozen_residual: torch.Tensor,
736
+ attention_mask: Optional[torch.Tensor] = None,
737
+ position_ids: Optional[torch.LongTensor] = None,
738
+ past_key_values: Optional[Cache] = None,
739
+ output_attentions: Optional[bool] = False,
740
+ use_cache: Optional[bool] = False,
741
+ cache_position: Optional[torch.LongTensor] = None,
742
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
743
+ **kwargs,
744
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
745
+ """
746
+ Args:
747
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
748
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
749
+ `(batch, sequence_length)` where padding elements are indicated by 0.
750
+ output_attentions (`bool`, *optional*):
751
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
752
+ returned tensors for more detail.
753
+ output_router_logits (`bool`, *optional*):
754
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
755
+ and should not be returned during inference.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
+ (see `past_key_values`).
759
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
760
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
761
+ Indices depicting the position of the input sequence tokens in the sequence.
762
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
763
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
764
+ with `head_dim` being the embedding dimension of each attention head.
765
+ kwargs (`dict`, *optional*):
766
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
767
+ into the model
768
+ """
769
+ residual = hidden_states
770
+
771
+ hidden_states = self.input_layernorm(hidden_states)
772
+
773
+ # Self Attention
774
+ hidden_states,self_attn_weights = self.self_attn(
775
+ hidden_states=hidden_states,
776
+ frozen_residual=frozen_residual,
777
+ attention_mask=attention_mask,
778
+ position_ids=position_ids,
779
+ past_key_values=past_key_values,
780
+ output_attentions=output_attentions,
781
+ use_cache=use_cache,
782
+ cache_position=cache_position,
783
+ position_embeddings=position_embeddings,
784
+ #is_causal=True,
785
+ )
786
+
787
+ hidden_states = residual + hidden_states
788
+
789
+ # Fully Connected
790
+ residual = hidden_states
791
+ hidden_states = self.post_attention_layernorm(hidden_states)
792
+ hidden_states = self.mlp(hidden_states)
793
+ # For the MoE layers, we need to unpack
794
+ if isinstance(hidden_states, tuple):
795
+ hidden_states, _ = hidden_states
796
+ hidden_states = residual + hidden_states
797
+
798
+ outputs = (hidden_states,)
799
+ #print(f'output_attentions = {output_attentions} self_attn_weights = {self_attn_weights}')
800
+ if output_attentions:
801
+ outputs += (self_attn_weights,)
802
+
803
+ #outputs = (hidden_states, v_first,k_first,)
804
+
805
+ return outputs
806
+
807
+
808
+ #@auto_docstring
809
+ class RWKV07BMoEPreTrainedModel(PreTrainedModel):
810
+ config: RWKV07BMoEConfig
811
+ config_class = RWKV07BMoEConfig
812
+ base_model_prefix = "model"
813
+ supports_gradient_checkpointing = True
814
+ _no_split_modules = ["RWKV07BMoEDecoderLayer"]
815
+ _skip_keys_device_placement = "past_key_values"
816
+ _supports_flash_attn_2 = True
817
+ _supports_sdpa = True
818
+ _supports_flex_attn = True
819
+
820
+ _supports_cache_class = True
821
+ _supports_quantized_cache = True
822
+ _supports_static_cache = True
823
+
824
+ # def _init_weights(self, module):
825
+ # std = self.config.initializer_range
826
+ # if isinstance(module, nn.Linear):
827
+ # module.weight.data.normal_(mean=0.0, std=std)
828
+ # if module.bias is not None:
829
+ # module.bias.data.zero_()
830
+ # elif isinstance(module, nn.Embedding):
831
+ # module.weight.data.normal_(mean=0.0, std=std)
832
+ # if module.padding_idx is not None:
833
+ # module.weight.data[module.padding_idx].zero_()
834
+
835
+ class Qwen3MoeMRoPERotaryEmbedding(nn.Module):
836
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
837
+
838
+ def __init__(self, config: RWKV07BMoEConfig, device=None):
839
+ super().__init__()
840
+ self.max_seq_len_cached = config.max_position_embeddings
841
+ self.original_max_seq_len = config.max_position_embeddings
842
+
843
+ self.config = config
844
+
845
+ self.rope_type = self.config.rope_parameters["rope_type"]
846
+ rope_init_fn: Callable = self.compute_default_rope_parameters
847
+ if self.rope_type != "default":
848
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
849
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
850
+
851
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
852
+ self.original_inv_freq = inv_freq
853
+
854
+ # Qwen3VL と同じセクション指定を想定(config にあればそれを使う)
855
+ self.mrope_section = self.config.rope_parameters.get("mrope_section", [24, 20, 20])
856
+
857
+ @staticmethod
858
+ def compute_default_rope_parameters(
859
+ config: Optional[RWKV07BMoEConfig] = None,
860
+ device: Optional["torch.device"] = None,
861
+ seq_len: Optional[int] = None,
862
+ ) -> tuple["torch.Tensor", float]:
863
+ """
864
+ Qwen3 系の通常 RoPE と同じ inv_freq を作る
865
+ """
866
+ base = config.rope_theta
867
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
868
+
869
+ attention_factor = 1.0 # このタイプの RoPE では未使用
870
+
871
+ inv_freq = 1.0 / (
872
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
873
+ )
874
+ return inv_freq, attention_factor
875
+
876
+ def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section):
877
+ """
878
+ Qwen3VLTextRotaryEmbedding の apply_interleaved_mrope と互換のロジック。
879
+ freqs: (3, B, T, dim_half) [0:T, 1:H, 2:W]
880
+ 戻り値: (B, T, dim_half)
881
+ """
882
+ # T 軸の周波数をベースにする
883
+ freqs_t = freqs[0] # (B, T, dim_half)
884
+ _, _, _, dim_half = freqs.shape
885
+
886
+ # dim=1: H, dim=2: W
887
+ for dim, offset in enumerate((1, 2), start=1):
888
+ length = mrope_section[dim] * 3
889
+ length = min(length, dim_half) # 安全のため head_dim//2 を超えないようにする
890
+ if length <= offset:
891
+ continue
892
+ idx = slice(offset, length, 3) # 1,4,7,... / 2,5,8,... といったインターリーブ位置
893
+ freqs_t[..., idx] = freqs[dim, ..., idx]
894
+
895
+ return freqs_t # (B, T, dim_half)
896
+
897
+ @torch.no_grad()
898
+ @dynamic_rope_update # RoPE の動的スケーリングにはそのまま対応
899
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
900
+ """
901
+ x: (B, T, hidden_size) 相当を想定(dtype / device を取得するため)
902
+ position_ids: (B, T) または (3, B, T) を想定
903
+ 戻り値:
904
+ cos, sin: (B, T, head_dim) で既存 apply_rotary_pos_emb と互換
905
+ """
906
+ device = x.device
907
+ dtype = x.dtype
908
+
909
+ # position_ids を (3, B, T) に正規化
910
+ if position_ids.ndim == 2:
911
+ # text-only なので T/H/W すべて同じ position を使う
912
+ position_ids_3 = position_ids.unsqueeze(0).expand(3, -1, -1) # (3, B, T)
913
+ elif position_ids.ndim == 3 and position_ids.shape[0] == 3:
914
+ position_ids_3 = position_ids
915
+ else:
916
+ raise ValueError(
917
+ f"position_ids must be (B,T) or (3,B,T), but got shape {position_ids.shape}"
918
+ )
919
+
920
+ B, T = position_ids_3.shape[1], position_ids_3.shape[2]
921
+ dim_half = self.inv_freq.shape[0] # head_dim // 2
922
+
923
+ # inv_freq: (dim_half,) -> (3, B, dim_half, 1)
924
+ inv_freq_expanded = (
925
+ self.inv_freq.view(1, 1, dim_half, 1)
926
+ .float()
927
+ .expand(3, B, dim_half, 1)
928
+ .to(device)
929
+ )
930
+
931
+ # position_ids: (3, B, T) -> (3, B, 1, T)
932
+ position_ids_expanded = position_ids_3.float().view(3, B, 1, T)
933
+
934
+ device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
935
+ with torch.autocast(device_type=device_type, enabled=False): # 強制 float32
936
+ # (3, B, dim_half, 1) @ (3, B, 1, T) -> (3, B, dim_half, T) -> (3, B, T, dim_half)
937
+ freqs = torch.matmul(inv_freq_expanded, position_ids_expanded).transpose(2, 3)
938
+
939
+ # MRoPE のインターリーブを適用して (B, T, dim_half) を得る
940
+ freqs_t = self.apply_interleaved_mrope(freqs, self.mrope_section)
941
+
942
+ # rotary_dim (=head_dim) にするために 2 倍に連結
943
+ emb = torch.cat((freqs_t, freqs_t), dim=-1) # (B, T, head_dim)
944
+
945
+ cos = emb.cos() * self.attention_scaling
946
+ sin = emb.sin() * self.attention_scaling
947
+
948
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
949
+
950
+
951
+ #@auto_docstring
952
+ class RWKV07BMoEModel(RWKV07BMoEPreTrainedModel):
953
+ """
954
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]
955
+
956
+ Args:
957
+ config: RWKV07BMoEConfig
958
+ """
959
+
960
+ def __init__(self, config: RWKV07BMoEConfig):
961
+ super().__init__(config)
962
+ self.padding_idx = config.pad_token_id
963
+ self.vocab_size = config.vocab_size
964
+
965
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
966
+ self.layers = nn.ModuleList(
967
+ [RWKV07BMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
968
+ )
969
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
970
+ self.rotary_emb = Qwen3MoeMRoPERotaryEmbedding(config=config)
971
+ self.gradient_checkpointing = False
972
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
973
+
974
+ # Initialize weights and apply final processing
975
+ self.post_init()
976
+ def get_input_embeddings(self):
977
+ # HF の PreTrainedModel から呼ばれる想定の実装
978
+ return self.embed_tokens
979
+
980
+ def set_input_embeddings(self, new_embeddings: nn.Embedding):
981
+ # HF の resize_token_embeddings などが使えるように
982
+ self.embed_tokens = new_embeddings
983
+
984
+ #@check_model_inputs
985
+ #@auto_docstring
986
+ def forward(
987
+ self,
988
+ input_ids: Optional[torch.LongTensor] = None,
989
+ attention_mask: Optional[torch.Tensor] = None,
990
+ position_ids: Optional[torch.LongTensor] = None,
991
+ past_key_values: Optional[Cache] = None,
992
+ inputs_embeds: Optional[torch.FloatTensor] = None,
993
+ use_cache: Optional[bool] = None,
994
+ output_attentions: Optional[bool] = None,
995
+ output_hidden_states: Optional[bool] = None,
996
+ cache_position: Optional[torch.LongTensor] = None,
997
+ **kwargs,#: Unpack[TransformersKwargs],
998
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
999
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1000
+ output_hidden_states = (
1001
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1002
+ )
1003
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1004
+
1005
+ if (input_ids is None) ^ (inputs_embeds is not None):
1006
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1007
+
1008
+ if self.gradient_checkpointing and self.training and use_cache:
1009
+ logger.warning_once(
1010
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1011
+ )
1012
+ use_cache = False
1013
+
1014
+ if inputs_embeds is None:
1015
+ inputs_embeds = self.embed_tokens(input_ids)
1016
+
1017
+ if use_cache and not isinstance(past_key_values, RWKV07BState):
1018
+ past_key_values = RWKV07BState()
1019
+
1020
+ if cache_position is None:
1021
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1022
+ cache_position = torch.arange(
1023
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1024
+ )
1025
+
1026
+ if position_ids is None:
1027
+ position_ids = cache_position.unsqueeze(0)
1028
+
1029
+ # It may already have been prepared by e.g. `generate`
1030
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
1031
+ # Prepare mask arguments
1032
+ mask_kwargs = {
1033
+ "config": self.config,
1034
+ "input_embeds": inputs_embeds,
1035
+ "attention_mask": attention_mask,
1036
+ "cache_position": cache_position,
1037
+ "past_key_values": past_key_values,
1038
+ "position_ids": position_ids,
1039
+ }
1040
+ # Create the masks
1041
+ causal_mask_mapping = {
1042
+ "full_attention": create_causal_mask(**mask_kwargs),
1043
+ }
1044
+ # The sliding window alternating layers are not always activated depending on the config
1045
+ if self.has_sliding_layers:
1046
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1047
+
1048
+ hidden_states = inputs_embeds
1049
+
1050
+ # create position embeddings to be shared across the decoder layers
1051
+ if self.config.use_rope:
1052
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1053
+ else:
1054
+ position_embeddings = None
1055
+
1056
+ # decoder layers
1057
+ all_hidden_states = () if output_hidden_states else None
1058
+ all_self_attns = () if output_attentions else None
1059
+ next_decoder_cache = None
1060
+ v_first = None
1061
+ k_first = None
1062
+ frozen_residual = None
1063
+
1064
+ for decoder_layer in self.layers:
1065
+ if not is_layer_attention(self.config, decoder_layer.layer_idx):
1066
+ frozen_residual = hidden_states#rms_norm(hidden_states)
1067
+ if output_hidden_states:
1068
+ all_hidden_states += (hidden_states,)
1069
+
1070
+ attention_mask = causal_mask_mapping[decoder_layer.attention_type]
1071
+ if attention_mask is not None and attention_mask.ndim == 1:
1072
+ attention_mask = None
1073
+ #attention_mask = None
1074
+
1075
+ layer_outputs = decoder_layer(
1076
+ hidden_states,
1077
+ frozen_residual=frozen_residual,
1078
+ attention_mask=attention_mask,
1079
+ position_ids=position_ids,
1080
+ past_key_values=past_key_values,
1081
+ output_attentions=output_attentions,
1082
+ use_cache=use_cache,
1083
+ cache_position=cache_position,
1084
+ position_embeddings=position_embeddings,
1085
+
1086
+ )
1087
+
1088
+ hidden_states = layer_outputs[0]
1089
+
1090
+ if output_attentions:
1091
+ all_self_attns += (layer_outputs[1],)
1092
+
1093
+ hidden_states = self.norm(hidden_states)
1094
+
1095
+ # add hidden states from the last decoder layer
1096
+ if output_hidden_states:
1097
+ all_hidden_states += (hidden_states,)
1098
+
1099
+ #if return_legacy_cache:
1100
+ # next_cache = next_cache.to_legacy_cache()
1101
+
1102
+ return BaseModelOutputWithPast(
1103
+ last_hidden_state=hidden_states,
1104
+ past_key_values=past_key_values if use_cache else None,
1105
+ hidden_states=all_hidden_states,
1106
+ attentions=all_self_attns,
1107
+ )
1108
+
1109
+ class RWKV07BMoEForCausalLM(RWKV07BMoEPreTrainedModel, GenerationMixin):
1110
+ _tied_weights_keys = ["lm_head.weight"]
1111
+
1112
+ def __init__(self, config):
1113
+ super().__init__(config)
1114
+ self.model = RWKV07BMoEModel(config)
1115
+ self.vocab_size = config.vocab_size
1116
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1117
+
1118
+ # Initialize weights and apply final processing
1119
+ self.post_init()
1120
+
1121
+ #@can_return_tuple
1122
+ #@auto_docstring
1123
+ def forward(
1124
+ self,
1125
+ input_ids: torch.LongTensor = None,
1126
+ attention_mask: Optional[torch.Tensor] = None,
1127
+ position_ids: Optional[torch.LongTensor] = None,
1128
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1129
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1130
+ labels: Optional[torch.LongTensor] = None,
1131
+ use_cache: Optional[bool] = None,
1132
+ output_attentions: Optional[bool] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ cache_position: Optional[torch.LongTensor] = None,
1135
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1136
+ **loss_kwargs,
1137
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1138
+ r"""
1139
+ Args:
1140
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1141
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1142
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1143
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1144
+
1145
+ num_logits_to_keep (`int`, *optional*):
1146
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1147
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1148
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1149
+
1150
+ Returns:
1151
+
1152
+ Example:
1153
+
1154
+ ```python
1155
+ >>> from transformers import AutoTokenizer, RWKV07BQwen3ForCausalLM
1156
+
1157
+ >>> model = RWKV07BQwen3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1158
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1159
+
1160
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1161
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1162
+
1163
+ >>> # Generate
1164
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1165
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1166
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1167
+ ```"""
1168
+
1169
+ # # run the prefill only up to the last token, then run one more for the actual result
1170
+ # # we do this so that called code doesn't have to handle the dichotomy specially and can just check for L==1
1171
+ # for i in range(2):
1172
+ # all_but_one = max(1, input_ids.size(-1)-1)
1173
+ # iid = input_ids[..., i*all_but_one:(i+1)*all_but_one]
1174
+ # if iid.size(-1) == 0:
1175
+ # continue
1176
+ # pids = position_ids
1177
+ # if pids is not None:
1178
+ # pids = position_ids[..., i*all_but_one:(i+1)*all_but_one]
1179
+ # cp = cache_position
1180
+ # if cp is not None:
1181
+ # cp = cache_position[..., i*all_but_one:(i+1)*all_but_one]
1182
+ # rv = self.forward_inner(iid, attention_mask=attention_mask, position_ids=pids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cp, num_logits_to_keep=num_logits_to_keep, **loss_kwargs)
1183
+ # past_key_values = rv.past_key_values
1184
+ # return rv
1185
+
1186
+ # def forward_inner(
1187
+ # self,
1188
+ # input_ids: torch.LongTensor = None,
1189
+ # attention_mask: Optional[torch.Tensor] = None,
1190
+ # position_ids: Optional[torch.LongTensor] = None,
1191
+ # past_key_values: Optional[List[torch.FloatTensor]] = None,
1192
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
1193
+ # labels: Optional[torch.LongTensor] = None,
1194
+ # use_cache: Optional[bool] = None,
1195
+ # output_attentions: Optional[bool] = None,
1196
+ # output_hidden_states: Optional[bool] = None,
1197
+ # cache_position: Optional[torch.LongTensor] = None,
1198
+ # num_logits_to_keep: int = 0,
1199
+ # **loss_kwargs,
1200
+ # ) -> Union[Tuple, CausalLMOutputWithPast]:
1201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1202
+ output_hidden_states = (
1203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1204
+ )
1205
+
1206
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1207
+ outputs = self.model(
1208
+ input_ids=input_ids,
1209
+ attention_mask=attention_mask,
1210
+ position_ids=position_ids,
1211
+ past_key_values=past_key_values,
1212
+ inputs_embeds=inputs_embeds,
1213
+ use_cache=use_cache,
1214
+ output_attentions=output_attentions,
1215
+ output_hidden_states=output_hidden_states,
1216
+ cache_position=cache_position,
1217
+ )
1218
+
1219
+ hidden_states = outputs.last_hidden_state
1220
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1221
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1222
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1223
+
1224
+ loss = None
1225
+ if labels is not None:
1226
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
1227
+
1228
+ return CausalLMOutputWithPast(
1229
+ loss=loss,
1230
+ logits=logits,
1231
+ past_key_values=outputs.past_key_values,
1232
+ hidden_states=outputs.hidden_states,
1233
+ attentions=outputs.attentions,
1234
+ )
1235
+
1236
+ #@auto_docstring
1237
+ class RWKV07BQwen3ForSequenceClassification(RWKV07BMoEPreTrainedModel):
1238
+ pass
1239
+
1240
+ #@auto_docstring
1241
+ class RWKV07BQwen3ForTokenClassification(RWKV07BMoEPreTrainedModel):
1242
+ pass
1243
+
1244
+ #@auto_docstring
1245
+ class RWKV07BQwen3ForQuestionAnswering(RWKV07BMoEPreTrainedModel):
1246
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
test2.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test: output_attentions が正しく Attention Output を返すか検証する。
3
+
4
+ Gemma4TextDecoderLayer は output_attentions=True のとき、
5
+ (hidden_states, attn_output) を返す。attn_output は self_attn の出力
6
+ (post_attention_layernorm 適用前の hidden states)。
7
+
8
+ capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を
9
+ キャプチャするが、sdpa 実装では attn_weights=None のため空になる。
10
+ そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。
11
+ """
12
+
13
+ import torch
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+ MODEL_PATH = "/workspace/llm/gemma-4-31B-Text"
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
19
+ inputs = tokenizer("hello", return_tensors="pt")
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_PATH,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto",
25
+ trust_remote_code=True,
26
+ )
27
+ inputs = inputs.to(model.device)
28
+
29
+ num_layers = model.config.num_hidden_layers
30
+ hidden_size = model.config.hidden_size
31
+ seq_len = inputs["input_ids"].shape[1]
32
+ batch_size = inputs["input_ids"].shape[0]
33
+
34
+ print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}")
35
+ print(f"Input: batch={batch_size}, seq_len={seq_len}")
36
+
37
+ # =========================================================
38
+ # Test 1: model.model (Gemma4TextModel) で output_attentions=True
39
+ # =========================================================
40
+ print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===")
41
+ with torch.no_grad():
42
+ text_outputs = model.model(
43
+ **inputs,
44
+ output_attentions=True,
45
+ use_cache=False,
46
+ )
47
+
48
+ attentions = text_outputs.attentions
49
+ print(f"attentions is None: {attentions is None}")
50
+
51
+ if attentions is not None:
52
+ print(f"Number of attention entries: {len(attentions)}")
53
+ if len(attentions) > 0:
54
+ for i, attn in enumerate(attentions):
55
+ if attn is None:
56
+ print(f" Layer {i}: None")
57
+ else:
58
+ print(f" Layer {i}: shape={attn.shape}, dtype={attn.dtype}")
59
+ if i == 0:
60
+ # attn_output は (batch, seq_len, hidden_size) であるべき
61
+ expected_shape = (batch_size, seq_len, hidden_size)
62
+ if attn.shape == expected_shape:
63
+ print(f" PASS: shape matches expected {expected_shape}")
64
+ else:
65
+ print(f" FAIL: expected {expected_shape}, got {attn.shape}")
66
+ else:
67
+ print(" (empty tuple - capture_outputs hook did not collect anything)")
68
+
69
+ # =========================================================
70
+ # Test 2: DecoderLayer を直接呼んで attn_output を確認
71
+ # =========================================================
72
+ print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===")
73
+ with torch.no_grad():
74
+ # まずembeddingとposition情報を準備
75
+ input_ids = inputs["input_ids"].to(model.device)
76
+ inputs_embeds = model.model.embed_tokens(input_ids)
77
+ position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0)
78
+
79
+ # Rotary embedding
80
+ layer_type = model.config.layer_types[0]
81
+ position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type)
82
+
83
+ # Causal mask (簡易: None で全アテンション)
84
+ first_layer = model.model.layers[0]
85
+
86
+ layer_outputs = first_layer(
87
+ inputs_embeds,
88
+ per_layer_input=None,
89
+ position_embeddings=position_embeddings,
90
+ attention_mask=None,
91
+ position_ids=position_ids,
92
+ past_key_values=None,
93
+ output_attentions=True,
94
+ )
95
+
96
+ print(f"DecoderLayer returned {len(layer_outputs)} outputs")
97
+ if len(layer_outputs) >= 2:
98
+ hidden_out = layer_outputs[0]
99
+ attn_out = layer_outputs[1]
100
+ print(f" hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}")
101
+ print(f" attn_output: shape={attn_out.shape}, dtype={attn_out.dtype}")
102
+
103
+ expected_shape = (batch_size, seq_len, hidden_size)
104
+ if attn_out.shape == expected_shape:
105
+ print(f" PASS: attn_output shape is correct {expected_shape}")
106
+ else:
107
+ print(f" FAIL: expected {expected_shape}, got {attn_out.shape}")
108
+
109
+ # attn_output が all-zero でないことを確認
110
+ if attn_out.abs().sum() > 0:
111
+ print(f" PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})")
112
+ else:
113
+ print(f" FAIL: attn_output is all zeros")
114
+
115
+ # hidden_states と attn_output が異なることを確認
116
+ # (attn_output は layernorm + residual 前なので hidden_states とは異なるはず)
117
+ if not torch.equal(hidden_out, attn_out):
118
+ print(f" PASS: attn_output differs from hidden_states (as expected)")
119
+ else:
120
+ print(f" FAIL: attn_output is identical to hidden_states")
121
+ else:
122
+ print(f" FAIL: expected 2 outputs, got {len(layer_outputs)}")
123
+
124
+ # =========================================================
125
+ # Test 3: output_attentions=False では attn_output が返らないこと
126
+ # =========================================================
127
+ print("\n=== Test 3: DecoderLayer with output_attentions=False ===")
128
+ with torch.no_grad():
129
+ layer_outputs_no_attn = first_layer(
130
+ inputs_embeds,
131
+ per_layer_input=None,
132
+ position_embeddings=position_embeddings,
133
+ attention_mask=None,
134
+ position_ids=position_ids,
135
+ past_key_values=None,
136
+ output_attentions=False,
137
+ )
138
+ print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs")
139
+ if len(layer_outputs_no_attn) == 1:
140
+ print(" PASS: only hidden_states returned (no attn_output)")
141
+ else:
142
+ print(f" FAIL: expected 1 output, got {len(layer_outputs_no_attn)}")
143
+
144
+ # =========================================================
145
+ # Test 4: CausalLM の output_attentions の伝播確認
146
+ # =========================================================
147
+ print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===")
148
+ with torch.no_grad():
149
+ causal_outputs = model(**inputs, output_attentions=True, use_cache=False)
150
+
151
+ attentions_causal = causal_outputs.attentions
152
+ print(f"CausalLM attentions is None: {attentions_causal is None}")
153
+ if attentions_causal is not None:
154
+ print(f"CausalLM attentions length: {len(attentions_causal)}")
155
+ if len(attentions_causal) == num_layers:
156
+ print(f" PASS: got {num_layers} layers of attention output")
157
+ elif len(attentions_causal) == 0:
158
+ print(f" FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)")
159
+ print(f" NOTE: This is a known issue - sdpa does not return attention weights.")
160
+ print(f" Use attn_implementation='eager' to get attention weights via this path.")
161
+ else:
162
+ print(f" Got {len(attentions_causal)} (expected {num_layers})")
163
+
164
+ # =========================================================
165
+ # Summary
166
+ # =========================================================
167
+ print("\n" + "=" * 60)
168
+ print("SUMMARY")
169
+ print("=" * 60)
170
+ print("- DecoderLayer correctly returns attn_output when output_attentions=True")
171
+ print("- DecoderLayer correctly omits attn_output when output_attentions=False")
172
+ print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]")
173
+ print(" which is attn_weights (None with sdpa), so CausalLM.attentions is empty.")
174
+ print("- To get attention outputs at model level, either:")
175
+ print(" (a) use attn_implementation='eager', or")
176
+ print(" (b) access DecoderLayer outputs directly.")
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8d3a0ce36466ccc1278bf987df5f71db1719b9ca6b4118264f45cb627bfe0f
3
+ size 32169626