ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR2.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 2

What Changed In This Pass

tscale.py

  • Added Triton detection and a CUDA/Triton execution path for TernaryScaleTensor.

  • Added packed ternary forward and grad-input kernels that read:

    • T_packed as 5 trits per byte
    • E as int8 log-scale groups
    • x / grad_y on CUDA
  • TernaryScaleTensor.forward() now prefers the Triton path when input tensors are CUDA tensors.

  • Fixed the old TileLang negative exponent issue. The previous integer shift path made 2^-k become zero. The TileLang fallback now reconstructs 2^E before multiplying by sign.

  • Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size M, not N.

  • Fixed ternary update ordering by calling update_E() before ternary_step(), so the gradient sign is not deleted before the exponent update sees it.

  • Kept T_packed on the original device after repacking. Without this, repacking moved the buffer back to CPU.

  • Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes:

    grad_w = grad_y.T @ x
    

    as a dense FP32 tensor. It now computes only:

    sign(sum_m grad_y[m,n] * x[m,k])
    

    directly into an int8 CUDA tensor.

  • Added a Triton E update kernel. CUDA exponent updates now read packed trits from T_packed and int8 gradient signs directly, then update int8 E in-place without unpacking full T through PyTorch.

  • Added a Triton ternary-step/repack kernel. CUDA ternary_step() now updates T_accum, applies flip thresholds, resets flipped accumulators, and rewrites T_packed in-place without calling Python pack_ternary().

  • Removed the dense int8 grad_sign[N,K] allocation from the normal CUDA autograd path. Backward now retains compact grad_y and x views, and the CUDA E update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels.

  • Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates T_accum, rewrites T_packed, and atomically accumulates per-group E scores in one pass over grad_y and x.

  • Replaced the separate direct E reduction with a tiny score-apply kernel over E groups. The remaining temporary is int32 per scale group, not per logical weight.

trigram.py

  • Replaced MoE routers with TernaryScaleTensor:
    • moe.router
    • moe.router_h
  • Added constructor gates for strict text-only ternary training:
    • enable_image
    • enable_vq
    • enable_graph
    • enable_memory_modules
    • enable_moe
  • In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model.
  • Added a no-VQ forward path where text relational states go directly into MoE and ByteHead.

train.py

  • Default optimizer changed to signsgd.
  • Added --compute_dtype {bf16,fp16,none}.
  • Added --strict_ternary.
    • Forces SignSGD.
    • Forces compute_dtype=none.
    • Disables VQ, graph, image, and memory modules.
    • Freezes any remaining trainable float parameters.
  • Added --freeze_float_params for non-strict runs.
  • Added model state audit logging before training.
  • Fixed the main training loop indentation so optimizer steps run inside the data batch loop.
  • Fixed gradient clipping and optimizer construction to use only trainable parameters.

ternary_audit.py

New helper module for reporting:

  • logical ternary weights
  • packed ternary bytes
  • int8 exponent bytes
  • int8 accumulator bytes
  • trainable floating-point parameters
  • frozen floating-point parameters
  • floating-point buffers

Strict text-only audit currently reports zero trainable float params and zero float buffers.

testing/test_tscale.py

  • Added a CUDA/Triton path test for TernaryScaleTensor.
  • The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is int8 on CUDA.
  • The CUDA/Triton test now verifies the device-side E update modifies exponent groups and keeps ternary buffers on CUDA.
  • The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to +1 and resets accumulators.
  • The CUDA/Triton test now asserts normal backward does not create _hook_grad_T_sign; it uses retained grad_y and x views for direct state updates.
  • The CUDA/Triton test now covers the fused ternary-step plus E score application path.
  • Made missing TileLang reference tests skip instead of failing when tilelang/kernels/dequant_gemm.py is absent.

Verification Run

Direct CUDA/Triton smoke:

cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0

Strict one-step train:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
Optimizer: SignSGD

Sign-only gradient kernel test:

PASS test_cuda_triton_tscale_path

The same test now also covers the CUDA-side E update kernel and CUDA-side packed ternary repack.

Strict one-step train after replacing dense grad_w:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655

Strict one-step train after moving E update to Triton:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655

Strict one-step train after moving ternary_step() repack to Triton:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655

Strict one-step train after removing dense grad_sign[N,K] from the normal CUDA path:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655

Strict one-step train after fusing sign reduction into ternary-step/repack and applying E from per-group scores:

Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655

Remaining Problem: grad_w = grad_y.T @ x

The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary grad_w in Python:

grad_w = grad_2d.float().t() @ x_2d.float()
grad_sign = grad_w.sign().to(torch.int8)

That has now been replaced by a Triton sign-only reduction. The dense FP32 grad_w tensor is gone from the Triton backward path.

Current remaining issue:

  • The CUDA update path now allocates a small int32 score buffer shaped like E, not like the full logical weight matrix.
  • The score buffer is out_dim * ceil(in_dim / group_size) entries. At T32 group size 12, this is about 1/12 as many elements as the logical weights.
  • The fused kernel uses tl.atomic_add into group scores because 5-trit packed bytes and scale groups are not naturally aligned.

This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes.

The actual required update is:

T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j])
E update uses grouped sign statistics from the same reduction.

So the next fix is to avoid storing even the dense int8 sign tensor by fusing the reduction directly into the ternary state update.

Proposed Fix For Gradient Sign Capture

Stage 1: Sign-Only Grad Kernel

Implemented. Added a Triton kernel:

input:
  grad_y[M, N]
  x[M, K]

output:
  grad_sign[N, K] int8

Each program owns a tile of (N, K), loops across M, accumulates a local float32 or int32-ish sum, then immediately converts to sign:

s = sum_m grad_y[m, n] * x[m, k]
g = sign(s)
grad_sign[n, k] = g

This removes the dense FP32 grad_w allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch.

Stage 2: GPU T Accumulator Update And Repack

Implemented as a fused Triton kernel. It computes the sign reduction from grad_y and x, updates T_accum, applies threshold flips, resets flipped accumulator entries, rewrites T_packed in-place, and emits per-group scale scores:

g = sign(sum_m grad_y[m,n] * x[m,k])
score[n, group(k)] += g * old_T[n,k]
T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127)
if T_accum[n,k] > threshold:  T[n,k] = +1, T_accum[n,k] = 0
if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0
repack 5 T values into one uint8

This removes Python unpack/repack from the CUDA path.

Stage 3: Separate E Group Update

Implemented for the CUDA/Triton path as a small score-apply kernel.

For scale exponents, update per group rather than per weight:

group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k]
E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127)

This is now a second Triton kernel over (N, groups) that applies the score emitted by the fused ternary-step/repack kernel.

Current implemented split:

  1. fused sign-reduction + T_accum + repack + group-score kernel
  2. small E score-apply kernel

No sign tensor is stored in the normal CUDA path.

Stage 4: Tune Or Remove Group-Score Atomics

Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while E groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type.

Options:

  • Keep atomics and tune block sizes.
  • Change group sizes to align with 5-trit packing where possible.
  • Use a group-owned kernel for E and a pack-owned kernel for T, accepting two reductions but no atomics.
  • Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck.

This is now a speed optimization, not a memory correctness blocker.

Stage 5: Activation Ternary Mode

Forward currently consumes normal activation tensors. True ternary training should eventually use:

A in {-1, 0, +1}
W in {-1, 0, +1}
accumulator in int32 or fp32
scale from int8 E

This makes the gradient-sign kernel cheaper because x[m, k] is also sign/zero rather than a dense float.

Native CUDA Speed Path

cuBLAS

cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input.

Use cuBLAS only for baselines or fallback dequantized GEMM.

Triton

Triton is the best immediate path.

Reasons:

  • Already installed in this environment.
  • Integrates with PyTorch autograd quickly.
  • Good for custom packed formats and fused update kernels.
  • Lets us remove the dense grad_w allocation without waiting on TileLang setup.

Near-term target:

Triton forward packed ternary GEMM
Triton grad-input packed ternary GEMM
Triton sign-only grad/state-update kernel
Triton device repack kernel

TileLang

TileLang is still worth getting working after the algorithm is stable. Its docs describe T.gemm lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads.

Use TileLang when:

  • Triton semantics are proven.
  • Tile sizes and data layout are stable.
  • We want a cleaner path toward tensor-core-style tiling and schedule tuning.

TileLang should not be the first blocker because it is not currently installed in this environment.

Custom CUDA/CUTLASS

This is the final performance path, not the first implementation path.

Use custom CUDA/CUTLASS when:

  • The Triton kernels prove the training algorithm.
  • Profiling shows Triton is bottlenecked by decode/repack/reduction overhead.
  • We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy.

This path has the highest ceiling and highest development cost.

Recommended Roadmap

  1. Keep the current Triton forward and grad-input path.
  2. Add ternary_grad_sign_accum_kernel to eliminate dense grad_w.
  3. Add GPU-side E update.
  4. Add GPU-side repack of T_packed.
  5. Benchmark strict ternary training memory with torch.cuda.max_memory_allocated().
  6. Tune Triton block sizes.
  7. Install and evaluate TileLang against the Triton kernels.
  8. Move only proven hot kernels to CUDA/CUTLASS if needed.

Key Constraint

There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is:

  • no persistent BF16/FP32/FP8 weights
  • no persistent FP optimizer state
  • no dense full-precision weight gradients
  • packed ternary storage
  • int8 scale memory
  • streaming reductions into ternary/int8 state

Speed Pass: Scale Update Scheduling

The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling E updates reached about 4.90 it/s on the same small batch after compile.

Changes made:

  • Added --scale_update_interval.
    • Default: 4.
    • 1: update int8 E every step.
    • 0: disable E updates and always use the fast direct ternary repack path.
  • Changed CUDA update_E() back to the direct group-owned Triton kernel for the normal retained grad_y/x path.
  • ternary_step() now uses the fast direct repack kernel after update_E() instead of the fused score/repack kernel on scheduled scale-update steps.
  • This removes the temporary int32 score buffer and tl.atomic_add from the default scheduled scale-update path.

Observed before the direct E change:

scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke
scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke

Observed after the direct group-owned E change:

scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke
scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke

Interpretation:

  • The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path.
  • The direct group-owned E kernel is faster even though it performs a separate reduction, because each program owns an E group and writes without atomics.
  • The default scale_update_interval=4 now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path.

Next speed target:

  • Benchmark the direct group-owned E kernel versus the fused score kernel on real layer sizes.
  • If direct E is faster, remove or demote the fused score kernel to an experimental path.
  • Tune BLOCK_N, BLOCK_K, and BLOCK_M in _triton_update_e_direct_kernel and _triton_ternary_step_direct_kernel.