| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType |
|
|
|
|
| def test_ternary_sign_values(): |
| """T ∈ {-1, 0, 1} — ternary values are pure sign.""" |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| T = lin._get_T() |
| unique = set(T.detach().flatten().tolist()) |
| assert unique.issubset({-1, 0, 1}), f"T has non-ternary values: {unique}" |
| print(" PASS test_ternary_sign_values") |
|
|
|
|
| def test_scale_positive(): |
| """S = 2^E > 0 — scales are always positive.""" |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| S = lin._get_S() |
| assert (S > 0).all(), "Some S values are not positive" |
| assert torch.isfinite(S).all(), "Some S values are non-finite" |
| print(" PASS test_scale_positive") |
|
|
|
|
| def test_effective_weight_polarity(): |
| """W = S * T = {-S, 0, +S} — effective weight is product of sign and scale.""" |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| T = lin._get_T().float() |
| S = lin._get_S() |
| W = S * T |
| T_binary = T.sign().to(torch.int8) |
| assert (W == S * T_binary.float()).all(), "W should equal S * sign(T)" |
| print(" PASS test_effective_weight_polarity") |
|
|
|
|
| def test_e_update_changes_magnitude(): |
| """Modifying E changes W magnitude, not T polarity.""" |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| T_before = lin._get_T().clone() |
| E_orig = lin.E.clone() |
| lin.E[0] = E_orig[0] + 1 |
| W_after = lin.dequantize() |
| T_after = lin._get_T() |
| assert (T_before == T_after).all(), "T changed when only E was modified" |
| S_after = lin._get_S() |
| assert (W_after == S_after * T_after.float()).all(), "W should reflect new S" |
| lin.E = E_orig |
| print(" PASS test_e_update_changes_magnitude") |
|
|
|
|
| def test_state_dict_no_float_weights(): |
| """Model state_dict has no float weight tensors.""" |
| from arbitor.main import ARBModel |
| model = ARBModel( |
| enable_image=False, enable_audio=False, enable_vq=False, |
| enable_graph=False, enable_memory_modules=False, enable_moe=False, |
| ) |
| sd = model.state_dict() |
| for key, tensor in sd.items(): |
| if "T_packed" in key: |
| assert tensor.dtype == torch.uint8, f"{key} should be uint8, got {tensor.dtype}" |
| elif "T_accum" in key or "E_accum" in key or "E" in key or "group_lr" in key: |
| assert tensor.dtype == torch.int8, f"{key} should be int8, got {tensor.dtype}" |
| elif "corr_accum" in key or "step_counter" in key: |
| assert tensor.dtype == torch.int64, f"{key} should be int64, got {tensor.dtype}" |
| elif "bias" in key: |
| assert tensor.dtype == torch.int32, f"{key} should be int32, got {tensor.dtype}" |
| print(" PASS test_state_dict_no_float_weights") |
|
|