vishesh-t27 commited on
Commit
29e0f98
·
verified ·
1 Parent(s): e8e8af0

Update configuration_nandi.py

Browse files
Files changed (1) hide show
  1. configuration_nandi.py +3 -31
configuration_nandi.py CHANGED
@@ -1,17 +1,3 @@
1
- # Copyright 2026 RTA AI Labs. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
  from transformers.configuration_utils import PretrainedConfig
16
 
17
 
@@ -97,26 +83,12 @@ class NandiConfig(PretrainedConfig):
97
  self.factorized_embedding = factorized_embedding
98
  self.embedding_rank = embedding_rank
99
  self.layer_sharing = layer_sharing
100
- # Smoltron training loops over `layer_sharing_repeats` unconditionally
101
- # (it does NOT check `layer_sharing`). Preserve the raw value here so
102
- # the modeling code can honor it; the `layer_sharing` bool is now just
103
- # metadata describing intent.
104
  self.layer_sharing_repeats = max(1, int(layer_sharing_repeats or 1))
105
  self.qk_norm = qk_norm
106
- # `shared_kv` records that V was tied to K at pretraining time. In the
107
- # HF model V is recomputed from `k_proj` at runtime (no `v_proj` module
108
- # is materialised); see `NandiAttention.forward`.
109
  self.shared_kv = shared_kv
110
- # `kv_cache_mode` controls the inference-time K/V cache strategy when
111
- # `shared_kv=True`. Both modes produce identical outputs (numerical
112
- # round-off only); they trade memory for compute:
113
- # "shared" -> cache ONLY raw K (single tensor per layer). Each
114
- # decode step re-applies k_norm + RoPE to the full
115
- # cached raw K. Halves KV-cache memory.
116
- # "vanilla" -> cache post-norm post-RoPE K AND raw V (two tensors
117
- # per layer). k_norm + RoPE are applied only to the
118
- # current step's tokens. Standard HF behavior.
119
- # Ignored when `shared_kv=False`. Defaults to "shared".
120
  if kv_cache_mode not in ("shared", "vanilla"):
121
  raise ValueError(
122
  f"`kv_cache_mode` must be 'shared' or 'vanilla', got {kv_cache_mode!r}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
 
3
 
 
83
  self.factorized_embedding = factorized_embedding
84
  self.embedding_rank = embedding_rank
85
  self.layer_sharing = layer_sharing
86
+
 
 
 
87
  self.layer_sharing_repeats = max(1, int(layer_sharing_repeats or 1))
88
  self.qk_norm = qk_norm
89
+
 
 
90
  self.shared_kv = shared_kv
91
+
 
 
 
 
 
 
 
 
 
92
  if kv_cache_mode not in ("shared", "vanilla"):
93
  raise ValueError(
94
  f"`kv_cache_mode` must be 'shared' or 'vanilla', got {kv_cache_mode!r}."