magicBERT / config.py
nishtahir's picture
Upload folder using huggingface_hub
7cf5414 verified
from typing import Any
from transformers import AutoConfig, GenerationConfig, PretrainedConfig
class MagicBERTConfig(PretrainedConfig):
model_type = "magicBERT"
def __init__(
self,
*,
attention_dropout: float = 0.15,
d_model: int = 768,
dim_feed_forward: int = 3072,
embedding_dropout: float = 0.15,
mask_token_id: int = 0,
num_attention_heads: int = 8,
num_encoder_layers: int = 4,
pad_token_id: int = 1,
seq_len: int = 100,
tie_embeddings: bool = True,
vocab_size: int = 35000,
**kwargs,
):
if "tie_word_embeddings" not in kwargs:
kwargs["tie_word_embeddings"] = tie_embeddings
super().__init__(**kwargs)
self.attention_dropout = attention_dropout
self.d_model = d_model
self.dim_feed_forward = dim_feed_forward or int(d_model * 8 / 3)
self.embedding_dropout = embedding_dropout
self.num_attention_heads = num_attention_heads
self.mask_token_id = mask_token_id
self.num_encoder_layers = num_encoder_layers
self.seq_len = seq_len
self.tie_embeddings = tie_embeddings
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
class MagicBERTGenerationConfig(GenerationConfig):
model_type = MagicBERTConfig.model_type
def __init__(self, *, cards: list[dict[str, Any]] | None = None, **kwargs):
super().__init__(**kwargs)
self.cards = cards or []
MagicBERTConfig.register_for_auto_class(AutoConfig)