File size: 2,846 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType


def test_ternary_sign_values():
    """T ∈ {-1, 0, 1} — ternary values are pure sign."""
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    T = lin._get_T()
    unique = set(T.detach().flatten().tolist())
    assert unique.issubset({-1, 0, 1}), f"T has non-ternary values: {unique}"
    print(" PASS test_ternary_sign_values")


def test_scale_positive():
    """S = 2^E > 0 — scales are always positive."""
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    S = lin._get_S()
    assert (S > 0).all(), "Some S values are not positive"
    assert torch.isfinite(S).all(), "Some S values are non-finite"
    print(" PASS test_scale_positive")


def test_effective_weight_polarity():
    """W = S * T = {-S, 0, +S} — effective weight is product of sign and scale."""
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    T = lin._get_T().float()
    S = lin._get_S()
    W = S * T
    T_binary = T.sign().to(torch.int8)
    assert (W == S * T_binary.float()).all(), "W should equal S * sign(T)"
    print(" PASS test_effective_weight_polarity")


def test_e_update_changes_magnitude():
    """Modifying E changes W magnitude, not T polarity."""
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    T_before = lin._get_T().clone()
    E_orig = lin.E.clone()
    lin.E[0] = E_orig[0] + 1
    W_after = lin.dequantize()
    T_after = lin._get_T()
    assert (T_before == T_after).all(), "T changed when only E was modified"
    S_after = lin._get_S()
    assert (W_after == S_after * T_after.float()).all(), "W should reflect new S"
    lin.E = E_orig
    print(" PASS test_e_update_changes_magnitude")


def test_state_dict_no_float_weights():
    """Model state_dict has no float weight tensors."""
    from arbitor.main import ARBModel
    model = ARBModel(
        enable_image=False, enable_audio=False, enable_vq=False,
        enable_graph=False, enable_memory_modules=False, enable_moe=False,
    )
    sd = model.state_dict()
    for key, tensor in sd.items():
        if "T_packed" in key:
            assert tensor.dtype == torch.uint8, f"{key} should be uint8, got {tensor.dtype}"
        elif "T_accum" in key or "E_accum" in key or "E" in key or "group_lr" in key:
            assert tensor.dtype == torch.int8, f"{key} should be int8, got {tensor.dtype}"
        elif "corr_accum" in key or "step_counter" in key:
            assert tensor.dtype == torch.int64, f"{key} should be int64, got {tensor.dtype}"
        elif "bias" in key:
            assert tensor.dtype == torch.int32, f"{key} should be int32, got {tensor.dtype}"
    print(" PASS test_state_dict_no_float_weights")