File size: 2,846 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType
def test_ternary_sign_values():
"""T ∈ {-1, 0, 1} — ternary values are pure sign."""
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
T = lin._get_T()
unique = set(T.detach().flatten().tolist())
assert unique.issubset({-1, 0, 1}), f"T has non-ternary values: {unique}"
print(" PASS test_ternary_sign_values")
def test_scale_positive():
"""S = 2^E > 0 — scales are always positive."""
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
S = lin._get_S()
assert (S > 0).all(), "Some S values are not positive"
assert torch.isfinite(S).all(), "Some S values are non-finite"
print(" PASS test_scale_positive")
def test_effective_weight_polarity():
"""W = S * T = {-S, 0, +S} — effective weight is product of sign and scale."""
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
T = lin._get_T().float()
S = lin._get_S()
W = S * T
T_binary = T.sign().to(torch.int8)
assert (W == S * T_binary.float()).all(), "W should equal S * sign(T)"
print(" PASS test_effective_weight_polarity")
def test_e_update_changes_magnitude():
"""Modifying E changes W magnitude, not T polarity."""
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
T_before = lin._get_T().clone()
E_orig = lin.E.clone()
lin.E[0] = E_orig[0] + 1
W_after = lin.dequantize()
T_after = lin._get_T()
assert (T_before == T_after).all(), "T changed when only E was modified"
S_after = lin._get_S()
assert (W_after == S_after * T_after.float()).all(), "W should reflect new S"
lin.E = E_orig
print(" PASS test_e_update_changes_magnitude")
def test_state_dict_no_float_weights():
"""Model state_dict has no float weight tensors."""
from arbitor.main import ARBModel
model = ARBModel(
enable_image=False, enable_audio=False, enable_vq=False,
enable_graph=False, enable_memory_modules=False, enable_moe=False,
)
sd = model.state_dict()
for key, tensor in sd.items():
if "T_packed" in key:
assert tensor.dtype == torch.uint8, f"{key} should be uint8, got {tensor.dtype}"
elif "T_accum" in key or "E_accum" in key or "E" in key or "group_lr" in key:
assert tensor.dtype == torch.int8, f"{key} should be int8, got {tensor.dtype}"
elif "corr_accum" in key or "step_counter" in key:
assert tensor.dtype == torch.int64, f"{key} should be int64, got {tensor.dtype}"
elif "bias" in key:
assert tensor.dtype == torch.int32, f"{key} should be int32, got {tensor.dtype}"
print(" PASS test_state_dict_no_float_weights")
|