import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT, TernaryScaleTensor from arbitor.components import LossComponents, LossWeights def _cuda_available(): return torch.cuda.is_available() def test_component_context_lifecycle(): _COMPONENT_CONTEXT.clear() name, weight = _COMPONENT_CONTEXT.get() assert name is None, f"default name should be None, got {name}" assert weight == 1.0, f"default weight should be 1.0, got {weight}" _COMPONENT_CONTEXT.set("lm", 1.0) name, weight = _COMPONENT_CONTEXT.get() assert name == "lm", f"after set, name should be lm, got {name}" assert weight == 1.0, f"after set, weight should be 1.0, got {weight}" _COMPONENT_CONTEXT.set("vq", 0.5) name, weight = _COMPONENT_CONTEXT.get() assert name == "vq", f"after set vq, name should be vq, got {name}" assert weight == 0.5, f"after set vq, weight should be 0.5, got {weight}" _COMPONENT_CONTEXT.clear() name, weight = _COMPONENT_CONTEXT.get() assert name is None, f"after clear, name should be None, got {name}" _COMPONENT_CONTEXT.set(None) name, weight = _COMPONENT_CONTEXT.get() assert name is None, f"after set(None), name should be None, got {name}" print(" PASS test_component_context_lifecycle") def test_triton_fn_per_component_hook(): if not _cuda_available(): print(" SKIP test_triton_fn_per_component_hook (no CUDA)") return from arbitor.kernel.ternary_scale import _HAS_TRITON if not _HAS_TRITON: print(" SKIP test_triton_fn_per_component_hook (no Triton)") return lin = TernaryScaleTensor(8, 4).to("cuda") x = torch.ones(2, 8, device="cuda", requires_grad=True) _COMPONENT_CONTEXT.set("lm", 1.0) y = lin(x) loss = y.sum() loss.backward() _COMPONENT_CONTEXT.clear() assert not hasattr(lin, "_hook_grad_T_sign_lm"), "streaming backward should not retain grad-sign hooks" assert not hasattr(lin, "_hook_grad_2d_lm"), "fp32 grad hook should not be retained" assert not hasattr(lin, "_hook_x_2d_lm"), "fp32 activation hook should not be retained" assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" print(" PASS test_triton_fn_per_component_hook") def test_ternary_fn_per_component_hook(): if not _cuda_available(): print(" SKIP test_ternary_fn_per_component_hook (no CUDA)") return from arbitor.kernel.ternary_scale import _HAS_TILELANG if not _HAS_TILELANG: print(" SKIP test_ternary_fn_per_component_hook (no Tilelang)") return lin = TernaryScaleTensor(8, 4).to("cuda") x = torch.ones(2, 8, device="cuda", requires_grad=True) _COMPONENT_CONTEXT.set("moe", 0.5) y = lin(x) loss = y.sum() loss.backward() _COMPONENT_CONTEXT.clear() if hasattr(lin, "_hook_grad_T_sign_moe"): h = getattr(lin, "_hook_grad_T_sign_moe") assert h.shape == (4, 8), f"expected shape (4,8), got {h.shape}" assert h.dtype == torch.int8 del lin._hook_grad_T_sign_moe elif hasattr(lin, "_hook_grad_2d_moe"): assert hasattr(lin, "_hook_grad_2d_moe"), "per-component hook not found" h = getattr(lin, "_hook_grad_2d_moe") assert h.shape == (2, 4), f"expected shape (2,4), got {h.shape}" del lin._hook_grad_2d_moe del lin._hook_x_2d_moe else: assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" print(" PASS test_ternary_fn_per_component_hook") def test_merged_hooks_backward_compat(): if not _cuda_available(): print(" SKIP test_merged_hooks_backward_compat (no CUDA)") return from arbitor.kernel.ternary_scale import _HAS_TRITON if not _HAS_TRITON: print(" SKIP test_merged_hooks_backward_compat (no Triton)") return lin = TernaryScaleTensor(8, 4).to("cuda") x = torch.ones(2, 8, device="cuda", requires_grad=True) y = lin(x) loss = y.sum() loss.backward() assert not hasattr(lin, "_hook_grad_T_sign"), "streaming backward should not retain grad-sign hooks" assert not hasattr(lin, "_hook_grad_2d"), "fp32 grad hook should not be retained" assert not hasattr(lin, "_hook_x_2d"), "fp32 activation hook should not be retained" assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" assert not hasattr(lin, "_hook_grad_T_sign_lm"), "per-component hook leaked without context" print(" PASS test_merged_hooks_backward_compat") def test_losscomponents_active_fields(): lc = LossComponents(lm=torch.tensor(1.0), weights=LossWeights()) fields = lc.active_fields assert len(fields) == 1, f"expected 1 field, got {len(fields)}" assert fields[0][0] == "lm" assert fields[0][1].item() == 1.0 assert fields[0][2] == 1.0 lc2 = LossComponents() assert lc2.active_fields == [], f"expected empty, got {lc2.active_fields}" lc3 = LossComponents(lm=torch.tensor(2.0), vq_commitment=None, weights=LossWeights(vq_commitment=0.5)) assert len(lc3.active_fields) == 1, f"expected 1 (vq=None), got {len(lc3.active_fields)}" lc4 = LossComponents( lm=torch.tensor(1.0), moe_aux=torch.tensor(0.5), moe_ponder=torch.tensor(0.1), weights=LossWeights(moe_aux=0.1, moe_ponder=0.2), ) fields4 = lc4.active_fields assert len(fields4) == 3, f"expected 3 fields, got {len(fields4)}" names = {f[0] for f in fields4} assert "lm" in names assert "moe_aux" in names assert "moe_ponder" in names for f in fields4: if f[0] == "moe_aux": assert f[2] == 0.1, f"moe_aux weight should be 0.1, got {f[2]}" assert "weights" not in names, "weights field leaked into active_fields" print(" PASS test_losscomponents_active_fields")