ARBS / testing /FACTORIZED-EQUATIONS.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# Factorized Scaled Ternary — Config E Equations
## The Core Identity
```
W = S ⊙ T where S = |W|, T = sign(W)
```
This is always an identity for any real number: |w| × sign(w) = w.
No approximation. No information loss at active positions.
---
## Per-Config Forward Pass Equations
### Config A (FP32 Baseline)
```
y = xW + b
```
Standard linear. No ternarization. Full precision weights.
### Config C (Learned-S)
```
T = sign(W) · (|W| > θ) ← ternarize with hard threshold
S = learned scalar ← nn.Parameter, trained via Adam
W_eff = S · T ← all surviving weights share one scale
y = xW_eff + b
```
Problem: one scalar S cannot capture per-element magnitude variation.
At θ=0.05, only ~30% of weights survive, and they all get the same
scale regardless of actual magnitude. S converges to ~0.30.
### Config D (Computed-S per-layer / BitNet absmean)
```
T = sign(W) · (|W| > θ) ← same ternarize
S = mean(|W|) ← single scalar computed from weight stats
W_eff = S · T ← uniform scale like BitNet
y = xW_eff + b
```
S tracks actual weight magnitude dynamically. Better than C because
S is always correct for the current weight distribution. Still one
scalar per layer — coarse granularity.
### Config E (Factorized — per-element)
```
T = sign(W) · (|W| > θ) ← ternary mask: {-1, 0, +1}
S = |W| ← per-element magnitude tensor (same shape as W)
W_eff = S ⊙ T ← element-wise product
y = xW_eff + b
```
At every active position (where |W[i,j]| > θ):
```
W_eff[i,j] = |W[i,j]| × sign(W[i,j]) = W[i,j]
```
Exact reconstruction. The only information loss is at positions where
|W[i,j]| ≤ θ — these are zeroed out. Those positions had near-zero
contribution anyway. Measured gap: ~5% vs FP32 at 5000 steps.
---
## STE Backward Pass
The TernarizeSTE function:
```
Forward: T = sign(W) · (|W| > θ)
Backward: ∂L/∂W = ∂L/∂T_eff · ∂T_eff/∂W
```
Where the STE approximation treats the forward as identity within
the active zone:
```
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
```
For Config E specifically, since W_eff = |W| ⊙ T:
```
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
```
The gradient flows through to W directly. The optimizer updates W
via standard addition (Adam: W ← W - lr × update). The magnitude
|W| and direction sign(W) emerge naturally from this process.
No separate S gradient needed. S = |W| is a deterministic function
of W, not a learned parameter.
---
## Compute Comparison
| Operation | FP32 | Config C | Config D | Config E |
|------------------------|-----------------|-------------------|-------------------|-----------------------|
| Weight storage | full FP32 | FP32 + 1 scalar | FP32 only | FP32 only |
| Effective weights | W | S·T (uniform) | S·T (uniform) | S⊙T (per-element) |
| Forward matmul | x × W | x × (S·T) | x × (S·T) | x × (S⊙T) |
| STE backward | N/A | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff |
| Extra compute / step | 0 | 1 scalar mul | 1 absmean | 1 abs + element mul |
| BPW (inference) | 32.0 | 1.58 + overhead | 1.58 | 1.58 |
| Measured val loss | 2.165 | 2.895 | 2.515 | 2.273 |
| vs FP32 | 1.000× | 1.337× | 1.162× | 1.050× |
---
## Why Config E Wins
1. **Identity reconstruction**: At active positions, W_eff = W exactly.
No quantization error. No information bottleneck from uniform scaling.
2. **No extra parameter**: S = |W| is computed, not learned. Fewer
parameters means less BPW overhead and simpler architecture.
3. **Natural magnitude evolution**: As training progresses, weights
that matter grow in magnitude (larger S), weights that don't
shrink toward zero (absorbed by threshold). This is standard
gradient descent — we just observe it through the S ⊙ T lens.
4. **Sparsity emerges organically**: The threshold θ = 0.05 naturally
kills ~38% of weights in fc1 and ~63% in fc2 by step 5000. These
are genuine zeros — structural sparsity available for inference.
---
## Implications for Inference Serialization
At inference time, Config E's effective weights can be decomposed
at any granularity:
| Granularity | Format | BPW | Accuracy |
|----------------|--------------------------|--------|-------------|
| Per-layer S | T (2-bit) + 1 FP16 S | ~1.58 | Like D |
| Per-group S | T (2-bit) + group FP16 | ~1.58 | Better |
| Per-element S | T (2-bit) + per-elem FP | ~17 | Full E |
The I2_S format (BitNet.cpp) packs 16 ternary weights in 32 bits
(2 bits each) with one FP16 group scale. This is the natural
serialization target for Config E — group the magnitudes, pack
the ternary patterns.
Training uses full FP32 W with STE. Inference uses packed T + S.
The bridge between them is the fact that W_eff at active positions
equals W — so any group quantization of |W| is just standard
weight quantization applied to the surviving weights.
---
## The Question: Is Ternary Actually Applied?
Config E's effective forward weight is W_eff = |W| ⊙ T.
At active positions this equals W itself. So is ternarization
even happening? Yes, but subtly:
1. **Forward pass**: The STE ternarizes W into T ∈ {-1, 0, +1},
then re-scales by |W|. The result at active positions is W.
But the **sparsity mask** (which positions are zero) IS the
ternary structure. T determines shape, S determines magnitude.
2. **Gradient pass**: The STE mask blocks gradient to small weights.
This IS ternary training — the gradient is zeroed for positions
below threshold. This is the mechanism that creates sparsity.
3. **Inference packing**: T (the ternary pattern) can be packed at
2 bits per weight. S (magnitudes) can be quantized per-group.
This IS ternary inference — just with group scales attached.
4. **What Config E is NOT**: It is not a ternary-only matmul in
training. During training, the matmul uses full-precision
W_eff values (which equal W at active positions). The ternary
advantage is realized at inference, not during training.
**The ternary IS the sparsity mask.** The magnitude IS the weight.
Training maintains both in one parameter. Inference separates them.