| # TRUE TERNARY REFACTOR 18: BigInt Correlation Scaling |
|
|
| ## Goal |
|
|
| Port the successful `testing/test_bigint_ternary.py` scaling rule into the main ARB ternary training path while staying under the >8GB VRAM training target. |
|
|
| The important math is now the dense ternary default: |
|
|
| ```text |
| mean_corr = corr_accum / (step_counter * group_size) |
| S = 2 ** (E + ARB_BIGINT_CORR_STRENGTH * mean_corr) |
| ``` |
|
|
| Default `ARB_BIGINT_CORR_STRENGTH` is `4.0`, matching the successful BigInt test. `E` stays fixed after initialization for dense `TernaryScaleTensor` modules. The trainable signal moves through integer `corr_accum`. |
|
|
| ## Main Changes |
|
|
| - `arbitor/kernel/ternary_scale.py` |
| - Replaced dense `T_accum` and `E_accum` training state with `corr_accum int64` and `step_counter int64`. |
| - Added BigInt scale expansion in `_get_S()` and the Triton linear forward/grad-x kernels. |
| - Added a Triton direct correlation accumulation kernel for `sign(grad_y.T @ x) * T` grouped by scale group. |
| - Added a custom autograd function that recomputes effective weights in backward instead of saving `w_eff` in the graph. |
| - Disabled the TileLang dense path for this mode until TileLang supports `corr_accum` in both forward and backward. This prevents the previous fp16 TileLang path from silently breaking ternary training. |
|
|
| - `arbitor/main.py` |
| - Updated `prepare_ternary_backward()` and update cleanup to recognize streamed BigInt updates. |
| - Preserved `LossComponent` routing by sending component-specific dense signs into `corr_accum`. |
| - Skipped old per-group threshold float temporaries for BigInt dense layers. Those thresholds were only used by legacy `T_accum` flips and were a large avoidable allocation at 3B scale. |
|
|
| - `arbitor/kernel/ternary_audit.py` |
| - Audit now counts `corr_accum` and `step_counter` so training-state memory is no longer underreported. |
|
|
| - `arbitor/attention/context_attention.py`, `arbitor/components.py`, `arbitor/decoders.py` |
| - Converted the context attention gate from `nn.Linear` to `TernaryScaleTensor`. |
| - Froze LTI injection float constants so pure ternary trainers have zero trainable float parameters. |
|
|
| - `testing/test_tscale.py`, `testing/test_polarity_validation.py` |
| - Updated expectations from old `E_accum/T_accum` behavior to fixed `E`, integer `corr_accum`, and BigInt step counters. |
|
|
| ## Validation |
|
|
| Commands run: |
|
|
| ```bash |
| python -m compileall -q arbitor training testing |
| python -m pytest -q testing/test_polarity_validation.py testing/test_tscale.py -k "not model_integration and not runtime_switch" |
| ARB_TERNARY_BACKEND=triton python training/text.py --steps 1 --batch 1 --ctx 4 --eval-interval 999 --run bigint-smoke |
| ARB_TERNARY_BACKEND=triton python training/pretrain.py --steps 1 --batch 1 --ctx 4 --text-data training/data/tinyshakespeare.txt --text-weight 1.0 --code-weight 0 --image-weight 0 --audio-weight 0 --video-weight 0 --eval-interval 0 --log-interval 0 --save-interval 0 --no-save --run bigint-pretrain-smoke |
| ``` |
|
|
| Focused tests passed: |
|
|
| ```text |
| 42 passed, 2 deselected |
| ``` |
|
|
| 3-step full text-model CUDA memory probe: |
|
|
| ```text |
| logical_ternary_weights 3122933472 |
| training_state_mb 1684.27 |
| after_cuda_mb 1706 |
| step 0 alloc_mb 1754 reserved_mb 1838 peak_mb 1998 |
| step 1 alloc_mb 1754 reserved_mb 1838 peak_mb 2038 |
| step 2 alloc_mb 1754 reserved_mb 1842 peak_mb 2036 |
| ``` |
|
|
| This is the key result: allocated VRAM did not stack across steps after cleanup. |
|
|
| ## Current State |
|
|
| - Dense ternary modules now train through BigInt correlation scaling rather than discrete ternary flips. |
| - Persistent dense training state is integer only: `T_packed uint8`, `E int8`, `corr_accum int64`, `step_counter int64`. |
| - The 3.12B logical ternary text path reports 0 trainable float params and runs a 1-step pretrain smoke. |
| - Remaining float params are frozen LTI constants, not optimizer state. |
|
|
| ## Follow-Up |
|
|
| - Port sparse embedding/VQ tables from legacy `T_accum/E_accum` to BigInt correlation if the `T_accum=321 MB` training-state block becomes the next bottleneck. |
| - Add TileLang BigInt support only after its kernels accept `corr_accum`, `step_counter`, and integer correlation accumulation. The fp16-only TileLang path should stay disabled for BigInt training. |
|
|