Nandi-Mini-600M-Early-Checkpoint / configuration_nandi.py
vishesh-t27's picture
Update configuration_nandi.py
29e0f98 verified
from transformers.configuration_utils import PretrainedConfig
class NandiConfig(PretrainedConfig):
r"""
Configuration class for the Nandi model.
Example:
```python
>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> configuration = AutoConfig.from_pretrained("Rta-AILabs/Nandi-500M-remote", trust_remote_code=True)
>>> model = AutoModelForCausalLM.from_pretrained("Rta-AILabs/Nandi-500M-remote", trust_remote_code=True)
>>> configuration = model.config
```
"""
model_type = "nandi"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
vocab_size=131072,
hidden_size=1248,
intermediate_size=3556,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.008,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=0,
pretraining_tp=1,
tie_word_embeddings=True,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
factorized_embedding=False,
embedding_rank=768,
layer_sharing=False,
layer_sharing_repeats=1,
qk_norm=True,
shared_kv=True,
kv_cache_mode="shared",
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.pretraining_tp = pretraining_tp
self.rope_parameters = rope_parameters if rope_parameters is not None else {"rope_theta": 1000000.0}
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.factorized_embedding = factorized_embedding
self.embedding_rank = embedding_rank
self.layer_sharing = layer_sharing
self.layer_sharing_repeats = max(1, int(layer_sharing_repeats or 1))
self.qk_norm = qk_norm
self.shared_kv = shared_kv
if kv_cache_mode not in ("shared", "vanilla"):
raise ValueError(
f"`kv_cache_mode` must be 'shared' or 'vanilla', got {kv_cache_mode!r}."
)
self.kv_cache_mode = kv_cache_mode
if self.factorized_embedding and self.embedding_rank <= 0:
raise ValueError(
f"`embedding_rank` must be positive when `factorized_embedding=True`, got {self.embedding_rank}."
)
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"`hidden_size` ({self.hidden_size}) must be divisible by "
f"`num_attention_heads` ({self.num_attention_heads})."
)
if self.layer_sharing_repeats < 1:
raise ValueError(f"`layer_sharing_repeats` must be >= 1, got {self.layer_sharing_repeats}.")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["NandiConfig"]