True Ternary Refactor 2
What Changed In This Pass
tscale.py
Added Triton detection and a CUDA/Triton execution path for
TernaryScaleTensor.Added packed ternary forward and grad-input kernels that read:
T_packedas 5 trits per byteEas int8 log-scale groupsx/grad_yon CUDA
TernaryScaleTensor.forward()now prefers the Triton path when input tensors are CUDA tensors.Fixed the old TileLang negative exponent issue. The previous integer shift path made
2^-kbecome zero. The TileLang fallback now reconstructs2^Ebefore multiplying by sign.Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size
M, notN.Fixed ternary update ordering by calling
update_E()beforeternary_step(), so the gradient sign is not deleted before the exponent update sees it.Kept
T_packedon the original device after repacking. Without this, repacking moved the buffer back to CPU.Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes:
grad_w = grad_y.T @ xas a dense FP32 tensor. It now computes only:
sign(sum_m grad_y[m,n] * x[m,k])directly into an
int8CUDA tensor.Added a Triton
Eupdate kernel. CUDA exponent updates now read packed trits fromT_packedand int8 gradient signs directly, then update int8Ein-place without unpacking fullTthrough PyTorch.Added a Triton ternary-step/repack kernel. CUDA
ternary_step()now updatesT_accum, applies flip thresholds, resets flipped accumulators, and rewritesT_packedin-place without calling Pythonpack_ternary().Removed the dense
int8 grad_sign[N,K]allocation from the normal CUDA autograd path. Backward now retains compactgrad_yandxviews, and the CUDAEupdate / ternary-step kernels recompute the sign reduction directly inside the state-update kernels.Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates
T_accum, rewritesT_packed, and atomically accumulates per-groupEscores in one pass overgrad_yandx.Replaced the separate direct
Ereduction with a tiny score-apply kernel overEgroups. The remaining temporary isint32per scale group, not per logical weight.
trigram.py
- Replaced MoE routers with
TernaryScaleTensor:moe.routermoe.router_h
- Added constructor gates for strict text-only ternary training:
enable_imageenable_vqenable_graphenable_memory_modulesenable_moe
- In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model.
- Added a no-VQ forward path where text relational states go directly into MoE and ByteHead.
train.py
- Default optimizer changed to
signsgd. - Added
--compute_dtype {bf16,fp16,none}. - Added
--strict_ternary.- Forces SignSGD.
- Forces
compute_dtype=none. - Disables VQ, graph, image, and memory modules.
- Freezes any remaining trainable float parameters.
- Added
--freeze_float_paramsfor non-strict runs. - Added model state audit logging before training.
- Fixed the main training loop indentation so optimizer steps run inside the data batch loop.
- Fixed gradient clipping and optimizer construction to use only trainable parameters.
ternary_audit.py
New helper module for reporting:
- logical ternary weights
- packed ternary bytes
- int8 exponent bytes
- int8 accumulator bytes
- trainable floating-point parameters
- frozen floating-point parameters
- floating-point buffers
Strict text-only audit currently reports zero trainable float params and zero float buffers.
testing/test_tscale.py
- Added a CUDA/Triton path test for
TernaryScaleTensor. - The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is
int8on CUDA. - The CUDA/Triton test now verifies the device-side
Eupdate modifies exponent groups and keeps ternary buffers on CUDA. - The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to
+1and resets accumulators. - The CUDA/Triton test now asserts normal backward does not create
_hook_grad_T_sign; it uses retainedgrad_yandxviews for direct state updates. - The CUDA/Triton test now covers the fused ternary-step plus
Escore application path. - Made missing TileLang reference tests skip instead of failing when
tilelang/kernels/dequant_gemm.pyis absent.
Verification Run
Direct CUDA/Triton smoke:
cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0
Strict one-step train:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
Optimizer: SignSGD
Sign-only gradient kernel test:
PASS test_cuda_triton_tscale_path
The same test now also covers the CUDA-side E update kernel and CUDA-side packed ternary repack.
Strict one-step train after replacing dense grad_w:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
Strict one-step train after moving E update to Triton:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
Strict one-step train after moving ternary_step() repack to Triton:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
Strict one-step train after removing dense grad_sign[N,K] from the normal CUDA path:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
Strict one-step train after fusing sign reduction into ternary-step/repack and applying E from per-group scores:
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
Remaining Problem: grad_w = grad_y.T @ x
The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary grad_w in Python:
grad_w = grad_2d.float().t() @ x_2d.float()
grad_sign = grad_w.sign().to(torch.int8)
That has now been replaced by a Triton sign-only reduction. The dense FP32 grad_w tensor is gone from the Triton backward path.
Current remaining issue:
- The CUDA update path now allocates a small
int32score buffer shaped likeE, not like the full logical weight matrix. - The score buffer is
out_dim * ceil(in_dim / group_size)entries. AtT32group size 12, this is about 1/12 as many elements as the logical weights. - The fused kernel uses
tl.atomic_addinto group scores because 5-trit packed bytes and scale groups are not naturally aligned.
This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes.
The actual required update is:
T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j])
E update uses grouped sign statistics from the same reduction.
So the next fix is to avoid storing even the dense int8 sign tensor by fusing the reduction directly into the ternary state update.
Proposed Fix For Gradient Sign Capture
Stage 1: Sign-Only Grad Kernel
Implemented. Added a Triton kernel:
input:
grad_y[M, N]
x[M, K]
output:
grad_sign[N, K] int8
Each program owns a tile of (N, K), loops across M, accumulates a local float32 or int32-ish sum, then immediately converts to sign:
s = sum_m grad_y[m, n] * x[m, k]
g = sign(s)
grad_sign[n, k] = g
This removes the dense FP32 grad_w allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch.
Stage 2: GPU T Accumulator Update And Repack
Implemented as a fused Triton kernel. It computes the sign reduction from grad_y and x, updates T_accum, applies threshold flips, resets flipped accumulator entries, rewrites T_packed in-place, and emits per-group scale scores:
g = sign(sum_m grad_y[m,n] * x[m,k])
score[n, group(k)] += g * old_T[n,k]
T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127)
if T_accum[n,k] > threshold: T[n,k] = +1, T_accum[n,k] = 0
if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0
repack 5 T values into one uint8
This removes Python unpack/repack from the CUDA path.
Stage 3: Separate E Group Update
Implemented for the CUDA/Triton path as a small score-apply kernel.
For scale exponents, update per group rather than per weight:
group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k]
E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127)
This is now a second Triton kernel over (N, groups) that applies the score emitted by the fused ternary-step/repack kernel.
Current implemented split:
- fused sign-reduction +
T_accum+ repack + group-score kernel - small
Escore-apply kernel
No sign tensor is stored in the normal CUDA path.
Stage 4: Tune Or Remove Group-Score Atomics
Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while E groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type.
Options:
- Keep atomics and tune block sizes.
- Change group sizes to align with 5-trit packing where possible.
- Use a group-owned kernel for
Eand a pack-owned kernel forT, accepting two reductions but no atomics. - Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck.
This is now a speed optimization, not a memory correctness blocker.
Stage 5: Activation Ternary Mode
Forward currently consumes normal activation tensors. True ternary training should eventually use:
A in {-1, 0, +1}
W in {-1, 0, +1}
accumulator in int32 or fp32
scale from int8 E
This makes the gradient-sign kernel cheaper because x[m, k] is also sign/zero rather than a dense float.
Native CUDA Speed Path
cuBLAS
cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input.
Use cuBLAS only for baselines or fallback dequantized GEMM.
Triton
Triton is the best immediate path.
Reasons:
- Already installed in this environment.
- Integrates with PyTorch autograd quickly.
- Good for custom packed formats and fused update kernels.
- Lets us remove the dense
grad_wallocation without waiting on TileLang setup.
Near-term target:
Triton forward packed ternary GEMM
Triton grad-input packed ternary GEMM
Triton sign-only grad/state-update kernel
Triton device repack kernel
TileLang
TileLang is still worth getting working after the algorithm is stable. Its docs describe T.gemm lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads.
Use TileLang when:
- Triton semantics are proven.
- Tile sizes and data layout are stable.
- We want a cleaner path toward tensor-core-style tiling and schedule tuning.
TileLang should not be the first blocker because it is not currently installed in this environment.
Custom CUDA/CUTLASS
This is the final performance path, not the first implementation path.
Use custom CUDA/CUTLASS when:
- The Triton kernels prove the training algorithm.
- Profiling shows Triton is bottlenecked by decode/repack/reduction overhead.
- We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy.
This path has the highest ceiling and highest development cost.
Recommended Roadmap
- Keep the current Triton forward and grad-input path.
- Add
ternary_grad_sign_accum_kernelto eliminate densegrad_w. - Add GPU-side
Eupdate. - Add GPU-side repack of
T_packed. - Benchmark strict ternary training memory with
torch.cuda.max_memory_allocated(). - Tune Triton block sizes.
- Install and evaluate TileLang against the Triton kernels.
- Move only proven hot kernels to CUDA/CUTLASS if needed.
Key Constraint
There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is:
- no persistent BF16/FP32/FP8 weights
- no persistent FP optimizer state
- no dense full-precision weight gradients
- packed ternary storage
- int8 scale memory
- streaming reductions into ternary/int8 state
Speed Pass: Scale Update Scheduling
The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling E updates reached about 4.90 it/s on the same small batch after compile.
Changes made:
- Added
--scale_update_interval.- Default:
4. 1: update int8Eevery step.0: disableEupdates and always use the fast direct ternary repack path.
- Default:
- Changed CUDA
update_E()back to the direct group-owned Triton kernel for the normal retainedgrad_y/xpath. ternary_step()now uses the fast direct repack kernel afterupdate_E()instead of the fused score/repack kernel on scheduled scale-update steps.- This removes the temporary int32 score buffer and
tl.atomic_addfrom the default scheduled scale-update path.
Observed before the direct E change:
scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke
scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke
Observed after the direct group-owned E change:
scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke
scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke
Interpretation:
- The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path.
- The direct group-owned
Ekernel is faster even though it performs a separate reduction, because each program owns anEgroup and writes without atomics. - The default
scale_update_interval=4now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path.
Next speed target:
- Benchmark the direct group-owned
Ekernel versus the fused score kernel on real layer sizes. - If direct
Eis faster, remove or demote the fused score kernel to an experimental path. - Tune
BLOCK_N,BLOCK_K, andBLOCK_Min_triton_update_e_direct_kerneland_triton_ternary_step_direct_kernel.