import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.kernel.ternary_scale import _HAS_TILELANG, TernaryScaleTensor, _TernaryLinearFn, _tilelang_training_enabled def _cuda_available(): return torch.cuda.is_available() def test_tilelang_output_float32(): if not _cuda_available(): print(" SKIP test_tilelang_output_float32 (no CUDA)") return if not _HAS_TILELANG: print(" SKIP test_tilelang_output_float32 (no Tilelang)") return lin = TernaryScaleTensor(8, 4).to("cuda") x = torch.randn(2, 8, device="cuda") with torch.no_grad(): y = lin(x) assert y.dtype == torch.float32, f"Expected float32 output, got {y.dtype}" print(" PASS test_tilelang_output_float32") def test_tilelang_training_disabled_by_default(): assert not _tilelang_training_enabled(), "TileLang training should be opt-in because the fp16 path can produce NaNs" print(" PASS test_tilelang_training_disabled_by_default") def test_tilelang_training_forward_finite(): if not _cuda_available(): print(" SKIP test_tilelang_training_forward_finite (no CUDA)") return if not _HAS_TILELANG: print(" SKIP test_tilelang_training_forward_finite (no Tilelang)") return # Train a small model with Tilelang and verify no NaN os.environ["ARB_TERNARY_BACKEND"] = "auto" from arbitor.main import ARBModel model = ARBModel( enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False, ).cuda() from arbitor.config import VOCAB for step in range(5): x = torch.randint(0, VOCAB, (1, 4), device="cuda") _, lc, _, _ = model(x, targets=x[:, 3:]) assert torch.isfinite(lc.total), f"Non-finite loss at step {step}" model._ternary_update_memory(loss_components=lc) print(" PASS test_tilelang_training_forward_finite")