# Factorized Scaled Ternary — Config E Equations ## The Core Identity ``` W = S ⊙ T where S = |W|, T = sign(W) ``` This is always an identity for any real number: |w| × sign(w) = w. No approximation. No information loss at active positions. --- ## Per-Config Forward Pass Equations ### Config A (FP32 Baseline) ``` y = xW + b ``` Standard linear. No ternarization. Full precision weights. ### Config C (Learned-S) ``` T = sign(W) · (|W| > θ) ← ternarize with hard threshold S = learned scalar ← nn.Parameter, trained via Adam W_eff = S · T ← all surviving weights share one scale y = xW_eff + b ``` Problem: one scalar S cannot capture per-element magnitude variation. At θ=0.05, only ~30% of weights survive, and they all get the same scale regardless of actual magnitude. S converges to ~0.30. ### Config D (Computed-S per-layer / BitNet absmean) ``` T = sign(W) · (|W| > θ) ← same ternarize S = mean(|W|) ← single scalar computed from weight stats W_eff = S · T ← uniform scale like BitNet y = xW_eff + b ``` S tracks actual weight magnitude dynamically. Better than C because S is always correct for the current weight distribution. Still one scalar per layer — coarse granularity. ### Config E (Factorized — per-element) ``` T = sign(W) · (|W| > θ) ← ternary mask: {-1, 0, +1} S = |W| ← per-element magnitude tensor (same shape as W) W_eff = S ⊙ T ← element-wise product y = xW_eff + b ``` At every active position (where |W[i,j]| > θ): ``` W_eff[i,j] = |W[i,j]| × sign(W[i,j]) = W[i,j] ``` Exact reconstruction. The only information loss is at positions where |W[i,j]| ≤ θ — these are zeroed out. Those positions had near-zero contribution anyway. Measured gap: ~5% vs FP32 at 5000 steps. --- ## STE Backward Pass The TernarizeSTE function: ``` Forward: T = sign(W) · (|W| > θ) Backward: ∂L/∂W = ∂L/∂T_eff · ∂T_eff/∂W ``` Where the STE approximation treats the forward as identity within the active zone: ``` ∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ) ``` For Config E specifically, since W_eff = |W| ⊙ T: ``` ∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ) ``` The gradient flows through to W directly. The optimizer updates W via standard addition (Adam: W ← W - lr × update). The magnitude |W| and direction sign(W) emerge naturally from this process. No separate S gradient needed. S = |W| is a deterministic function of W, not a learned parameter. --- ## Compute Comparison | Operation | FP32 | Config C | Config D | Config E | |------------------------|-----------------|-------------------|-------------------|-----------------------| | Weight storage | full FP32 | FP32 + 1 scalar | FP32 only | FP32 only | | Effective weights | W | S·T (uniform) | S·T (uniform) | S⊙T (per-element) | | Forward matmul | x × W | x × (S·T) | x × (S·T) | x × (S⊙T) | | STE backward | N/A | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | | Extra compute / step | 0 | 1 scalar mul | 1 absmean | 1 abs + element mul | | BPW (inference) | 32.0 | 1.58 + overhead | 1.58 | 1.58 | | Measured val loss | 2.165 | 2.895 | 2.515 | 2.273 | | vs FP32 | 1.000× | 1.337× | 1.162× | 1.050× | --- ## Why Config E Wins 1. **Identity reconstruction**: At active positions, W_eff = W exactly. No quantization error. No information bottleneck from uniform scaling. 2. **No extra parameter**: S = |W| is computed, not learned. Fewer parameters means less BPW overhead and simpler architecture. 3. **Natural magnitude evolution**: As training progresses, weights that matter grow in magnitude (larger S), weights that don't shrink toward zero (absorbed by threshold). This is standard gradient descent — we just observe it through the S ⊙ T lens. 4. **Sparsity emerges organically**: The threshold θ = 0.05 naturally kills ~38% of weights in fc1 and ~63% in fc2 by step 5000. These are genuine zeros — structural sparsity available for inference. --- ## Implications for Inference Serialization At inference time, Config E's effective weights can be decomposed at any granularity: | Granularity | Format | BPW | Accuracy | |----------------|--------------------------|--------|-------------| | Per-layer S | T (2-bit) + 1 FP16 S | ~1.58 | Like D | | Per-group S | T (2-bit) + group FP16 | ~1.58 | Better | | Per-element S | T (2-bit) + per-elem FP | ~17 | Full E | The I2_S format (BitNet.cpp) packs 16 ternary weights in 32 bits (2 bits each) with one FP16 group scale. This is the natural serialization target for Config E — group the magnitudes, pack the ternary patterns. Training uses full FP32 W with STE. Inference uses packed T + S. The bridge between them is the fact that W_eff at active positions equals W — so any group quantization of |W| is just standard weight quantization applied to the surviving weights. --- ## The Question: Is Ternary Actually Applied? Config E's effective forward weight is W_eff = |W| ⊙ T. At active positions this equals W itself. So is ternarization even happening? Yes, but subtly: 1. **Forward pass**: The STE ternarizes W into T ∈ {-1, 0, +1}, then re-scales by |W|. The result at active positions is W. But the **sparsity mask** (which positions are zero) IS the ternary structure. T determines shape, S determines magnitude. 2. **Gradient pass**: The STE mask blocks gradient to small weights. This IS ternary training — the gradient is zeroed for positions below threshold. This is the mechanism that creates sparsity. 3. **Inference packing**: T (the ternary pattern) can be packed at 2 bits per weight. S (magnitudes) can be quantized per-group. This IS ternary inference — just with group scales attached. 4. **What Config E is NOT**: It is not a ternary-only matmul in training. During training, the matmul uses full-precision W_eff values (which equal W at active positions). The ternary advantage is realized at inference, not during training. **The ternary IS the sparsity mask.** The magnitude IS the weight. Training maintains both in one parameter. Inference separates them.