HemanthSai7 commited on
Commit
963fde0
·
verified ·
1 Parent(s): df3eeff

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_nandi.py +111 -0
  2. modeling_nandi.py +484 -0
  3. tokenization_nandi.py +124 -0
configuration_nandi.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/nandi/modular_nandi.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_nandi.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from huggingface_hub.dataclasses import strict
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...modeling_rope_utils import RopeParameters
25
+
26
+
27
+ @strict(accept_kwargs=True)
28
+ class NandiConfig(PretrainedConfig):
29
+ r"""
30
+ Example:
31
+
32
+ ```python
33
+ >>> from transformers import NandiConfig, NandiForCausalLM
34
+
35
+ >>> # Initializing a Nandi style configuration
36
+ >>> configuration = NandiConfig()
37
+
38
+ >>> # Initializing a model from the Nandi style configuration
39
+ >>> model = NandiForCausalLM(configuration)
40
+
41
+ >>> # Accessing the model configuration
42
+ >>> configuration = model.config
43
+ ```"""
44
+
45
+ model_type = "nandi"
46
+ keys_to_ignore_at_inference = ["past_key_values"]
47
+
48
+ base_model_tp_plan = {
49
+ "layers.*.self_attn.q_proj": "colwise",
50
+ "layers.*.self_attn.k_proj": "colwise",
51
+ "layers.*.self_attn.v_proj": "colwise",
52
+ "layers.*.self_attn.o_proj": "rowwise",
53
+ "layers.*.mlp.gate_proj": "colwise",
54
+ "layers.*.mlp.up_proj": "colwise",
55
+ "layers.*.mlp.down_proj": "rowwise",
56
+ }
57
+
58
+ # Defaults from the provided Nanotron training config.
59
+ vocab_size: int = 131072
60
+ hidden_size: int = 832
61
+ intermediate_size: int = 2496
62
+ num_hidden_layers: int = 16
63
+ num_attention_heads: int = 16
64
+ num_key_value_heads: int | None = 4
65
+ head_dim: int | None = None
66
+ hidden_act: str = "silu"
67
+ max_position_embeddings: int = 2048
68
+ initializer_range: float = 0.008
69
+ rms_norm_eps: float = 1e-5
70
+ use_cache: bool = True
71
+ pad_token_id: int | None = None
72
+ bos_token_id: int | None = 1
73
+ eos_token_id: int | list[int] | None = 0
74
+ pretraining_tp: int | None = 1
75
+ tie_word_embeddings: bool = True
76
+ rope_parameters: RopeParameters | dict | None = None
77
+ attention_bias: bool = False
78
+ attention_dropout: float = 0.0
79
+ mlp_bias: bool = False
80
+
81
+ # Nandi-specific options.
82
+ factorized_embedding: bool = True
83
+ embedding_rank: int = 196
84
+ layer_sharing: bool = True
85
+ layer_sharing_repeats: int = 2
86
+
87
+ def __post_init__(self, **kwargs):
88
+ if self.num_key_value_heads is None:
89
+ self.num_key_value_heads = self.num_attention_heads
90
+ if self.head_dim is None:
91
+ self.head_dim = self.hidden_size // self.num_attention_heads
92
+ if self.rope_parameters is None:
93
+ self.rope_parameters = {"rope_theta": 100000.0}
94
+ if not self.layer_sharing:
95
+ self.layer_sharing_repeats = 1
96
+
97
+ if self.factorized_embedding and self.embedding_rank <= 0:
98
+ raise ValueError(
99
+ f"`embedding_rank` must be positive when `factorized_embedding=True`, got {self.embedding_rank}."
100
+ )
101
+ if self.hidden_size % self.num_attention_heads != 0:
102
+ raise ValueError(
103
+ f"`hidden_size` ({self.hidden_size}) must be divisible by `num_attention_heads` ({self.num_attention_heads})."
104
+ )
105
+ if self.layer_sharing_repeats < 1:
106
+ raise ValueError(f"`layer_sharing_repeats` must be >= 1, got {self.layer_sharing_repeats}.")
107
+
108
+ super().__post_init__(**kwargs)
109
+
110
+
111
+ __all__ = ["NandiConfig"]
modeling_nandi.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/nandi/modular_nandi.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_nandi.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from ...activations import ACT2FN
27
+ from ...cache_utils import Cache, DynamicCache, DynamicLayer
28
+ from ...generation import GenerationMixin
29
+ from ...integrations import use_kernel_forward_from_hub
30
+ from ...masking_utils import create_causal_mask
31
+ from ...modeling_layers import GradientCheckpointingLayer
32
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from ...processing_utils import Unpack
36
+ from ...utils import TransformersKwargs, auto_docstring
37
+ from ...utils.deprecation import deprecate_kwarg
38
+ from ...utils.generic import can_return_tuple, merge_with_config_defaults
39
+ from ...utils.output_capturing import capture_outputs
40
+ from .configuration_nandi import NandiConfig
41
+
42
+
43
+ @use_kernel_forward_from_hub("RMSNorm")
44
+ class NandiRMSNorm(nn.Module):
45
+ def __init__(self, hidden_size, eps=1e-6):
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.ones(hidden_size))
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, hidden_states):
51
+ input_dtype = hidden_states.dtype
52
+ hidden_states = hidden_states.to(torch.float32)
53
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
54
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
55
+ return self.weight * hidden_states.to(input_dtype)
56
+
57
+ def extra_repr(self):
58
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
59
+
60
+
61
+ class NandiRotaryEmbedding(nn.Module):
62
+ inv_freq: torch.Tensor
63
+
64
+ def __init__(self, config: NandiConfig, device=None):
65
+ super().__init__()
66
+ self.max_seq_len_cached = config.max_position_embeddings
67
+ self.original_max_seq_len = config.max_position_embeddings
68
+
69
+ self.config = config
70
+ self.rope_type = self.config.rope_parameters.get("rope_type", "default")
71
+ rope_init_fn: Callable = self.compute_default_rope_parameters
72
+ if self.rope_type != "default":
73
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
74
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
75
+
76
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
77
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
78
+
79
+ @staticmethod
80
+ def compute_default_rope_parameters(
81
+ config: NandiConfig | None = None,
82
+ device: torch.device | None = None,
83
+ seq_len: int | None = None,
84
+ ) -> tuple[torch.Tensor, float]:
85
+ del seq_len
86
+ base = config.rope_parameters["rope_theta"]
87
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
88
+ attention_factor = 1.0
89
+ inv_freq = 1.0 / (
90
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
91
+ )
92
+ return inv_freq, attention_factor
93
+
94
+ @torch.no_grad()
95
+ @dynamic_rope_update
96
+ def forward(self, x, position_ids):
97
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
98
+ position_ids_expanded = position_ids[:, None, :].float()
99
+
100
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
101
+ with torch.autocast(device_type=device_type, enabled=False):
102
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ cos = emb.cos() * self.attention_scaling
105
+ sin = emb.sin() * self.attention_scaling
106
+
107
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
108
+
109
+
110
+ def rotate_half(x):
111
+ """Rotates half the hidden dims of the input."""
112
+ x1 = x[..., : x.shape[-1] // 2]
113
+ x2 = x[..., x.shape[-1] // 2 :]
114
+ return torch.cat((-x2, x1), dim=-1)
115
+
116
+
117
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
118
+ del position_ids
119
+ cos = cos.unsqueeze(unsqueeze_dim)
120
+ sin = sin.unsqueeze(unsqueeze_dim)
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
127
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
128
+ if n_rep == 1:
129
+ return hidden_states
130
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
131
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
132
+
133
+
134
+ def eager_attention_forward(
135
+ module: nn.Module,
136
+ query: torch.Tensor,
137
+ key: torch.Tensor,
138
+ value: torch.Tensor,
139
+ attention_mask: torch.Tensor | None,
140
+ scaling: float,
141
+ dropout: float = 0.0,
142
+ **kwargs: Unpack[TransformersKwargs],
143
+ ):
144
+ del kwargs
145
+ key_states = repeat_kv(key, module.num_key_value_groups)
146
+ value_states = repeat_kv(value, module.num_key_value_groups)
147
+
148
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
149
+ if attention_mask is not None:
150
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
151
+ attn_weights = attn_weights + causal_mask
152
+
153
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
154
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
155
+ attn_output = torch.matmul(attn_weights, value_states)
156
+ attn_output = attn_output.transpose(1, 2).contiguous()
157
+
158
+ return attn_output, attn_weights
159
+
160
+
161
+ class NandiAttention(nn.Module):
162
+ def __init__(self, config: NandiConfig, layer_idx: int):
163
+ super().__init__()
164
+ self.config = config
165
+ self.layer_idx = layer_idx
166
+ self.head_dim = config.head_dim
167
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
168
+ self.scaling = self.head_dim**-0.5
169
+ self.attention_dropout = config.attention_dropout
170
+ self.is_causal = True
171
+
172
+ self.q_proj = nn.Linear(
173
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
174
+ )
175
+ self.k_proj = nn.Linear(
176
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
177
+ )
178
+ self.v_proj = nn.Linear(
179
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
180
+ )
181
+ self.o_proj = nn.Linear(
182
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
183
+ )
184
+
185
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
186
+ def forward(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
190
+ attention_mask: torch.Tensor | None,
191
+ past_key_values: Cache | None = None,
192
+ cache_position: torch.LongTensor | None = None,
193
+ **kwargs: Unpack[TransformersKwargs],
194
+ ) -> tuple[torch.Tensor, torch.Tensor]:
195
+ input_shape = hidden_states.shape[:-1]
196
+ hidden_shape = (*input_shape, -1, self.head_dim)
197
+
198
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
199
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
200
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
201
+
202
+ cos, sin = position_embeddings
203
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
204
+
205
+ if past_key_values is not None:
206
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
207
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
208
+
209
+ attention_interface: Callable = eager_attention_forward
210
+ if self.config._attn_implementation != "eager":
211
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
212
+
213
+ attn_output, attn_weights = attention_interface(
214
+ self,
215
+ query_states,
216
+ key_states,
217
+ value_states,
218
+ attention_mask,
219
+ dropout=0.0 if not self.training else self.attention_dropout,
220
+ scaling=self.scaling,
221
+ **kwargs,
222
+ )
223
+
224
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
225
+ attn_output = self.o_proj(attn_output)
226
+ return attn_output, attn_weights
227
+
228
+
229
+ class NandiMLP(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
233
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
234
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
235
+ self.act_fn = ACT2FN[config.hidden_act]
236
+
237
+ def forward(self, x):
238
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
+
240
+
241
+ class NandiDecoderLayer(GradientCheckpointingLayer):
242
+ def __init__(self, config: NandiConfig, layer_idx: int):
243
+ super().__init__()
244
+ self.hidden_size = config.hidden_size
245
+ self.self_attn = NandiAttention(config=config, layer_idx=layer_idx)
246
+ self.mlp = NandiMLP(config)
247
+ self.input_layernorm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
+ self.post_attention_layernorm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+
250
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: torch.Tensor | None = None,
255
+ position_ids: torch.LongTensor | None = None,
256
+ past_key_values: Cache | None = None,
257
+ use_cache: bool | None = False,
258
+ cache_position: torch.LongTensor | None = None,
259
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
260
+ **kwargs: Unpack[TransformersKwargs],
261
+ ) -> torch.Tensor:
262
+ residual = hidden_states
263
+ hidden_states = self.input_layernorm(hidden_states)
264
+
265
+ hidden_states, _ = self.self_attn(
266
+ hidden_states=hidden_states,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ cache_position=cache_position,
272
+ position_embeddings=position_embeddings,
273
+ **kwargs,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ residual = hidden_states
278
+ hidden_states = self.post_attention_layernorm(hidden_states)
279
+ hidden_states = self.mlp(hidden_states)
280
+ hidden_states = residual + hidden_states
281
+ return hidden_states
282
+
283
+
284
+ class _VirtualLayerCache:
285
+ """Proxy that shifts cache layer indices by `offset` to give each repeat its own virtual slots."""
286
+
287
+ def __init__(self, cache: Cache, offset: int):
288
+ self._cache = cache
289
+ self._offset = offset
290
+
291
+ def __getattr__(self, name):
292
+ return getattr(self._cache, name)
293
+
294
+ def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
295
+ virtual_idx = layer_idx + self._offset
296
+ # grow the backing cache if generate() pre-allocated fewer slots than needed
297
+ while len(self._cache.layers) <= virtual_idx:
298
+ self._cache.layers.append(DynamicLayer())
299
+ return self._cache.update(key_states, value_states, virtual_idx, cache_kwargs)
300
+
301
+ def get_seq_length(self, layer_idx: int = 0) -> int:
302
+ return self._cache.get_seq_length(layer_idx + self._offset)
303
+
304
+
305
+ @auto_docstring
306
+ class NandiPreTrainedModel(PreTrainedModel):
307
+ config: NandiConfig
308
+ base_model_prefix = "model"
309
+ supports_gradient_checkpointing = True
310
+ _no_split_modules = ["NandiDecoderLayer"]
311
+ _skip_keys_device_placement = ["past_key_values"]
312
+ _supports_flash_attn = True
313
+ _supports_sdpa = True
314
+ _supports_flex_attn = True
315
+ _can_compile_fullgraph = True
316
+ _supports_attention_backend = True
317
+ _can_record_outputs = {
318
+ "hidden_states": NandiDecoderLayer,
319
+ "attentions": NandiAttention,
320
+ }
321
+
322
+ def __init__(self, config: NandiConfig):
323
+ super().__init__(config)
324
+
325
+
326
+ @auto_docstring
327
+ class NandiModel(NandiPreTrainedModel):
328
+ def __init__(self, config: NandiConfig):
329
+ super().__init__(config)
330
+ self.padding_idx = config.pad_token_id
331
+ self.vocab_size = config.vocab_size
332
+ embedding_dim = config.embedding_rank if config.factorized_embedding else config.hidden_size
333
+
334
+ self.embed_tokens = nn.Embedding(config.vocab_size, embedding_dim, self.padding_idx)
335
+ self.embedding_proj = (
336
+ nn.Linear(config.embedding_rank, config.hidden_size, bias=False) if config.factorized_embedding else None
337
+ )
338
+ self.layers = nn.ModuleList(
339
+ [NandiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
340
+ )
341
+ self.norm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
342
+ self.rotary_emb = NandiRotaryEmbedding(config=config)
343
+ self.gradient_checkpointing = False
344
+
345
+ self.post_init()
346
+
347
+ @merge_with_config_defaults
348
+ @capture_outputs
349
+ @auto_docstring
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor | None = None,
353
+ attention_mask: torch.Tensor | None = None,
354
+ position_ids: torch.LongTensor | None = None,
355
+ past_key_values: Cache | None = None,
356
+ inputs_embeds: torch.FloatTensor | None = None,
357
+ use_cache: bool | None = None,
358
+ **kwargs: Unpack[TransformersKwargs],
359
+ ) -> BaseModelOutputWithPast:
360
+ if (input_ids is None) ^ (inputs_embeds is not None):
361
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
362
+
363
+ if inputs_embeds is None:
364
+ inputs_embeds = self.embed_tokens(input_ids)
365
+
366
+ if self.embedding_proj is not None:
367
+ inputs_embeds = self.embedding_proj(inputs_embeds)
368
+
369
+ repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
370
+
371
+ if use_cache and past_key_values is None:
372
+ # Use lazy DynamicCache (no config) so it grows to accommodate
373
+ # num_hidden_layers * repeats virtual slots for layer-sharing.
374
+ past_key_values = DynamicCache()
375
+
376
+ if position_ids is None:
377
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
378
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
379
+ position_ids = position_ids.unsqueeze(0)
380
+
381
+ causal_mask = create_causal_mask(
382
+ config=self.config,
383
+ inputs_embeds=inputs_embeds,
384
+ attention_mask=attention_mask,
385
+ past_key_values=past_key_values,
386
+ position_ids=position_ids,
387
+ )
388
+
389
+ hidden_states = inputs_embeds
390
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
391
+
392
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
393
+ for repeat_idx in range(repeats):
394
+ # Each repeat gets its own virtual cache slots offset by num_hidden_layers,
395
+ # so repeat 0 uses slots 0..N-1 and repeat 1 uses slots N..2N-1, etc.
396
+ repeat_cache = (
397
+ _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
398
+ if (past_key_values is not None and repeat_idx > 0)
399
+ else past_key_values
400
+ )
401
+ hidden_states = decoder_layer(
402
+ hidden_states,
403
+ attention_mask=causal_mask,
404
+ position_embeddings=position_embeddings,
405
+ position_ids=position_ids,
406
+ past_key_values=repeat_cache,
407
+ use_cache=use_cache,
408
+ **kwargs,
409
+ )
410
+
411
+ hidden_states = self.norm(hidden_states)
412
+ return BaseModelOutputWithPast(
413
+ last_hidden_state=hidden_states,
414
+ past_key_values=past_key_values,
415
+ )
416
+
417
+
418
+ @auto_docstring
419
+ class NandiForCausalLM(NandiPreTrainedModel, GenerationMixin):
420
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
421
+ _tp_plan = {"lm_head": "colwise_gather_output"}
422
+ _pp_plan = {
423
+ "lm_head_proj": (["hidden_states"], ["hidden_states"]),
424
+ "lm_head": (["hidden_states"], ["logits"]),
425
+ }
426
+
427
+ def __init__(self, config):
428
+ super().__init__(config)
429
+ self.model = NandiModel(config)
430
+ self.vocab_size = config.vocab_size
431
+
432
+ lm_head_in_features = config.embedding_rank if config.factorized_embedding else config.hidden_size
433
+ self.lm_head_proj = (
434
+ nn.Linear(config.hidden_size, config.embedding_rank, bias=False) if config.factorized_embedding else None
435
+ )
436
+ self.lm_head = nn.Linear(lm_head_in_features, config.vocab_size, bias=False)
437
+
438
+ self.post_init()
439
+
440
+ @can_return_tuple
441
+ @auto_docstring
442
+ def forward(
443
+ self,
444
+ input_ids: torch.LongTensor | None = None,
445
+ attention_mask: torch.Tensor | None = None,
446
+ position_ids: torch.LongTensor | None = None,
447
+ past_key_values: Cache | None = None,
448
+ inputs_embeds: torch.FloatTensor | None = None,
449
+ labels: torch.LongTensor | None = None,
450
+ use_cache: bool | None = None,
451
+ logits_to_keep: int | torch.Tensor = 0,
452
+ **kwargs: Unpack[TransformersKwargs],
453
+ ) -> CausalLMOutputWithPast:
454
+ outputs: BaseModelOutputWithPast = self.model(
455
+ input_ids=input_ids,
456
+ attention_mask=attention_mask,
457
+ position_ids=position_ids,
458
+ past_key_values=past_key_values,
459
+ inputs_embeds=inputs_embeds,
460
+ use_cache=use_cache,
461
+ **kwargs,
462
+ )
463
+
464
+ hidden_states = outputs.last_hidden_state
465
+ if self.lm_head_proj is not None:
466
+ hidden_states = self.lm_head_proj(hidden_states)
467
+
468
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
469
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
470
+
471
+ loss = None
472
+ if labels is not None:
473
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
474
+
475
+ return CausalLMOutputWithPast(
476
+ loss=loss,
477
+ logits=logits,
478
+ past_key_values=outputs.past_key_values,
479
+ hidden_states=outputs.hidden_states,
480
+ attentions=outputs.attentions,
481
+ )
482
+
483
+
484
+ __all__ = ["NandiPreTrainedModel", "NandiModel", "NandiForCausalLM"]
tokenization_nandi.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenization classes for the Nandi family."""
15
+
16
+ from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
17
+ from tokenizers.models import BPE
18
+
19
+ from ...tokenization_utils_tokenizers import TokenizersBackend
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?(?:\p{L}\p{M}*)+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
26
+
27
+
28
+ class NandiTokenizer(TokenizersBackend):
29
+ model_input_names = ["input_ids", "attention_mask"]
30
+ model = BPE
31
+
32
+ def __init__(
33
+ self,
34
+ vocab: str | dict[str, int] | None = None,
35
+ merges: str | list[str] | None = None,
36
+ vocab_file=None,
37
+ merges_file=None,
38
+ unk_token: str = "<|endoftext|>",
39
+ bos_token: str = "<|im_start|>",
40
+ eos_token: str = "<|endoftext|>",
41
+ pad_token: str = "<|pad|>",
42
+ add_prefix_space: bool | None = None,
43
+ **kwargs,
44
+ ):
45
+ self._vocab = (
46
+ vocab
47
+ if vocab is not None
48
+ else {
49
+ "<|endoftext|>": 0,
50
+ }
51
+ )
52
+ self._merges = merges or []
53
+
54
+ self._tokenizer = Tokenizer(
55
+ BPE(
56
+ vocab=self._vocab,
57
+ merges=self._merges,
58
+ dropout=None,
59
+ unk_token=None,
60
+ continuing_subword_prefix="",
61
+ end_of_word_suffix="",
62
+ fuse_unk=False,
63
+ byte_fallback=False,
64
+ )
65
+ )
66
+ self._tokenizer.decoder = decoders.ByteLevel()
67
+ self._tokenizer.normalizer = normalizers.NFC()
68
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
69
+ [
70
+ pre_tokenizers.Split(
71
+ Regex(PRETOKENIZE_REGEX),
72
+ behavior="isolated",
73
+ invert=False,
74
+ ),
75
+ pre_tokenizers.ByteLevel(
76
+ add_prefix_space=False,
77
+ trim_offsets=True,
78
+ use_regex=False
79
+ ),
80
+ ]
81
+ )
82
+
83
+ super().__init__(
84
+ vocab_file=vocab_file,
85
+ merges_file=merges_file,
86
+ unk_token=unk_token,
87
+ bos_token=bos_token,
88
+ eos_token=eos_token,
89
+ pad_token=pad_token,
90
+ add_prefix_space=add_prefix_space,
91
+ **kwargs,
92
+ )
93
+
94
+ def encode(
95
+ self,
96
+ text,
97
+ text_pair=None,
98
+ add_special_tokens: bool = True,
99
+ padding=False,
100
+ truncation=None,
101
+ max_length=None,
102
+ stride: int = 0,
103
+ padding_side=None,
104
+ return_tensors=None,
105
+ **kwargs,
106
+ ):
107
+ if isinstance(text, str):
108
+ # This is a temporary fix to match the behaviour of the training pipeline
109
+ text = " " + text
110
+ return super().encode(
111
+ text,
112
+ text_pair=text_pair,
113
+ add_special_tokens=add_special_tokens,
114
+ padding=padding,
115
+ truncation=truncation,
116
+ max_length=max_length,
117
+ stride=stride,
118
+ padding_side=padding_side,
119
+ return_tensors=return_tensors,
120
+ **kwargs,
121
+ )
122
+
123
+
124
+ __all__ = ["NandiTokenizer"]