import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType def test_ternary_sign_values(): """T ∈ {-1, 0, 1} — ternary values are pure sign.""" lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) T = lin._get_T() unique = set(T.detach().flatten().tolist()) assert unique.issubset({-1, 0, 1}), f"T has non-ternary values: {unique}" print(" PASS test_ternary_sign_values") def test_scale_positive(): """S = 2^E > 0 — scales are always positive.""" lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) S = lin._get_S() assert (S > 0).all(), "Some S values are not positive" assert torch.isfinite(S).all(), "Some S values are non-finite" print(" PASS test_scale_positive") def test_effective_weight_polarity(): """W = S * T = {-S, 0, +S} — effective weight is product of sign and scale.""" lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) T = lin._get_T().float() S = lin._get_S() W = S * T T_binary = T.sign().to(torch.int8) assert (W == S * T_binary.float()).all(), "W should equal S * sign(T)" print(" PASS test_effective_weight_polarity") def test_e_update_changes_magnitude(): """Modifying E changes W magnitude, not T polarity.""" lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) T_before = lin._get_T().clone() E_orig = lin.E.clone() lin.E[0] = E_orig[0] + 1 W_after = lin.dequantize() T_after = lin._get_T() assert (T_before == T_after).all(), "T changed when only E was modified" S_after = lin._get_S() assert (W_after == S_after * T_after.float()).all(), "W should reflect new S" lin.E = E_orig print(" PASS test_e_update_changes_magnitude") def test_state_dict_no_float_weights(): """Model state_dict has no float weight tensors.""" from arbitor.main import ARBModel model = ARBModel( enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False, ) sd = model.state_dict() for key, tensor in sd.items(): if "T_packed" in key: assert tensor.dtype == torch.uint8, f"{key} should be uint8, got {tensor.dtype}" elif "T_accum" in key or "E_accum" in key or "E" in key or "group_lr" in key: assert tensor.dtype == torch.int8, f"{key} should be int8, got {tensor.dtype}" elif "corr_accum" in key or "step_counter" in key: assert tensor.dtype == torch.int64, f"{key} should be int64, got {tensor.dtype}" elif "bias" in key: assert tensor.dtype == torch.int32, f"{key} should be int32, got {tensor.dtype}" print(" PASS test_state_dict_no_float_weights")