File size: 1,574 Bytes
7cf5414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Any

from transformers import AutoConfig, GenerationConfig, PretrainedConfig


class MagicBERTConfig(PretrainedConfig):
    model_type = "magicBERT"

    def __init__(
        self,
        *,
        attention_dropout: float = 0.15,
        d_model: int = 768,
        dim_feed_forward: int = 3072,
        embedding_dropout: float = 0.15,
        mask_token_id: int = 0,
        num_attention_heads: int = 8,
        num_encoder_layers: int = 4,
        pad_token_id: int = 1,
        seq_len: int = 100,
        tie_embeddings: bool = True,
        vocab_size: int = 35000,
        **kwargs,
    ):
        if "tie_word_embeddings" not in kwargs:
            kwargs["tie_word_embeddings"] = tie_embeddings
        super().__init__(**kwargs)
        self.attention_dropout = attention_dropout
        self.d_model = d_model
        self.dim_feed_forward = dim_feed_forward or int(d_model * 8 / 3)
        self.embedding_dropout = embedding_dropout
        self.num_attention_heads = num_attention_heads
        self.mask_token_id = mask_token_id
        self.num_encoder_layers = num_encoder_layers
        self.seq_len = seq_len
        self.tie_embeddings = tie_embeddings
        self.vocab_size = vocab_size
        self.pad_token_id = pad_token_id


class MagicBERTGenerationConfig(GenerationConfig):
    model_type = MagicBERTConfig.model_type

    def __init__(self, *, cards: list[dict[str, Any]] | None = None, **kwargs):
        super().__init__(**kwargs)
        self.cards = cards or []


MagicBERTConfig.register_for_auto_class(AutoConfig)