ARBS / testing /test_gradient_capture.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT, TernaryScaleTensor
from arbitor.components import LossComponents, LossWeights
def _cuda_available():
return torch.cuda.is_available()
def test_component_context_lifecycle():
_COMPONENT_CONTEXT.clear()
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"default name should be None, got {name}"
assert weight == 1.0, f"default weight should be 1.0, got {weight}"
_COMPONENT_CONTEXT.set("lm", 1.0)
name, weight = _COMPONENT_CONTEXT.get()
assert name == "lm", f"after set, name should be lm, got {name}"
assert weight == 1.0, f"after set, weight should be 1.0, got {weight}"
_COMPONENT_CONTEXT.set("vq", 0.5)
name, weight = _COMPONENT_CONTEXT.get()
assert name == "vq", f"after set vq, name should be vq, got {name}"
assert weight == 0.5, f"after set vq, weight should be 0.5, got {weight}"
_COMPONENT_CONTEXT.clear()
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"after clear, name should be None, got {name}"
_COMPONENT_CONTEXT.set(None)
name, weight = _COMPONENT_CONTEXT.get()
assert name is None, f"after set(None), name should be None, got {name}"
print(" PASS test_component_context_lifecycle")
def test_triton_fn_per_component_hook():
if not _cuda_available():
print(" SKIP test_triton_fn_per_component_hook (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TRITON
if not _HAS_TRITON:
print(" SKIP test_triton_fn_per_component_hook (no Triton)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
_COMPONENT_CONTEXT.set("lm", 1.0)
y = lin(x)
loss = y.sum()
loss.backward()
_COMPONENT_CONTEXT.clear()
assert not hasattr(lin, "_hook_grad_T_sign_lm"), "streaming backward should not retain grad-sign hooks"
assert not hasattr(lin, "_hook_grad_2d_lm"), "fp32 grad hook should not be retained"
assert not hasattr(lin, "_hook_x_2d_lm"), "fp32 activation hook should not be retained"
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
print(" PASS test_triton_fn_per_component_hook")
def test_ternary_fn_per_component_hook():
if not _cuda_available():
print(" SKIP test_ternary_fn_per_component_hook (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TILELANG
if not _HAS_TILELANG:
print(" SKIP test_ternary_fn_per_component_hook (no Tilelang)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
_COMPONENT_CONTEXT.set("moe", 0.5)
y = lin(x)
loss = y.sum()
loss.backward()
_COMPONENT_CONTEXT.clear()
if hasattr(lin, "_hook_grad_T_sign_moe"):
h = getattr(lin, "_hook_grad_T_sign_moe")
assert h.shape == (4, 8), f"expected shape (4,8), got {h.shape}"
assert h.dtype == torch.int8
del lin._hook_grad_T_sign_moe
elif hasattr(lin, "_hook_grad_2d_moe"):
assert hasattr(lin, "_hook_grad_2d_moe"), "per-component hook not found"
h = getattr(lin, "_hook_grad_2d_moe")
assert h.shape == (2, 4), f"expected shape (2,4), got {h.shape}"
del lin._hook_grad_2d_moe
del lin._hook_x_2d_moe
else:
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
print(" PASS test_ternary_fn_per_component_hook")
def test_merged_hooks_backward_compat():
if not _cuda_available():
print(" SKIP test_merged_hooks_backward_compat (no CUDA)")
return
from arbitor.kernel.ternary_scale import _HAS_TRITON
if not _HAS_TRITON:
print(" SKIP test_merged_hooks_backward_compat (no Triton)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.ones(2, 8, device="cuda", requires_grad=True)
y = lin(x)
loss = y.sum()
loss.backward()
assert not hasattr(lin, "_hook_grad_T_sign"), "streaming backward should not retain grad-sign hooks"
assert not hasattr(lin, "_hook_grad_2d"), "fp32 grad hook should not be retained"
assert not hasattr(lin, "_hook_x_2d"), "fp32 activation hook should not be retained"
assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state"
assert not hasattr(lin, "_hook_grad_T_sign_lm"), "per-component hook leaked without context"
print(" PASS test_merged_hooks_backward_compat")
def test_losscomponents_active_fields():
lc = LossComponents(lm=torch.tensor(1.0), weights=LossWeights())
fields = lc.active_fields
assert len(fields) == 1, f"expected 1 field, got {len(fields)}"
assert fields[0][0] == "lm"
assert fields[0][1].item() == 1.0
assert fields[0][2] == 1.0
lc2 = LossComponents()
assert lc2.active_fields == [], f"expected empty, got {lc2.active_fields}"
lc3 = LossComponents(lm=torch.tensor(2.0), vq_commitment=None, weights=LossWeights(vq_commitment=0.5))
assert len(lc3.active_fields) == 1, f"expected 1 (vq=None), got {len(lc3.active_fields)}"
lc4 = LossComponents(
lm=torch.tensor(1.0),
moe_aux=torch.tensor(0.5),
moe_ponder=torch.tensor(0.1),
weights=LossWeights(moe_aux=0.1, moe_ponder=0.2),
)
fields4 = lc4.active_fields
assert len(fields4) == 3, f"expected 3 fields, got {len(fields4)}"
names = {f[0] for f in fields4}
assert "lm" in names
assert "moe_aux" in names
assert "moe_ponder" in names
for f in fields4:
if f[0] == "moe_aux":
assert f[2] == 0.1, f"moe_aux weight should be 0.1, got {f[2]}"
assert "weights" not in names, "weights field leaked into active_fields"
print(" PASS test_losscomponents_active_fields")