File size: 5,975 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT, TernaryScaleTensor
from arbitor.components import LossComponents, LossWeights
def _cuda_available():
return torch.cuda.is_available()
def test_component_context_lifecycle():
_COMPONENT_CONTEXT.clear()
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"default name should be None, got {name}"
assert weight == 1.0, f"default weight should be 1.0, got {weight}"
_COMPONENT_CONTEXT.set("lm", 1.0)
name, weight = _COMPONENT_CONTEXT.get()
assert name == "lm", f"after set, name should be lm, got {name}"
assert weight == 1.0, f"after set, weight should be 1.0, got {weight}"
_COMPONENT_CONTEXT.set("vq", 0.5)
name, weight = _COMPONENT_CONTEXT.get()
assert name == "vq", f"after set vq, name should be vq, got {name}"
assert weight == 0.5, f"after set vq, weight should be 0.5, got {weight}"
_COMPONENT_CONTEXT.clear()
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"after clear, name should be None, got {name}"
_COMPONENT_CONTEXT.set(None)
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"after set(None), name should be None, got {name}"
print(" PASS test_component_context_lifecycle")
def test_triton_fn_per_component_hook():
if not _cuda_available():
print(" SKIP test_triton_fn_per_component_hook (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TRITON
if not _HAS_TRITON:
print(" SKIP test_triton_fn_per_component_hook (no Triton)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
_COMPONENT_CONTEXT.set("lm", 1.0)
y = lin(x)
loss = y.sum()
loss.backward()
_COMPONENT_CONTEXT.clear()
assert not hasattr(lin, "_hook_grad_T_sign_lm"), "streaming backward should not retain grad-sign hooks"
assert not hasattr(lin, "_hook_grad_2d_lm"), "fp32 grad hook should not be retained"
assert not hasattr(lin, "_hook_x_2d_lm"), "fp32 activation hook should not be retained"
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
print(" PASS test_triton_fn_per_component_hook")
def test_ternary_fn_per_component_hook():
if not _cuda_available():
print(" SKIP test_ternary_fn_per_component_hook (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TILELANG
if not _HAS_TILELANG:
print(" SKIP test_ternary_fn_per_component_hook (no Tilelang)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
_COMPONENT_CONTEXT.set("moe", 0.5)
y = lin(x)
loss = y.sum()
loss.backward()
_COMPONENT_CONTEXT.clear()
if hasattr(lin, "_hook_grad_T_sign_moe"):
h = getattr(lin, "_hook_grad_T_sign_moe")
assert h.shape == (4, 8), f"expected shape (4,8), got {h.shape}"
assert h.dtype == torch.int8
del lin._hook_grad_T_sign_moe
elif hasattr(lin, "_hook_grad_2d_moe"):
assert hasattr(lin, "_hook_grad_2d_moe"), "per-component hook not found"
h = getattr(lin, "_hook_grad_2d_moe")
assert h.shape == (2, 4), f"expected shape (2,4), got {h.shape}"
del lin._hook_grad_2d_moe
del lin._hook_x_2d_moe
else:
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
print(" PASS test_ternary_fn_per_component_hook")
def test_merged_hooks_backward_compat():
if not _cuda_available():
print(" SKIP test_merged_hooks_backward_compat (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TRITON
if not _HAS_TRITON:
print(" SKIP test_merged_hooks_backward_compat (no Triton)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
y = lin(x)
loss = y.sum()
loss.backward()
assert not hasattr(lin, "_hook_grad_T_sign"), "streaming backward should not retain grad-sign hooks"
assert not hasattr(lin, "_hook_grad_2d"), "fp32 grad hook should not be retained"
assert not hasattr(lin, "_hook_x_2d"), "fp32 activation hook should not be retained"
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
assert not hasattr(lin, "_hook_grad_T_sign_lm"), "per-component hook leaked without context"
print(" PASS test_merged_hooks_backward_compat")
def test_losscomponents_active_fields():
lc = LossComponents(lm=torch.tensor(1.0), weights=LossWeights())
fields = lc.active_fields
assert len(fields) == 1, f"expected 1 field, got {len(fields)}"
assert fields[0][0] == "lm"
assert fields[0][1].item() == 1.0
assert fields[0][2] == 1.0
lc2 = LossComponents()
assert lc2.active_fields == [], f"expected empty, got {lc2.active_fields}"
lc3 = LossComponents(lm=torch.tensor(2.0), vq_commitment=None, weights=LossWeights(vq_commitment=0.5))
assert len(lc3.active_fields) == 1, f"expected 1 (vq=None), got {len(lc3.active_fields)}"
lc4 = LossComponents(
lm=torch.tensor(1.0),
moe_aux=torch.tensor(0.5),
moe_ponder=torch.tensor(0.1),
weights=LossWeights(moe_aux=0.1, moe_ponder=0.2),
)
fields4 = lc4.active_fields
assert len(fields4) == 3, f"expected 3 fields, got {len(fields4)}"
names = {f[0] for f in fields4}
assert "lm" in names
assert "moe_aux" in names
assert "moe_ponder" in names
for f in fields4:
if f[0] == "moe_aux":
assert f[2] == 0.1, f"moe_aux weight should be 0.1, got {f[2]}"
assert "weights" not in names, "weights field leaked into active_fields"
print(" PASS test_losscomponents_active_fields")
|