File size: 1,994 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.kernel.ternary_scale import _HAS_TILELANG, TernaryScaleTensor, _TernaryLinearFn, _tilelang_training_enabled
def _cuda_available():
return torch.cuda.is_available()
def test_tilelang_output_float32():
if not _cuda_available():
print(" SKIP test_tilelang_output_float32 (no CUDA)")
return
if not _HAS_TILELANG:
print(" SKIP test_tilelang_output_float32 (no Tilelang)")
return
lin = TernaryScaleTensor(8, 4).to("cuda")
x = torch.randn(2, 8, device="cuda")
with torch.no_grad():
y = lin(x)
assert y.dtype == torch.float32, f"Expected float32 output, got {y.dtype}"
print(" PASS test_tilelang_output_float32")
def test_tilelang_training_disabled_by_default():
assert not _tilelang_training_enabled(), "TileLang training should be opt-in because the fp16 path can produce NaNs"
print(" PASS test_tilelang_training_disabled_by_default")
def test_tilelang_training_forward_finite():
if not _cuda_available():
print(" SKIP test_tilelang_training_forward_finite (no CUDA)")
return
if not _HAS_TILELANG:
print(" SKIP test_tilelang_training_forward_finite (no Tilelang)")
return
# Train a small model with Tilelang and verify no NaN
os.environ["ARB_TERNARY_BACKEND"] = "auto"
from arbitor.main import ARBModel
model = ARBModel(
enable_image=False, enable_audio=False, enable_vq=False,
enable_graph=False, enable_memory_modules=False, enable_moe=False,
).cuda()
from arbitor.config import VOCAB
for step in range(5):
x = torch.randint(0, VOCAB, (1, 4), device="cuda")
_, lc, _, _ = model(x, targets=x[:, 3:])
assert torch.isfinite(lc.total), f"Non-finite loss at step {step}"
model._ternary_update_memory(loss_components=lc)
print(" PASS test_tilelang_training_forward_finite")
|