| # Factorized Scaled Ternary — Config E Equations |
|
|
| ## The Core Identity |
|
|
| ``` |
| W = S ⊙ T where S = |W|, T = sign(W) |
| ``` |
|
|
| This is always an identity for any real number: |w| × sign(w) = w. |
| No approximation. No information loss at active positions. |
|
|
| --- |
|
|
| ## Per-Config Forward Pass Equations |
|
|
| ### Config A (FP32 Baseline) |
|
|
| ``` |
| y = xW + b |
| ``` |
|
|
| Standard linear. No ternarization. Full precision weights. |
|
|
| ### Config C (Learned-S) |
|
|
| ``` |
| T = sign(W) · (|W| > θ) ← ternarize with hard threshold |
| S = learned scalar ← nn.Parameter, trained via Adam |
| W_eff = S · T ← all surviving weights share one scale |
| y = xW_eff + b |
| ``` |
|
|
| Problem: one scalar S cannot capture per-element magnitude variation. |
| At θ=0.05, only ~30% of weights survive, and they all get the same |
| scale regardless of actual magnitude. S converges to ~0.30. |
|
|
| ### Config D (Computed-S per-layer / BitNet absmean) |
|
|
| ``` |
| T = sign(W) · (|W| > θ) ← same ternarize |
| S = mean(|W|) ← single scalar computed from weight stats |
| W_eff = S · T ← uniform scale like BitNet |
| y = xW_eff + b |
| ``` |
|
|
| S tracks actual weight magnitude dynamically. Better than C because |
| S is always correct for the current weight distribution. Still one |
| scalar per layer — coarse granularity. |
|
|
| ### Config E (Factorized — per-element) |
|
|
| ``` |
| T = sign(W) · (|W| > θ) ← ternary mask: {-1, 0, +1} |
| S = |W| ← per-element magnitude tensor (same shape as W) |
| W_eff = S ⊙ T ← element-wise product |
| y = xW_eff + b |
| ``` |
|
|
| At every active position (where |W[i,j]| > θ): |
|
|
| ``` |
| W_eff[i,j] = |W[i,j]| × sign(W[i,j]) = W[i,j] |
| ``` |
|
|
| Exact reconstruction. The only information loss is at positions where |
| |W[i,j]| ≤ θ — these are zeroed out. Those positions had near-zero |
| contribution anyway. Measured gap: ~5% vs FP32 at 5000 steps. |
|
|
| --- |
|
|
| ## STE Backward Pass |
|
|
| The TernarizeSTE function: |
|
|
| ``` |
| Forward: T = sign(W) · (|W| > θ) |
| Backward: ∂L/∂W = ∂L/∂T_eff · ∂T_eff/∂W |
| ``` |
|
|
| Where the STE approximation treats the forward as identity within |
| the active zone: |
|
|
| ``` |
| ∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ) |
| ``` |
|
|
| For Config E specifically, since W_eff = |W| ⊙ T: |
| |
| ``` |
| ∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ) |
| ``` |
| |
| The gradient flows through to W directly. The optimizer updates W |
| via standard addition (Adam: W ← W - lr × update). The magnitude |
| |W| and direction sign(W) emerge naturally from this process. |
| |
| No separate S gradient needed. S = |W| is a deterministic function |
| of W, not a learned parameter. |
| |
| --- |
| |
| ## Compute Comparison |
| |
| | Operation | FP32 | Config C | Config D | Config E | |
| |------------------------|-----------------|-------------------|-------------------|-----------------------| |
| | Weight storage | full FP32 | FP32 + 1 scalar | FP32 only | FP32 only | |
| | Effective weights | W | S·T (uniform) | S·T (uniform) | S⊙T (per-element) | |
| | Forward matmul | x × W | x × (S·T) | x × (S·T) | x × (S⊙T) | |
| | STE backward | N/A | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | |
| | Extra compute / step | 0 | 1 scalar mul | 1 absmean | 1 abs + element mul | |
| | BPW (inference) | 32.0 | 1.58 + overhead | 1.58 | 1.58 | |
| | Measured val loss | 2.165 | 2.895 | 2.515 | 2.273 | |
| | vs FP32 | 1.000× | 1.337× | 1.162× | 1.050× | |
| |
| --- |
| |
| ## Why Config E Wins |
| |
| 1. **Identity reconstruction**: At active positions, W_eff = W exactly. |
| No quantization error. No information bottleneck from uniform scaling. |
| |
| 2. **No extra parameter**: S = |W| is computed, not learned. Fewer |
| parameters means less BPW overhead and simpler architecture. |
| |
| 3. **Natural magnitude evolution**: As training progresses, weights |
| that matter grow in magnitude (larger S), weights that don't |
| shrink toward zero (absorbed by threshold). This is standard |
| gradient descent — we just observe it through the S ⊙ T lens. |
| |
| 4. **Sparsity emerges organically**: The threshold θ = 0.05 naturally |
| kills ~38% of weights in fc1 and ~63% in fc2 by step 5000. These |
| are genuine zeros — structural sparsity available for inference. |
| |
| --- |
| |
| ## Implications for Inference Serialization |
| |
| At inference time, Config E's effective weights can be decomposed |
| at any granularity: |
| |
| | Granularity | Format | BPW | Accuracy | |
| |----------------|--------------------------|--------|-------------| |
| | Per-layer S | T (2-bit) + 1 FP16 S | ~1.58 | Like D | |
| | Per-group S | T (2-bit) + group FP16 | ~1.58 | Better | |
| | Per-element S | T (2-bit) + per-elem FP | ~17 | Full E | |
| |
| The I2_S format (BitNet.cpp) packs 16 ternary weights in 32 bits |
| (2 bits each) with one FP16 group scale. This is the natural |
| serialization target for Config E — group the magnitudes, pack |
| the ternary patterns. |
| |
| Training uses full FP32 W with STE. Inference uses packed T + S. |
| The bridge between them is the fact that W_eff at active positions |
| equals W — so any group quantization of |W| is just standard |
| weight quantization applied to the surviving weights. |
| |
| --- |
| |
| ## The Question: Is Ternary Actually Applied? |
| |
| Config E's effective forward weight is W_eff = |W| ⊙ T. |
| |
| At active positions this equals W itself. So is ternarization |
| even happening? Yes, but subtly: |
| |
| 1. **Forward pass**: The STE ternarizes W into T ∈ {-1, 0, +1}, |
| then re-scales by |W|. The result at active positions is W. |
| But the **sparsity mask** (which positions are zero) IS the |
| ternary structure. T determines shape, S determines magnitude. |
| |
| 2. **Gradient pass**: The STE mask blocks gradient to small weights. |
| This IS ternary training — the gradient is zeroed for positions |
| below threshold. This is the mechanism that creates sparsity. |
| |
| 3. **Inference packing**: T (the ternary pattern) can be packed at |
| 2 bits per weight. S (magnitudes) can be quantized per-group. |
| This IS ternary inference — just with group scales attached. |
| |
| 4. **What Config E is NOT**: It is not a ternary-only matmul in |
| training. During training, the matmul uses full-precision |
| W_eff values (which equal W at active positions). The ternary |
| advantage is realized at inference, not during training. |
| |
| **The ternary IS the sparsity mask.** The magnitude IS the weight. |
| Training maintains both in one parameter. Inference separates them. |
| |