File size: 5,975 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT, TernaryScaleTensor
from arbitor.components import LossComponents, LossWeights


def _cuda_available():
    return torch.cuda.is_available()


def test_component_context_lifecycle():
    _COMPONENT_CONTEXT.clear()
    name, weight = _COMPONENT_CONTEXT.get()
    assert name is None, f"default name should be None, got {name}"
    assert weight == 1.0, f"default weight should be 1.0, got {weight}"

    _COMPONENT_CONTEXT.set("lm", 1.0)
    name, weight = _COMPONENT_CONTEXT.get()
    assert name == "lm", f"after set, name should be lm, got {name}"
    assert weight == 1.0, f"after set, weight should be 1.0, got {weight}"

    _COMPONENT_CONTEXT.set("vq", 0.5)
    name, weight = _COMPONENT_CONTEXT.get()
    assert name == "vq", f"after set vq, name should be vq, got {name}"
    assert weight == 0.5, f"after set vq, weight should be 0.5, got {weight}"

    _COMPONENT_CONTEXT.clear()
    name, weight = _COMPONENT_CONTEXT.get()
    assert name is None, f"after clear, name should be None, got {name}"

    _COMPONENT_CONTEXT.set(None)
    name, weight = _COMPONENT_CONTEXT.get()
    assert name is None, f"after set(None), name should be None, got {name}"

    print(" PASS test_component_context_lifecycle")


def test_triton_fn_per_component_hook():
    if not _cuda_available():
        print(" SKIP test_triton_fn_per_component_hook (no CUDA)")
        return
    from arbitor.kernel.ternary_scale import _HAS_TRITON
    if not _HAS_TRITON:
        print(" SKIP test_triton_fn_per_component_hook (no Triton)")
        return

    lin = TernaryScaleTensor(8, 4).to("cuda")
    x = torch.ones(2, 8, device="cuda", requires_grad=True)

    _COMPONENT_CONTEXT.set("lm", 1.0)
    y = lin(x)
    loss = y.sum()
    loss.backward()
    _COMPONENT_CONTEXT.clear()

    assert not hasattr(lin, "_hook_grad_T_sign_lm"), "streaming backward should not retain grad-sign hooks"
    assert not hasattr(lin, "_hook_grad_2d_lm"), "fp32 grad hook should not be retained"
    assert not hasattr(lin, "_hook_x_2d_lm"), "fp32 activation hook should not be retained"
    assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"

    print(" PASS test_triton_fn_per_component_hook")


def test_ternary_fn_per_component_hook():
    if not _cuda_available():
        print(" SKIP test_ternary_fn_per_component_hook (no CUDA)")
        return
    from arbitor.kernel.ternary_scale import _HAS_TILELANG
    if not _HAS_TILELANG:
        print(" SKIP test_ternary_fn_per_component_hook (no Tilelang)")
        return

    lin = TernaryScaleTensor(8, 4).to("cuda")
    x = torch.ones(2, 8, device="cuda", requires_grad=True)

    _COMPONENT_CONTEXT.set("moe", 0.5)
    y = lin(x)
    loss = y.sum()
    loss.backward()
    _COMPONENT_CONTEXT.clear()

    if hasattr(lin, "_hook_grad_T_sign_moe"):
        h = getattr(lin, "_hook_grad_T_sign_moe")
        assert h.shape == (4, 8), f"expected shape (4,8), got {h.shape}"
        assert h.dtype == torch.int8
        del lin._hook_grad_T_sign_moe
    elif hasattr(lin, "_hook_grad_2d_moe"):
        assert hasattr(lin, "_hook_grad_2d_moe"), "per-component hook not found"
        h = getattr(lin, "_hook_grad_2d_moe")
        assert h.shape == (2, 4), f"expected shape (2,4), got {h.shape}"
        del lin._hook_grad_2d_moe
        del lin._hook_x_2d_moe
    else:
        assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"

    print(" PASS test_ternary_fn_per_component_hook")


def test_merged_hooks_backward_compat():
    if not _cuda_available():
        print(" SKIP test_merged_hooks_backward_compat (no CUDA)")
        return
    from arbitor.kernel.ternary_scale import _HAS_TRITON
    if not _HAS_TRITON:
        print(" SKIP test_merged_hooks_backward_compat (no Triton)")
        return

    lin = TernaryScaleTensor(8, 4).to("cuda")
    x = torch.ones(2, 8, device="cuda", requires_grad=True)

    y = lin(x)
    loss = y.sum()
    loss.backward()

    assert not hasattr(lin, "_hook_grad_T_sign"), "streaming backward should not retain grad-sign hooks"
    assert not hasattr(lin, "_hook_grad_2d"), "fp32 grad hook should not be retained"
    assert not hasattr(lin, "_hook_x_2d"), "fp32 activation hook should not be retained"
    assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"

    assert not hasattr(lin, "_hook_grad_T_sign_lm"), "per-component hook leaked without context"

    print(" PASS test_merged_hooks_backward_compat")


def test_losscomponents_active_fields():
    lc = LossComponents(lm=torch.tensor(1.0), weights=LossWeights())
    fields = lc.active_fields
    assert len(fields) == 1, f"expected 1 field, got {len(fields)}"
    assert fields[0][0] == "lm"
    assert fields[0][1].item() == 1.0
    assert fields[0][2] == 1.0

    lc2 = LossComponents()
    assert lc2.active_fields == [], f"expected empty, got {lc2.active_fields}"

    lc3 = LossComponents(lm=torch.tensor(2.0), vq_commitment=None, weights=LossWeights(vq_commitment=0.5))
    assert len(lc3.active_fields) == 1, f"expected 1 (vq=None), got {len(lc3.active_fields)}"

    lc4 = LossComponents(
        lm=torch.tensor(1.0),
        moe_aux=torch.tensor(0.5),
        moe_ponder=torch.tensor(0.1),
        weights=LossWeights(moe_aux=0.1, moe_ponder=0.2),
    )
    fields4 = lc4.active_fields
    assert len(fields4) == 3, f"expected 3 fields, got {len(fields4)}"
    names = {f[0] for f in fields4}
    assert "lm" in names
    assert "moe_aux" in names
    assert "moe_ponder" in names
    for f in fields4:
        if f[0] == "moe_aux":
            assert f[2] == 0.1, f"moe_aux weight should be 0.1, got {f[2]}"

    assert "weights" not in names, "weights field leaked into active_fields"

    print(" PASS test_losscomponents_active_fields")