| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT, TernaryScaleTensor |
| from arbitor.components import LossComponents, LossWeights |
|
|
|
|
| def _cuda_available(): |
| return torch.cuda.is_available() |
|
|
|
|
| def test_component_context_lifecycle(): |
| _COMPONENT_CONTEXT.clear() |
| name, weight = _COMPONENT_CONTEXT.get() |
| assert name is None, f"default name should be None, got {name}" |
| assert weight == 1.0, f"default weight should be 1.0, got {weight}" |
|
|
| _COMPONENT_CONTEXT.set("lm", 1.0) |
| name, weight = _COMPONENT_CONTEXT.get() |
| assert name == "lm", f"after set, name should be lm, got {name}" |
| assert weight == 1.0, f"after set, weight should be 1.0, got {weight}" |
|
|
| _COMPONENT_CONTEXT.set("vq", 0.5) |
| name, weight = _COMPONENT_CONTEXT.get() |
| assert name == "vq", f"after set vq, name should be vq, got {name}" |
| assert weight == 0.5, f"after set vq, weight should be 0.5, got {weight}" |
|
|
| _COMPONENT_CONTEXT.clear() |
| name, weight = _COMPONENT_CONTEXT.get() |
| assert name is None, f"after clear, name should be None, got {name}" |
|
|
| _COMPONENT_CONTEXT.set(None) |
| name, weight = _COMPONENT_CONTEXT.get() |
| assert name is None, f"after set(None), name should be None, got {name}" |
|
|
| print(" PASS test_component_context_lifecycle") |
|
|
|
|
| def test_triton_fn_per_component_hook(): |
| if not _cuda_available(): |
| print(" SKIP test_triton_fn_per_component_hook (no CUDA)") |
| return |
| from arbitor.kernel.ternary_scale import _HAS_TRITON |
| if not _HAS_TRITON: |
| print(" SKIP test_triton_fn_per_component_hook (no Triton)") |
| return |
|
|
| lin = TernaryScaleTensor(8, 4).to("cuda") |
| x = torch.ones(2, 8, device="cuda", requires_grad=True) |
|
|
| _COMPONENT_CONTEXT.set("lm", 1.0) |
| y = lin(x) |
| loss = y.sum() |
| loss.backward() |
| _COMPONENT_CONTEXT.clear() |
|
|
| assert not hasattr(lin, "_hook_grad_T_sign_lm"), "streaming backward should not retain grad-sign hooks" |
| assert not hasattr(lin, "_hook_grad_2d_lm"), "fp32 grad hook should not be retained" |
| assert not hasattr(lin, "_hook_x_2d_lm"), "fp32 activation hook should not be retained" |
| assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" |
|
|
| print(" PASS test_triton_fn_per_component_hook") |
|
|
|
|
| def test_ternary_fn_per_component_hook(): |
| if not _cuda_available(): |
| print(" SKIP test_ternary_fn_per_component_hook (no CUDA)") |
| return |
| from arbitor.kernel.ternary_scale import _HAS_TILELANG |
| if not _HAS_TILELANG: |
| print(" SKIP test_ternary_fn_per_component_hook (no Tilelang)") |
| return |
|
|
| lin = TernaryScaleTensor(8, 4).to("cuda") |
| x = torch.ones(2, 8, device="cuda", requires_grad=True) |
|
|
| _COMPONENT_CONTEXT.set("moe", 0.5) |
| y = lin(x) |
| loss = y.sum() |
| loss.backward() |
| _COMPONENT_CONTEXT.clear() |
|
|
| if hasattr(lin, "_hook_grad_T_sign_moe"): |
| h = getattr(lin, "_hook_grad_T_sign_moe") |
| assert h.shape == (4, 8), f"expected shape (4,8), got {h.shape}" |
| assert h.dtype == torch.int8 |
| del lin._hook_grad_T_sign_moe |
| elif hasattr(lin, "_hook_grad_2d_moe"): |
| assert hasattr(lin, "_hook_grad_2d_moe"), "per-component hook not found" |
| h = getattr(lin, "_hook_grad_2d_moe") |
| assert h.shape == (2, 4), f"expected shape (2,4), got {h.shape}" |
| del lin._hook_grad_2d_moe |
| del lin._hook_x_2d_moe |
| else: |
| assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" |
|
|
| print(" PASS test_ternary_fn_per_component_hook") |
|
|
|
|
| def test_merged_hooks_backward_compat(): |
| if not _cuda_available(): |
| print(" SKIP test_merged_hooks_backward_compat (no CUDA)") |
| return |
| from arbitor.kernel.ternary_scale import _HAS_TRITON |
| if not _HAS_TRITON: |
| print(" SKIP test_merged_hooks_backward_compat (no Triton)") |
| return |
|
|
| lin = TernaryScaleTensor(8, 4).to("cuda") |
| x = torch.ones(2, 8, device="cuda", requires_grad=True) |
|
|
| y = lin(x) |
| loss = y.sum() |
| loss.backward() |
|
|
| assert not hasattr(lin, "_hook_grad_T_sign"), "streaming backward should not retain grad-sign hooks" |
| assert not hasattr(lin, "_hook_grad_2d"), "fp32 grad hook should not be retained" |
| assert not hasattr(lin, "_hook_x_2d"), "fp32 activation hook should not be retained" |
| assert lin.T_accum.abs().sum().item() > 0, "streaming backward should accumulate int8 T state" |
|
|
| assert not hasattr(lin, "_hook_grad_T_sign_lm"), "per-component hook leaked without context" |
|
|
| print(" PASS test_merged_hooks_backward_compat") |
|
|
|
|
| def test_losscomponents_active_fields(): |
| lc = LossComponents(lm=torch.tensor(1.0), weights=LossWeights()) |
| fields = lc.active_fields |
| assert len(fields) == 1, f"expected 1 field, got {len(fields)}" |
| assert fields[0][0] == "lm" |
| assert fields[0][1].item() == 1.0 |
| assert fields[0][2] == 1.0 |
|
|
| lc2 = LossComponents() |
| assert lc2.active_fields == [], f"expected empty, got {lc2.active_fields}" |
|
|
| lc3 = LossComponents(lm=torch.tensor(2.0), vq_commitment=None, weights=LossWeights(vq_commitment=0.5)) |
| assert len(lc3.active_fields) == 1, f"expected 1 (vq=None), got {len(lc3.active_fields)}" |
|
|
| lc4 = LossComponents( |
| lm=torch.tensor(1.0), |
| moe_aux=torch.tensor(0.5), |
| moe_ponder=torch.tensor(0.1), |
| weights=LossWeights(moe_aux=0.1, moe_ponder=0.2), |
| ) |
| fields4 = lc4.active_fields |
| assert len(fields4) == 3, f"expected 3 fields, got {len(fields4)}" |
| names = {f[0] for f in fields4} |
| assert "lm" in names |
| assert "moe_aux" in names |
| assert "moe_ponder" in names |
| for f in fields4: |
| if f[0] == "moe_aux": |
| assert f[2] == 0.1, f"moe_aux weight should be 0.1, got {f[2]}" |
|
|
| assert "weights" not in names, "weights field leaked into active_fields" |
|
|
| print(" PASS test_losscomponents_active_fields") |
|
|