ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR18.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

TRUE TERNARY REFACTOR 18: BigInt Correlation Scaling

Goal

Port the successful testing/test_bigint_ternary.py scaling rule into the main ARB ternary training path while staying under the >8GB VRAM training target.

The important math is now the dense ternary default:

mean_corr = corr_accum / (step_counter * group_size)
S = 2 ** (E + ARB_BIGINT_CORR_STRENGTH * mean_corr)

Default ARB_BIGINT_CORR_STRENGTH is 4.0, matching the successful BigInt test. E stays fixed after initialization for dense TernaryScaleTensor modules. The trainable signal moves through integer corr_accum.

Main Changes

  • arbitor/kernel/ternary_scale.py

    • Replaced dense T_accum and E_accum training state with corr_accum int64 and step_counter int64.
    • Added BigInt scale expansion in _get_S() and the Triton linear forward/grad-x kernels.
    • Added a Triton direct correlation accumulation kernel for sign(grad_y.T @ x) * T grouped by scale group.
    • Added a custom autograd function that recomputes effective weights in backward instead of saving w_eff in the graph.
    • Disabled the TileLang dense path for this mode until TileLang supports corr_accum in both forward and backward. This prevents the previous fp16 TileLang path from silently breaking ternary training.
  • arbitor/main.py

    • Updated prepare_ternary_backward() and update cleanup to recognize streamed BigInt updates.
    • Preserved LossComponent routing by sending component-specific dense signs into corr_accum.
    • Skipped old per-group threshold float temporaries for BigInt dense layers. Those thresholds were only used by legacy T_accum flips and were a large avoidable allocation at 3B scale.
  • arbitor/kernel/ternary_audit.py

    • Audit now counts corr_accum and step_counter so training-state memory is no longer underreported.
  • arbitor/attention/context_attention.py, arbitor/components.py, arbitor/decoders.py

    • Converted the context attention gate from nn.Linear to TernaryScaleTensor.
    • Froze LTI injection float constants so pure ternary trainers have zero trainable float parameters.
  • testing/test_tscale.py, testing/test_polarity_validation.py

    • Updated expectations from old E_accum/T_accum behavior to fixed E, integer corr_accum, and BigInt step counters.

Validation

Commands run:

python -m compileall -q arbitor training testing
python -m pytest -q testing/test_polarity_validation.py testing/test_tscale.py -k "not model_integration and not runtime_switch"
ARB_TERNARY_BACKEND=triton python training/text.py --steps 1 --batch 1 --ctx 4 --eval-interval 999 --run bigint-smoke
ARB_TERNARY_BACKEND=triton python training/pretrain.py --steps 1 --batch 1 --ctx 4 --text-data training/data/tinyshakespeare.txt --text-weight 1.0 --code-weight 0 --image-weight 0 --audio-weight 0 --video-weight 0 --eval-interval 0 --log-interval 0 --save-interval 0 --no-save --run bigint-pretrain-smoke

Focused tests passed:

42 passed, 2 deselected

3-step full text-model CUDA memory probe:

logical_ternary_weights 3122933472
training_state_mb 1684.27
after_cuda_mb 1706
step 0 alloc_mb 1754 reserved_mb 1838 peak_mb 1998
step 1 alloc_mb 1754 reserved_mb 1838 peak_mb 2038
step 2 alloc_mb 1754 reserved_mb 1842 peak_mb 2036

This is the key result: allocated VRAM did not stack across steps after cleanup.

Current State

  • Dense ternary modules now train through BigInt correlation scaling rather than discrete ternary flips.
  • Persistent dense training state is integer only: T_packed uint8, E int8, corr_accum int64, step_counter int64.
  • The 3.12B logical ternary text path reports 0 trainable float params and runs a 1-step pretrain smoke.
  • Remaining float params are frozen LTI constants, not optimizer state.

Follow-Up

  • Port sparse embedding/VQ tables from legacy T_accum/E_accum to BigInt correlation if the T_accum=321 MB training-state block becomes the next bottleneck.
  • Add TileLang BigInt support only after its kernels accept corr_accum, step_counter, and integer correlation accumulation. The fp16-only TileLang path should stay disabled for BigInt training.