| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.kernel.ternary_scale import _HAS_TILELANG, TernaryScaleTensor, _TernaryLinearFn, _tilelang_training_enabled |
|
|
|
|
| def _cuda_available(): |
| return torch.cuda.is_available() |
|
|
|
|
| def test_tilelang_output_float32(): |
| if not _cuda_available(): |
| print(" SKIP test_tilelang_output_float32 (no CUDA)") |
| return |
| if not _HAS_TILELANG: |
| print(" SKIP test_tilelang_output_float32 (no Tilelang)") |
| return |
| lin = TernaryScaleTensor(8, 4).to("cuda") |
| x = torch.randn(2, 8, device="cuda") |
| with torch.no_grad(): |
| y = lin(x) |
| assert y.dtype == torch.float32, f"Expected float32 output, got {y.dtype}" |
| print(" PASS test_tilelang_output_float32") |
|
|
|
|
| def test_tilelang_training_disabled_by_default(): |
| assert not _tilelang_training_enabled(), "TileLang training should be opt-in because the fp16 path can produce NaNs" |
| print(" PASS test_tilelang_training_disabled_by_default") |
|
|
|
|
| def test_tilelang_training_forward_finite(): |
| if not _cuda_available(): |
| print(" SKIP test_tilelang_training_forward_finite (no CUDA)") |
| return |
| if not _HAS_TILELANG: |
| print(" SKIP test_tilelang_training_forward_finite (no Tilelang)") |
| return |
| |
| os.environ["ARB_TERNARY_BACKEND"] = "auto" |
| from arbitor.main import ARBModel |
| model = ARBModel( |
| enable_image=False, enable_audio=False, enable_vq=False, |
| enable_graph=False, enable_memory_modules=False, enable_moe=False, |
| ).cuda() |
| from arbitor.config import VOCAB |
| for step in range(5): |
| x = torch.randint(0, VOCAB, (1, 4), device="cuda") |
| _, lc, _, _ = model(x, targets=x[:, 3:]) |
| assert torch.isfinite(lc.total), f"Non-finite loss at step {step}" |
| model._ternary_update_memory(loss_components=lc) |
| print(" PASS test_tilelang_training_forward_finite") |
|
|