File size: 1,994 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.kernel.ternary_scale import _HAS_TILELANG, TernaryScaleTensor, _TernaryLinearFn, _tilelang_training_enabled


def _cuda_available():
    return torch.cuda.is_available()


def test_tilelang_output_float32():
    if not _cuda_available():
        print(" SKIP test_tilelang_output_float32 (no CUDA)")
        return
    if not _HAS_TILELANG:
        print(" SKIP test_tilelang_output_float32 (no Tilelang)")
        return
    lin = TernaryScaleTensor(8, 4).to("cuda")
    x = torch.randn(2, 8, device="cuda")
    with torch.no_grad():
        y = lin(x)
    assert y.dtype == torch.float32, f"Expected float32 output, got {y.dtype}"
    print(" PASS test_tilelang_output_float32")


def test_tilelang_training_disabled_by_default():
    assert not _tilelang_training_enabled(), "TileLang training should be opt-in because the fp16 path can produce NaNs"
    print(" PASS test_tilelang_training_disabled_by_default")


def test_tilelang_training_forward_finite():
    if not _cuda_available():
        print(" SKIP test_tilelang_training_forward_finite (no CUDA)")
        return
    if not _HAS_TILELANG:
        print(" SKIP test_tilelang_training_forward_finite (no Tilelang)")
        return
    # Train a small model with Tilelang and verify no NaN
    os.environ["ARB_TERNARY_BACKEND"] = "auto"
    from arbitor.main import ARBModel
    model = ARBModel(
        enable_image=False, enable_audio=False, enable_vq=False,
        enable_graph=False, enable_memory_modules=False, enable_moe=False,
    ).cuda()
    from arbitor.config import VOCAB
    for step in range(5):
        x = torch.randint(0, VOCAB, (1, 4), device="cuda")
        _, lc, _, _ = model(x, targets=x[:, 3:])
        assert torch.isfinite(lc.total), f"Non-finite loss at step {step}"
        model._ternary_update_memory(loss_components=lc)
    print(" PASS test_tilelang_training_forward_finite")