| from dataclasses import dataclass |
| from dtypes import DType |
|
|
|
|
| @dataclass |
| class Model: |
| vocab_size: int |
| num_layers: int |
| hidden_dim: int |
| intermediate_size: int |
| weight_tied_embeddings: bool |
| active_experts: int |
| total_experts: int |
| is_moe: bool |
|
|
|
|
| @dataclass |
| class Parallelism: |
| tensor_parallelism: int |
| pipeline_parallelism: int |
| context_parallelism: int |
| expert_parallelism: int |
| fsdp_enabled: bool |
| fsdp_parallelism: int |
| fsdp_strategy: str |
|
|
|
|
| @dataclass |
| class Training: |
| sequence_length: int |
| batch_size: int |
| gradient_checkpointing: bool |
| grad_accumulation: bool |
| precision: DType |
| mixed_precision: bool |
| param_dtype: DType |
| reduce_dtype: DType |
|
|