File size: 4,255 Bytes
7acd624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
SeqCond HuggingFace configuration.
"""

from transformers import PretrainedConfig


class SeqCondConfig(PretrainedConfig):
    """
    Configuration class for SeqCond models.

    SeqCond is a hybrid recurrent-transformer architecture that interleaves
    SeqCond (sequential conditioning) blocks with standard Transformer decoder
    blocks. SeqCond blocks replace softmax attention with a closed-form
    complex-exponential accumulator, enabling O(1) per-token decoding.

    Args:
        d_model: Hidden dimension.
        d_ff: Feed-forward dimension (typically 3×d_model).
        num_layers: Total number of blocks (SeqCond + Transformer).
        vocab_size: Vocabulary size.
        maxlen: Maximum sequence length (also sets KV-cache size).
        dropout: Dropout rate (0.0 disables).
        tie_weights: Whether to tie embedding and LM-head weights.
        num_heads: Number of attention heads in Transformer blocks.
        num_kv_heads: Number of KV heads (GQA). None = full MHA.
        qk_norm: Whether to apply QK-normalization in Transformer blocks.
        qk_norm_eps: Epsilon for QK-norm.
        seqcond_heads: Number of SeqCond memory heads (K).
        num_query_heads: Number of query heads in SeqCond (K_q, must divide K).
        num_thetas: Number of frequency components per head (M).
        derivative_order: Unused — kept for checkpoint compatibility.
        num_anchor_heads: Number of anchor heads (no decay) in SeqCond.
        conv_kernel_size: Depthwise conv kernel size inside SeqCond.
        expand_factor: Inner expansion factor for SeqCond memory dimension.
        out_expand_factor: SwiGLU expansion factor in SeqCond.
        use_positional_embedding: Whether to add learnable positional embeddings.
        seqcond_ratio: Block interleaving ratio. Every (seqcond_ratio+1)-th
            block (1-indexed) is a Transformer block; the rest are SeqCond.
        chunk_size: Chunk size for chunked computation (unused in PyTorch path).
        use_square_matrix: Unused — kept for checkpoint compatibility.
    """

    model_type = "seqcond"

    def __init__(
        self,
        # Core
        d_model: int = 768,
        d_ff: int = 2304,
        num_layers: int = 12,
        vocab_size: int = 100300,
        maxlen: int = 768,
        dropout: float = 0.0,
        tie_weights: bool = True,
        # Transformer block params
        num_heads: int = 8,
        num_kv_heads=None,
        qk_norm: bool = True,
        qk_norm_eps: float = 1e-6,
        # SeqCond block params
        seqcond_heads: int = 32,
        num_query_heads: int = 6,
        num_thetas: int = 4,
        derivative_order: int = 0,
        num_anchor_heads: int = 0,
        conv_kernel_size: int = 4,
        expand_factor: float = 2.0,
        out_expand_factor: int = 3,
        use_positional_embedding: bool = False,
        seqcond_ratio: int = 5,
        chunk_size: int = 128,
        use_square_matrix: bool = False,
        # Special token IDs (filled in by convert_checkpoint.py)
        bos_token_id=None,
        eos_token_id=None,
        pad_token_id=None,
        **kwargs,
    ):
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.maxlen = maxlen
        self.dropout = dropout
        self.tie_weights = tie_weights

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.qk_norm = qk_norm
        self.qk_norm_eps = qk_norm_eps

        self.seqcond_heads = seqcond_heads
        self.num_query_heads = num_query_heads
        self.num_thetas = num_thetas
        self.derivative_order = derivative_order
        self.num_anchor_heads = num_anchor_heads
        self.conv_kernel_size = conv_kernel_size
        self.expand_factor = expand_factor
        self.out_expand_factor = out_expand_factor
        self.use_positional_embedding = use_positional_embedding
        self.seqcond_ratio = seqcond_ratio
        self.chunk_size = chunk_size
        self.use_square_matrix = use_square_matrix

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            **kwargs,
        )