from typing import Any from transformers import AutoConfig, GenerationConfig, PretrainedConfig class MagicBERTConfig(PretrainedConfig): model_type = "magicBERT" def __init__( self, *, attention_dropout: float = 0.15, d_model: int = 768, dim_feed_forward: int = 3072, embedding_dropout: float = 0.15, mask_token_id: int = 0, num_attention_heads: int = 8, num_encoder_layers: int = 4, pad_token_id: int = 1, seq_len: int = 100, tie_embeddings: bool = True, vocab_size: int = 35000, **kwargs, ): if "tie_word_embeddings" not in kwargs: kwargs["tie_word_embeddings"] = tie_embeddings super().__init__(**kwargs) self.attention_dropout = attention_dropout self.d_model = d_model self.dim_feed_forward = dim_feed_forward or int(d_model * 8 / 3) self.embedding_dropout = embedding_dropout self.num_attention_heads = num_attention_heads self.mask_token_id = mask_token_id self.num_encoder_layers = num_encoder_layers self.seq_len = seq_len self.tie_embeddings = tie_embeddings self.vocab_size = vocab_size self.pad_token_id = pad_token_id class MagicBERTGenerationConfig(GenerationConfig): model_type = MagicBERTConfig.model_type def __init__(self, *, cards: list[dict[str, Any]] | None = None, **kwargs): super().__init__(**kwargs) self.cards = cards or [] MagicBERTConfig.register_for_auto_class(AutoConfig)