File size: 1,724 Bytes
e65ee65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import PretrainedConfig


class SlimMoEConfig(PretrainedConfig):
    model_type = "slim_moe"

    def __init__(

            self,

            vocab_size: int = 50257,

            dim: int = 768,

            num_hidden_layers: int = 12,

            num_heads: int = 12,

            hidden_dim: int = 2048,

            num_experts: int = 4,

            max_seq_len: int = 2048,

            dropout: float = 0.1,

            adaptive_routing: bool = True,

            **kwargs

    ):
        self.vocab_size = vocab_size
        self.dim = dim
        self.num_hidden_layers = num_hidden_layers
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.adaptive_routing = adaptive_routing

        # --- FIX: Enable automatic weight tying by the framework ---
        # This tells the PreTrainedModel's post_init to handle the tie correctly.
        self.tie_word_embeddings = True

        super().__init__(**kwargs)

    @classmethod
    def for_250m(cls, vocab_size: int = 50257, max_seq_len: int = 2048, dropout: float = 0.1):
        """

        Create configuration for ~300M parameter model.

        Uses: dim=768, layers=16, heads=12, hidden_dim=1536, experts=4

        This yields approximately 280-290M parameters, safely under 250M.

        """
        return cls(
            vocab_size=vocab_size,
            dim=768,
            num_hidden_layers=16,
            num_heads=12,
            hidden_dim=1536,
            num_experts=4,
            max_seq_len=max_seq_len,
            dropout=dropout
        )