File size: 4,192 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | # TRUE TERNARY REFACTOR 18: BigInt Correlation Scaling
## Goal
Port the successful `testing/test_bigint_ternary.py` scaling rule into the main ARB ternary training path while staying under the >8GB VRAM training target.
The important math is now the dense ternary default:
```text
mean_corr = corr_accum / (step_counter * group_size)
S = 2 ** (E + ARB_BIGINT_CORR_STRENGTH * mean_corr)
```
Default `ARB_BIGINT_CORR_STRENGTH` is `4.0`, matching the successful BigInt test. `E` stays fixed after initialization for dense `TernaryScaleTensor` modules. The trainable signal moves through integer `corr_accum`.
## Main Changes
- `arbitor/kernel/ternary_scale.py`
- Replaced dense `T_accum` and `E_accum` training state with `corr_accum int64` and `step_counter int64`.
- Added BigInt scale expansion in `_get_S()` and the Triton linear forward/grad-x kernels.
- Added a Triton direct correlation accumulation kernel for `sign(grad_y.T @ x) * T` grouped by scale group.
- Added a custom autograd function that recomputes effective weights in backward instead of saving `w_eff` in the graph.
- Disabled the TileLang dense path for this mode until TileLang supports `corr_accum` in both forward and backward. This prevents the previous fp16 TileLang path from silently breaking ternary training.
- `arbitor/main.py`
- Updated `prepare_ternary_backward()` and update cleanup to recognize streamed BigInt updates.
- Preserved `LossComponent` routing by sending component-specific dense signs into `corr_accum`.
- Skipped old per-group threshold float temporaries for BigInt dense layers. Those thresholds were only used by legacy `T_accum` flips and were a large avoidable allocation at 3B scale.
- `arbitor/kernel/ternary_audit.py`
- Audit now counts `corr_accum` and `step_counter` so training-state memory is no longer underreported.
- `arbitor/attention/context_attention.py`, `arbitor/components.py`, `arbitor/decoders.py`
- Converted the context attention gate from `nn.Linear` to `TernaryScaleTensor`.
- Froze LTI injection float constants so pure ternary trainers have zero trainable float parameters.
- `testing/test_tscale.py`, `testing/test_polarity_validation.py`
- Updated expectations from old `E_accum/T_accum` behavior to fixed `E`, integer `corr_accum`, and BigInt step counters.
## Validation
Commands run:
```bash
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_polarity_validation.py testing/test_tscale.py -k "not model_integration and not runtime_switch"
ARB_TERNARY_BACKEND=triton python training/text.py --steps 1 --batch 1 --ctx 4 --eval-interval 999 --run bigint-smoke
ARB_TERNARY_BACKEND=triton python training/pretrain.py --steps 1 --batch 1 --ctx 4 --text-data training/data/tinyshakespeare.txt --text-weight 1.0 --code-weight 0 --image-weight 0 --audio-weight 0 --video-weight 0 --eval-interval 0 --log-interval 0 --save-interval 0 --no-save --run bigint-pretrain-smoke
```
Focused tests passed:
```text
42 passed, 2 deselected
```
3-step full text-model CUDA memory probe:
```text
logical_ternary_weights 3122933472
training_state_mb 1684.27
after_cuda_mb 1706
step 0 alloc_mb 1754 reserved_mb 1838 peak_mb 1998
step 1 alloc_mb 1754 reserved_mb 1838 peak_mb 2038
step 2 alloc_mb 1754 reserved_mb 1842 peak_mb 2036
```
This is the key result: allocated VRAM did not stack across steps after cleanup.
## Current State
- Dense ternary modules now train through BigInt correlation scaling rather than discrete ternary flips.
- Persistent dense training state is integer only: `T_packed uint8`, `E int8`, `corr_accum int64`, `step_counter int64`.
- The 3.12B logical ternary text path reports 0 trainable float params and runs a 1-step pretrain smoke.
- Remaining float params are frozen LTI constants, not optimizer state.
## Follow-Up
- Port sparse embedding/VQ tables from legacy `T_accum/E_accum` to BigInt correlation if the `T_accum=321 MB` training-state block becomes the next bottleneck.
- Add TileLang BigInt support only after its kernels accept `corr_accum`, `step_counter`, and integer correlation accumulation. The fp16-only TileLang path should stay disabled for BigInt training.
|