Factorized Scaled Ternary — Config E Equations
The Core Identity
W = S ⊙ T where S = |W|, T = sign(W)
This is always an identity for any real number: |w| × sign(w) = w. No approximation. No information loss at active positions.
Per-Config Forward Pass Equations
Config A (FP32 Baseline)
y = xW + b
Standard linear. No ternarization. Full precision weights.
Config C (Learned-S)
T = sign(W) · (|W| > θ) ← ternarize with hard threshold
S = learned scalar ← nn.Parameter, trained via Adam
W_eff = S · T ← all surviving weights share one scale
y = xW_eff + b
Problem: one scalar S cannot capture per-element magnitude variation. At θ=0.05, only ~30% of weights survive, and they all get the same scale regardless of actual magnitude. S converges to ~0.30.
Config D (Computed-S per-layer / BitNet absmean)
T = sign(W) · (|W| > θ) ← same ternarize
S = mean(|W|) ← single scalar computed from weight stats
W_eff = S · T ← uniform scale like BitNet
y = xW_eff + b
S tracks actual weight magnitude dynamically. Better than C because S is always correct for the current weight distribution. Still one scalar per layer — coarse granularity.
Config E (Factorized — per-element)
T = sign(W) · (|W| > θ) ← ternary mask: {-1, 0, +1}
S = |W| ← per-element magnitude tensor (same shape as W)
W_eff = S ⊙ T ← element-wise product
y = xW_eff + b
At every active position (where |W[i,j]| > θ):
W_eff[i,j] = |W[i,j]| × sign(W[i,j]) = W[i,j]
Exact reconstruction. The only information loss is at positions where |W[i,j]| ≤ θ — these are zeroed out. Those positions had near-zero contribution anyway. Measured gap: ~5% vs FP32 at 5000 steps.
STE Backward Pass
The TernarizeSTE function:
Forward: T = sign(W) · (|W| > θ)
Backward: ∂L/∂W = ∂L/∂T_eff · ∂T_eff/∂W
Where the STE approximation treats the forward as identity within the active zone:
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
For Config E specifically, since W_eff = |W| ⊙ T:
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
The gradient flows through to W directly. The optimizer updates W via standard addition (Adam: W ← W - lr × update). The magnitude |W| and direction sign(W) emerge naturally from this process.
No separate S gradient needed. S = |W| is a deterministic function of W, not a learned parameter.
Compute Comparison
| Operation | FP32 | Config C | Config D | Config E |
|---|---|---|---|---|
| Weight storage | full FP32 | FP32 + 1 scalar | FP32 only | FP32 only |
| Effective weights | W | S·T (uniform) | S·T (uniform) | S⊙T (per-element) |
| Forward matmul | x × W | x × (S·T) | x × (S·T) | x × (S⊙T) |
| STE backward | N/A | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff | mask · ∂L/∂W_eff |
| Extra compute / step | 0 | 1 scalar mul | 1 absmean | 1 abs + element mul |
| BPW (inference) | 32.0 | 1.58 + overhead | 1.58 | 1.58 |
| Measured val loss | 2.165 | 2.895 | 2.515 | 2.273 |
| vs FP32 | 1.000× | 1.337× | 1.162× | 1.050× |
Why Config E Wins
Identity reconstruction: At active positions, W_eff = W exactly. No quantization error. No information bottleneck from uniform scaling.
No extra parameter: S = |W| is computed, not learned. Fewer parameters means less BPW overhead and simpler architecture.
Natural magnitude evolution: As training progresses, weights that matter grow in magnitude (larger S), weights that don't shrink toward zero (absorbed by threshold). This is standard gradient descent — we just observe it through the S ⊙ T lens.
Sparsity emerges organically: The threshold θ = 0.05 naturally kills ~38% of weights in fc1 and ~63% in fc2 by step 5000. These are genuine zeros — structural sparsity available for inference.
Implications for Inference Serialization
At inference time, Config E's effective weights can be decomposed at any granularity:
| Granularity | Format | BPW | Accuracy |
|---|---|---|---|
| Per-layer S | T (2-bit) + 1 FP16 S | ~1.58 | Like D |
| Per-group S | T (2-bit) + group FP16 | ~1.58 | Better |
| Per-element S | T (2-bit) + per-elem FP | ~17 | Full E |
The I2_S format (BitNet.cpp) packs 16 ternary weights in 32 bits (2 bits each) with one FP16 group scale. This is the natural serialization target for Config E — group the magnitudes, pack the ternary patterns.
Training uses full FP32 W with STE. Inference uses packed T + S. The bridge between them is the fact that W_eff at active positions equals W — so any group quantization of |W| is just standard weight quantization applied to the surviving weights.
The Question: Is Ternary Actually Applied?
Config E's effective forward weight is W_eff = |W| ⊙ T.
At active positions this equals W itself. So is ternarization even happening? Yes, but subtly:
Forward pass: The STE ternarizes W into T ∈ {-1, 0, +1}, then re-scales by |W|. The result at active positions is W. But the sparsity mask (which positions are zero) IS the ternary structure. T determines shape, S determines magnitude.
Gradient pass: The STE mask blocks gradient to small weights. This IS ternary training — the gradient is zeroed for positions below threshold. This is the mechanism that creates sparsity.
Inference packing: T (the ternary pattern) can be packed at 2 bits per weight. S (magnitudes) can be quantized per-group. This IS ternary inference — just with group scales attached.
What Config E is NOT: It is not a ternary-only matmul in training. During training, the matmul uses full-precision W_eff values (which equal W at active positions). The ternary advantage is realized at inference, not during training.
The ternary IS the sparsity mask. The magnitude IS the weight. Training maintains both in one parameter. Inference separates them.