SuperApriel-15b-Base / configuration_apriel2.py
denisko's picture
Upload folder using huggingface_hub
bd26458 verified
"""Apriel2 HuggingFace configuration."""
import logging
from typing import Optional
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
class Apriel2TextConfig(PretrainedConfig):
model_type = "apriel2_text"
def __init__(
self,
hidden_size: int = 4096,
vocab_size: int = 32000,
decoder: Optional[dict] = None,
embeddings: Optional[dict] = None,
head: Optional[dict] = None,
tie_word_embeddings: bool = False,
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: Optional[int] = None,
use_cache: bool = True,
**kwargs,
):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.use_cache = use_cache
self.decoder = decoder or self._default_decoder_config()
self.embeddings = embeddings or self._default_embeddings_config()
self.head = head or self._default_head_config()
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _default_decoder_config(self) -> dict:
return {
"type": "fixed",
"num_blocks": 32,
"block": {
"mixer": {
"type": "attention",
"heads": 32,
"head_groups": 32,
"head_size": self.hidden_size // 32,
"rotary": {"type": "default", "theta": 10000.0},
"add_linear_biases": False,
},
"mlp": {
"type": "mlp",
"intermediate_size": self.hidden_size * 4,
"activation": "silu",
"gated": True,
"add_linear_biases": False,
},
"normalization": {"type": "rms_norm", "epsilon": 1e-5},
},
}
def _default_embeddings_config(self) -> dict:
return {
"max_position_embeddings": 2048,
}
def _default_head_config(self) -> dict:
return {
"normalization": {"type": "rms_norm", "epsilon": 1e-5},
}
def get_text_config(self, decoder: bool = False):
return self
def get_block_name(self, layer_idx: int) -> str:
decoder_type = self.decoder.get("type", "fixed")
if decoder_type == "fixed":
return "block"
elif decoder_type == "pattern":
pattern = self.decoder.get("pattern", [])
if not pattern:
raise ValueError("Pattern decoder requires 'pattern' field")
return pattern[layer_idx % len(pattern)]
else:
raise ValueError(f"Unknown decoder type: {decoder_type}")
def get_block_config(self, layer_idx: int) -> dict:
decoder_type = self.decoder.get("type", "fixed")
if decoder_type == "fixed":
return self.decoder.get("block", {})
elif decoder_type == "pattern":
blocks = self.decoder.get("blocks", {})
pattern = self.decoder.get("pattern", [])
if not blocks or not pattern:
raise ValueError("Pattern decoder requires 'blocks' and 'pattern' fields")
block_name = pattern[layer_idx % len(pattern)]
return blocks[block_name]
else:
raise ValueError(f"Unknown decoder type: {decoder_type}")
class Apriel2Config(Apriel2TextConfig):
model_type = "apriel2"
def __init__(
self,
hidden_size: int = 4096,
vocab_size: int = 32000,
decoder: Optional[dict] = None,
embeddings: Optional[dict] = None,
head: Optional[dict] = None,
vision_encoder: Optional[dict] = None,
image_token_index: Optional[int] = None,
tie_word_embeddings: bool = False,
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: Optional[int] = None,
use_cache: bool = True,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
vocab_size=vocab_size,
decoder=decoder,
embeddings=embeddings,
head=head,
tie_word_embeddings=tie_word_embeddings,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
use_cache=use_cache,
**kwargs,
)
self.vision_encoder = vision_encoder
self.image_token_index = image_token_index