| """Apriel2 HuggingFace configuration.""" |
|
|
| import logging |
| from typing import Optional |
|
|
| from transformers import PretrainedConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Apriel2TextConfig(PretrainedConfig): |
| model_type = "apriel2_text" |
|
|
| def __init__( |
| self, |
| hidden_size: int = 4096, |
| vocab_size: int = 32000, |
| decoder: Optional[dict] = None, |
| embeddings: Optional[dict] = None, |
| head: Optional[dict] = None, |
| tie_word_embeddings: bool = False, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| pad_token_id: Optional[int] = None, |
| use_cache: bool = True, |
| **kwargs, |
| ): |
| self.hidden_size = hidden_size |
| self.vocab_size = vocab_size |
| self.use_cache = use_cache |
|
|
| self.decoder = decoder or self._default_decoder_config() |
| self.embeddings = embeddings or self._default_embeddings_config() |
| self.head = head or self._default_head_config() |
|
|
| super().__init__( |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| pad_token_id=pad_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
| def _default_decoder_config(self) -> dict: |
| return { |
| "type": "fixed", |
| "num_blocks": 32, |
| "block": { |
| "mixer": { |
| "type": "attention", |
| "heads": 32, |
| "head_groups": 32, |
| "head_size": self.hidden_size // 32, |
| "rotary": {"type": "default", "theta": 10000.0}, |
| "add_linear_biases": False, |
| }, |
| "mlp": { |
| "type": "mlp", |
| "intermediate_size": self.hidden_size * 4, |
| "activation": "silu", |
| "gated": True, |
| "add_linear_biases": False, |
| }, |
| "normalization": {"type": "rms_norm", "epsilon": 1e-5}, |
| }, |
| } |
|
|
| def _default_embeddings_config(self) -> dict: |
| return { |
| "max_position_embeddings": 2048, |
| } |
|
|
| def _default_head_config(self) -> dict: |
| return { |
| "normalization": {"type": "rms_norm", "epsilon": 1e-5}, |
| } |
|
|
| def get_text_config(self, decoder: bool = False): |
| return self |
|
|
| def get_block_name(self, layer_idx: int) -> str: |
| decoder_type = self.decoder.get("type", "fixed") |
|
|
| if decoder_type == "fixed": |
| return "block" |
| elif decoder_type == "pattern": |
| pattern = self.decoder.get("pattern", []) |
| if not pattern: |
| raise ValueError("Pattern decoder requires 'pattern' field") |
| return pattern[layer_idx % len(pattern)] |
| else: |
| raise ValueError(f"Unknown decoder type: {decoder_type}") |
|
|
| def get_block_config(self, layer_idx: int) -> dict: |
| decoder_type = self.decoder.get("type", "fixed") |
|
|
| if decoder_type == "fixed": |
| return self.decoder.get("block", {}) |
| elif decoder_type == "pattern": |
| blocks = self.decoder.get("blocks", {}) |
| pattern = self.decoder.get("pattern", []) |
| if not blocks or not pattern: |
| raise ValueError("Pattern decoder requires 'blocks' and 'pattern' fields") |
| block_name = pattern[layer_idx % len(pattern)] |
| return blocks[block_name] |
| else: |
| raise ValueError(f"Unknown decoder type: {decoder_type}") |
|
|
|
|
| class Apriel2Config(Apriel2TextConfig): |
| model_type = "apriel2" |
|
|
| def __init__( |
| self, |
| hidden_size: int = 4096, |
| vocab_size: int = 32000, |
| decoder: Optional[dict] = None, |
| embeddings: Optional[dict] = None, |
| head: Optional[dict] = None, |
| vision_encoder: Optional[dict] = None, |
| image_token_index: Optional[int] = None, |
| tie_word_embeddings: bool = False, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| pad_token_id: Optional[int] = None, |
| use_cache: bool = True, |
| **kwargs, |
| ): |
| super().__init__( |
| hidden_size=hidden_size, |
| vocab_size=vocab_size, |
| decoder=decoder, |
| embeddings=embeddings, |
| head=head, |
| tie_word_embeddings=tie_word_embeddings, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| pad_token_id=pad_token_id, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| self.vision_encoder = vision_encoder |
| self.image_token_index = image_token_index |
|
|