| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class CircuitGPTConfig(PretrainedConfig): |
| """ |
| Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig. |
| Only the fields exercised by the Neuronpedia runs are exposed. |
| """ |
|
|
| model_type = "circuitgpt" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 2048, |
| block_size: int = 256, |
| n_layer: int = 8, |
| n_head: int = 8, |
| d_model: int = 1024, |
| d_mlp: int | None = None, |
| d_head: int | None = None, |
| dropout: float = 0.0, |
| bias: bool = True, |
| ln_bias: bool = True, |
| rms_norm: bool = True, |
| activation_type: str = "gelu", |
| residual_activation_type: str = "identity", |
| tied_unembed: bool = False, |
| unembed_rank: int | None = None, |
| afrac: float | None = None, |
| afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out", |
| flash: bool = True, |
| use_position_embeddings: bool = False, |
| sink: bool = False, |
| enable_bigram_table: bool = False, |
| learnable_bigram_table: bool = False, |
| bigram_table_rank: int | None = None, |
| dropout_cat_pos_emb: bool = False, |
| sinusoidal_cat_pos_emb: bool = False, |
| d_pos_emb: int | None = None, |
| auto_map: dict[str, str] | None = None, |
| **kwargs: Any, |
| ) -> None: |
| |
| for key in [ |
| "afrac_ste", |
| "afrac_ste_only_non_neurons", |
| "afrac_approx", |
| "rtopk", |
| "mup", |
| "mup_width_multiplier", |
| "grad_checkpointing", |
| "enable_fp8_linear", |
| "scale_invariance", |
| "cat_pos_emb", |
| ]: |
| kwargs.pop(key, None) |
| d_mlp = d_mlp or 4 * d_model |
| d_head = d_head or d_model // n_head |
|
|
| |
| bos_token_id = kwargs.pop("bos_token_id", None) |
| eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1) |
| pad_token_id = kwargs.pop("pad_token_id", None) |
|
|
| super().__init__( |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| pad_token_id=pad_token_id, |
| **kwargs, |
| ) |
|
|
| self.vocab_size = vocab_size |
| self.block_size = block_size |
| self.max_position_embeddings = block_size |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.d_model = d_model |
| self.d_mlp = d_mlp |
| self.d_head = d_head |
| self.dropout = dropout |
| self.bias = bias |
| self.ln_bias = ln_bias |
| self.rms_norm = rms_norm |
| self.activation_type = activation_type |
| self.residual_activation_type = residual_activation_type |
| self.tied_unembed = tied_unembed |
| self.unembed_rank = unembed_rank |
| self.afrac = afrac |
| self.afrac_loctypes = afrac_loctypes |
| self.flash = flash |
| self.use_position_embeddings = use_position_embeddings |
| self.d_pos_emb = d_pos_emb |
| self.sink = sink |
| self.enable_bigram_table = enable_bigram_table |
| self.learnable_bigram_table = learnable_bigram_table |
| self.bigram_table_rank = bigram_table_rank |
| self.dropout_cat_pos_emb = dropout_cat_pos_emb |
| self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb |
| self.is_decoder = True |
| |
| |
| self.auto_map = auto_map or { |
| "AutoConfig": "config.CircuitGPTConfig", |
| "AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM", |
| } |
|
|
| |
| |
| |
| @classmethod |
| def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": |
| config_dict: dict[str, Any] = { |
| "vocab_size": circuit_config.vocab_size, |
| "block_size": circuit_config.block_size, |
| "n_layer": circuit_config.n_layer, |
| "n_head": circuit_config.n_head, |
| "d_model": circuit_config.d_model, |
| "d_mlp": circuit_config.d_mlp, |
| "d_head": circuit_config.d_head, |
| "dropout": circuit_config.dropout, |
| "bias": circuit_config.bias, |
| "ln_bias": circuit_config.ln_bias, |
| "rms_norm": circuit_config.rms_norm, |
| "activation_type": circuit_config.activation_type, |
| "residual_activation_type": circuit_config.residual_activation_type, |
| "tied_unembed": circuit_config.tied_unembed, |
| "unembed_rank": circuit_config.unembed_rank, |
| "afrac": circuit_config.afrac, |
| "afrac_loctypes": circuit_config.afrac_loctypes, |
| "flash": circuit_config.flash, |
| "use_position_embeddings": circuit_config.d_pos_emb is not None, |
| "d_pos_emb": getattr(circuit_config, "d_pos_emb", None), |
| "sink": getattr(circuit_config, "sink", False), |
| "enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False), |
| "learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False), |
| "bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None), |
| "dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False), |
| "sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False), |
| } |
| return cls(**config_dict) |
|
|
| def to_circuit_config(self) -> "GPTConfig": |
| from circuit_sparsity.gpt import GPTConfig as CircuitConfig |
|
|
| config_kwargs: dict[str, Any] = dict( |
| vocab_size=self.vocab_size, |
| block_size=self.block_size, |
| n_layer=self.n_layer, |
| n_head=self.n_head, |
| d_model=self.d_model, |
| dropout=self.dropout, |
| bias=self.bias, |
| ln_bias=self.ln_bias, |
| rms_norm=self.rms_norm, |
| activation_type=self.activation_type, |
| residual_activation_type=self.residual_activation_type, |
| tied_unembed=self.tied_unembed, |
| unembed_rank=self.unembed_rank, |
| afrac=self.afrac, |
| afrac_loctypes=self.afrac_loctypes, |
| flash=self.flash, |
| afrac_ste=False, |
| afrac_ste_only_non_neurons=False, |
| afrac_approx=False, |
| rtopk=False, |
| mup=False, |
| mup_width_multiplier=None, |
| grad_checkpointing=False, |
| enable_fp8_linear=False, |
| scale_invariance=False, |
| d_mlp=self.d_mlp, |
| d_head=self.d_head, |
| enable_sparse_kernels=False, |
| enable_bigram_table=self.enable_bigram_table, |
| learnable_bigram_table=self.learnable_bigram_table, |
| bigram_table_rank=self.bigram_table_rank, |
| d_pos_emb=self.d_pos_emb |
| if self.d_pos_emb is not None |
| else (self.d_model if self.use_position_embeddings else None), |
| sink=self.sink, |
| dropout_cat_pos_emb=self.dropout_cat_pos_emb, |
| sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb, |
| ) |
| return CircuitConfig(**config_kwargs) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| base = super().to_dict() |
| data = { |
| "vocab_size": self.vocab_size, |
| "block_size": self.block_size, |
| "n_layer": self.n_layer, |
| "n_head": self.n_head, |
| "d_model": self.d_model, |
| "d_mlp": self.d_mlp, |
| "d_head": self.d_head, |
| "dropout": self.dropout, |
| "bias": self.bias, |
| "ln_bias": self.ln_bias, |
| "rms_norm": self.rms_norm, |
| "activation_type": self.activation_type, |
| "residual_activation_type": self.residual_activation_type, |
| "tied_unembed": self.tied_unembed, |
| "unembed_rank": self.unembed_rank, |
| "flash": self.flash, |
| "afrac": self.afrac, |
| "afrac_loctypes": self.afrac_loctypes, |
| "use_position_embeddings": self.use_position_embeddings, |
| "d_pos_emb": self.d_pos_emb, |
| "sink": self.sink, |
| "enable_bigram_table": self.enable_bigram_table, |
| "learnable_bigram_table": self.learnable_bigram_table, |
| "bigram_table_rank": self.bigram_table_rank, |
| "dropout_cat_pos_emb": self.dropout_cat_pos_emb, |
| "sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb, |
| "auto_map": self.auto_map, |
| } |
| base.update(data) |
| return base |
|
|