TRUE TERNARY REFACTOR 18: BigInt Correlation Scaling
Goal
Port the successful testing/test_bigint_ternary.py scaling rule into the main ARB ternary training path while staying under the >8GB VRAM training target.
The important math is now the dense ternary default:
mean_corr = corr_accum / (step_counter * group_size)
S = 2 ** (E + ARB_BIGINT_CORR_STRENGTH * mean_corr)
Default ARB_BIGINT_CORR_STRENGTH is 4.0, matching the successful BigInt test. E stays fixed after initialization for dense TernaryScaleTensor modules. The trainable signal moves through integer corr_accum.
Main Changes
arbitor/kernel/ternary_scale.py- Replaced dense
T_accumandE_accumtraining state withcorr_accum int64andstep_counter int64. - Added BigInt scale expansion in
_get_S()and the Triton linear forward/grad-x kernels. - Added a Triton direct correlation accumulation kernel for
sign(grad_y.T @ x) * Tgrouped by scale group. - Added a custom autograd function that recomputes effective weights in backward instead of saving
w_effin the graph. - Disabled the TileLang dense path for this mode until TileLang supports
corr_accumin both forward and backward. This prevents the previous fp16 TileLang path from silently breaking ternary training.
- Replaced dense
arbitor/main.py- Updated
prepare_ternary_backward()and update cleanup to recognize streamed BigInt updates. - Preserved
LossComponentrouting by sending component-specific dense signs intocorr_accum. - Skipped old per-group threshold float temporaries for BigInt dense layers. Those thresholds were only used by legacy
T_accumflips and were a large avoidable allocation at 3B scale.
- Updated
arbitor/kernel/ternary_audit.py- Audit now counts
corr_accumandstep_counterso training-state memory is no longer underreported.
- Audit now counts
arbitor/attention/context_attention.py,arbitor/components.py,arbitor/decoders.py- Converted the context attention gate from
nn.LineartoTernaryScaleTensor. - Froze LTI injection float constants so pure ternary trainers have zero trainable float parameters.
- Converted the context attention gate from
testing/test_tscale.py,testing/test_polarity_validation.py- Updated expectations from old
E_accum/T_accumbehavior to fixedE, integercorr_accum, and BigInt step counters.
- Updated expectations from old
Validation
Commands run:
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_polarity_validation.py testing/test_tscale.py -k "not model_integration and not runtime_switch"
ARB_TERNARY_BACKEND=triton python training/text.py --steps 1 --batch 1 --ctx 4 --eval-interval 999 --run bigint-smoke
ARB_TERNARY_BACKEND=triton python training/pretrain.py --steps 1 --batch 1 --ctx 4 --text-data training/data/tinyshakespeare.txt --text-weight 1.0 --code-weight 0 --image-weight 0 --audio-weight 0 --video-weight 0 --eval-interval 0 --log-interval 0 --save-interval 0 --no-save --run bigint-pretrain-smoke
Focused tests passed:
42 passed, 2 deselected
3-step full text-model CUDA memory probe:
logical_ternary_weights 3122933472
training_state_mb 1684.27
after_cuda_mb 1706
step 0 alloc_mb 1754 reserved_mb 1838 peak_mb 1998
step 1 alloc_mb 1754 reserved_mb 1838 peak_mb 2038
step 2 alloc_mb 1754 reserved_mb 1842 peak_mb 2036
This is the key result: allocated VRAM did not stack across steps after cleanup.
Current State
- Dense ternary modules now train through BigInt correlation scaling rather than discrete ternary flips.
- Persistent dense training state is integer only:
T_packed uint8,E int8,corr_accum int64,step_counter int64. - The 3.12B logical ternary text path reports 0 trainable float params and runs a 1-step pretrain smoke.
- Remaining float params are frozen LTI constants, not optimizer state.
Follow-Up
- Port sparse embedding/VQ tables from legacy
T_accum/E_accumto BigInt correlation if theT_accum=321 MBtraining-state block becomes the next bottleneck. - Add TileLang BigInt support only after its kernels accept
corr_accum,step_counter, and integer correlation accumulation. The fp16-only TileLang path should stay disabled for BigInt training.