| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import dataclasses |
| import warnings |
| from dataclasses import dataclass, MISSING |
| from functools import partial |
| from typing import Optional, Dict, Any |
|
|
| from .transformers_4_44_2__configuration_llama import LlamaConfig |
| from .transformers_4_44_2__modeling_rope_utils import \ |
| rope_config_validation |
|
|
|
|
| class DeciLMConfig(LlamaConfig): |
| model_type = "nemotron-nas" |
|
|
| def __init__( |
| self, |
| block_configs: list[dict] | list["BlockConfig"] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.intermediate_size = None |
| self.num_key_value_heads = None |
|
|
| if block_configs is not None: |
| assert len(block_configs) == self.num_hidden_layers |
| if isinstance(block_configs[0], dict): |
| block_configs = [BlockConfig(**conf) for conf in block_configs] |
| self.block_configs: list[BlockConfig] = block_configs |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| self_dict = super().to_dict() |
| if self.block_configs is not None: |
| self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] |
| return self_dict |
|
|
|
|
| @partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
| class AttentionConfig: |
| no_op: bool = False |
| replace_with_linear: bool = False |
| n_heads_in_group: Optional[int] = None |
|
|
| def __post_init__(self): |
| assert not (self.no_op and self.replace_with_linear) |
| if self.no_op or self.replace_with_linear: |
| object.__setattr__(self, 'n_heads_in_group', None) |
| else: |
| assert self.n_heads_in_group is not None |
|
|
|
|
| @partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
| class FFNConfig: |
| no_op: bool = False |
| replace_with_linear: bool = False |
| ffn_mult: Optional[float] = None |
|
|
| def __post_init__(self): |
| assert not (self.no_op and self.replace_with_linear) |
| if self.no_op or self.replace_with_linear: |
| object.__setattr__(self, 'ffn_mult', None) |
| else: |
| assert self.ffn_mult is not None |
|
|
|
|
| @partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
| class BlockConfig: |
| attention: AttentionConfig = MISSING |
| ffn: FFNConfig = MISSING |
|
|
| def __post_init__(self): |
| """ |
| Init subblock dataclasses from dicts |
| """ |
| for subblock_name in dataclasses.fields(self): |
| subblock_config = getattr(self, subblock_name.name) |
| if isinstance(subblock_config, dict): |
| subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)] |
| unsupported_fields = [field_name for field_name in subblock_config.keys() |
| if field_name not in subblock_fields] |
| if len(unsupported_fields) > 0: |
| warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}") |
| subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields} |
| object.__setattr__(self, subblock_name.name, |
| subblock_name.type(**subblock_config)) |
|
|