I came across an interesting idea discussed here and wanted to try it and conduct my own experiments and improvements on top of it.
Current High level structure
Input Tokens
β
Embedding (vocab_size=10000, dim)
β
[ImliBlock Γ N_BLOCKS]
ββ RMSNorm
ββ GatedDeltaNet
β ββ Q,K,V Projections: BitLinear(dim, key_dim/value_dim)
β ββ ShortConvolution (depthwise, causal, kernel=4) on Q,K,V
β ββ SiLU activation on Q,K,V
β ββ L2 Normalization on Q,K
β ββ Gate Projections: gk_proj, beta_proj (BitLinear β sigmoid/logsigmoid)
β ββ Recurrent Delta Rule Update (O(N) linear attention)
β β ββ S_t = exp(gk) * S_{t-1} + beta * outer(k, v - S@k)
β ββ Output Gate: g_proj (BitLinear β sigmoid)
β ββ o_proj: BitLinear(value_dim, dim)
ββ RMSNorm
ββ TernaryGLU
ββ W_gate: BitLinear(dim, hidden) β SiLU
ββ W_up: BitLinear(dim, hidden)
ββ W_down: BitLinear(hidden, dim)
β
RMSNorm
β
Linear Head (tied)
β
Logits
Step by Step Notes:
- I've added Blockwise Quantization - So instead of computing one alpha for the entire layer, we divide the flattened weight matrix into blocks of 256 elements and compute a separate alpha for each block. Each block thus has its own scaling factor adapted to its local weight distribution. A block with small-magnitude weights gets a small alpha, preserving precision in that region.
RMSNormis defined beforeBitLinearto normalizes the input before the linear transformation. Pre-norm (old approach) normalizes at block level, but the data may drift before reachingBitLinear, so normalizes right before quantization make sure activations are in the correct scale for ternary weight operations (random discord anon).- Progressive
SEQ_LENtraining is still buggy but its was implemented i saw it somewhere that learning short range dependencies before long range is beneficial. The model first learns local patterns, word-level statistics, common bigrams and trigrams, etc. Then, when we increase the sequence length, it can build on this foundation to learn longer-range dependencies. - trainv4.py still uses
GatedConvMixerfor now - people new to this, instead of maintaining and updating a recurrent state (inGatedDeltaNet), it applies a depthwise 1D convolution across the sequence. The input is projected to twice its dimension, split into a gate and a value, the value is convolved with a kernel of size 8, and the result is modulated by the sigmoid-activated gate before being projected back down. - Evals are just π for current run. Previous run was better without per block alpha amd rmsnorm before linear. Maybe, i'll remove norm from BitLinear, keep blockwise quantization only.
- Next, my idea is to replace
GatedConvMixerwithGatedDeltaNet, which maintains a recurrent state matrix S that is updated at each timestep using the gated delta rule: the model computes what it expects to know about the current key (Sk), calculates the delta between the actual value and this expectation (v_t - Sk), then updates the state by decaying the old information (via gating) and adding the precise correction (via the delta rule). - The thing with
GatedDeltaNet(and if im not wrong Kimi Attn) is it enables longer sequences, 1024/2048/etc, and maintains O(N) complexity! But it would be slower due to the sequential loop overhead. - Added
GatedDeltaNet- it's naive, but combining GatedDeltaNet with ternary quantization is interesting. The implementation doesn't parallelize across sequence length, whereas in nvlabs repo we can see they've used chunkwise parallel algo. For output norm nvlabs uses FusedRMSNormSwishGate, and mine applies output gate via g_proj but no norm. - Im not working with large seq len so i can skip Flash attn style chunking.
I just need to see how i can implement parallel computation within chunk - we need to compute delta updates in parallel. lol time to use opencode.. Anyway, I won't test things for seq len > 1024 so I can skip parallel chunking for now.
Internal monologue
- Per-layer scaling: Each layer has its own alpha (mean absolute value) for dynamic range.
- Linear scaling with sequence length instead of quadratic.
- Gating Mechanism.
- Up-projects to 2Γdim, splits into gate and value.
- Gate uses sigmoid activation for [0,1] range.
- Value undergoes causal depthwise convolution.
- Element-wise multiplication then down-projection.
TODO:
- Multiple GatedDeltaNet layers with different expand factors.
Update (Run 21 Feb)
Parameters: ~5m
d_model: 192
Blocks: 6
GLU hidden dim: 512
Blocks = 6
Sequence length: 256
Vocab size: 10k
Weight tying
Total tokens trained: 40.4M
Best validation loss: 2.003
Eval (w/ GatedDeltaNet):
Sample:
Once upon a time, there was a little girl named Lily. She loved to play outside in her backyard. One day, she found a shiny toy, and accidentally knocked it down. It was very expensive and had a loud noise.
Lily felt sad and upset. She wanted to go back to her mom. She wanted to keep going. She was scared and her mom told her that she had to go to the hospital.
