ARBS / testing /model /test_tscale.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.kernel import ternary_scale as tscale
from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, TILE_SIZE, GROUP_SIZES
from arbitor.optim.sign_sgd import SignSGD
from arbitor.components import StickyZoneSTE
from arbitor.config import VOCAB, CTX, SPECIAL_VOCAB
from arbitor.main import ARBModel
# ─── TernaryScaleTensor Tests ───
def test_tscale_shape():
lin = TernaryScaleTensor(32, 16)
x = torch.randn(2, 10, 32)
out = lin(x)
assert out.shape == (2, 10, 16), f"Shape: {out.shape}"
print(" PASS test_tscale_shape")
def test_tscale_ternary_output():
lin = TernaryScaleTensor(32, 16, threshold=0.05)
T = lin._get_T()
unique = set(T.detach().flatten().tolist())
assert unique.issubset({-1, 0, 1}), f"Non-ternary values in T: {unique}"
print(" PASS test_tscale_ternary_output")
def test_tscale_T64_per_element_s():
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
dq = lin.dequantize()
assert dq.shape == (16, 32), f"Dequantize shape: {dq.shape}"
print(" PASS test_tscale_T64_per_element_s")
def test_tscale_T32_group_s():
lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T32)
dq = lin.dequantize()
gpr = lin.E.shape[0] // lin.out_dim
assert gpr > 0, f"Groups per row: {gpr}"
assert dq.shape == (16, 96), f"Dequantize shape: {dq.shape}"
print(" PASS test_tscale_T32_group_s")
def test_tscale_to_switching():
lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T64)
dq_before = lin.dequantize()
assert lin.tscale_type == TScaleType.T64
lin.tscale_to(TScaleType.T32)
assert lin.tscale_type == TScaleType.T32
dq_after = lin.dequantize()
assert dq_before.shape == dq_after.shape
lin.tscale_to(TScaleType.T4)
assert lin.tscale_type == TScaleType.T4
dq_t4 = lin.dequantize()
assert dq_t4.shape == dq_before.shape
print(" PASS test_tscale_to_switching")
def test_tscale_cast_alias():
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
result = lin.tscale_cast(TScaleType.T8)
assert result is lin, "tscale_cast should return self"
assert lin.tscale_type == TScaleType.T8
print(" PASS test_tscale_cast_alias")
def test_tscale_gradient_flow():
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
x = torch.randn(2, 10, 32)
x.requires_grad_(True)
out = lin(x)
out.sum().backward()
assert x.grad is not None, "No gradient on input"
print(" PASS test_tscale_gradient_flow")
def test_tscale_all_types_forward():
for tscale_type in TScaleType:
lin = TernaryScaleTensor(96, 16, tscale_type=tscale_type)
x = torch.randn(2, 4, 96)
out = lin(x)
assert out.shape == (2, 4, 16), f"{tscale_type.name}: shape {out.shape}"
assert torch.isfinite(out).all(), f"{tscale_type.name}: non-finite output"
print(" PASS test_tscale_all_types_forward")
def test_tscale_dequantize():
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
w_eff = lin.dequantize()
assert w_eff.shape == (16, 32), f"Shape: {w_eff.shape}"
assert torch.isfinite(w_eff).all()
print(" PASS test_tscale_dequantize")
def test_tscale_effective_bpw():
lin64 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T64)
lin4 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T4)
assert lin4.effective_bpw < lin64.effective_bpw, "T4 should have lower BPW than T64"
print(f" T64 BPW: {lin64.effective_bpw:.2f}, T4 BPW: {lin4.effective_bpw:.2f}")
print(" PASS test_tscale_effective_bpw")
def test_tscale_model_integration():
for tscale_type in [TScaleType.T64, TScaleType.T32, TScaleType.T8]:
model = ARBModel(tscale_type=tscale_type)
x = torch.randint(0, VOCAB, (2, 10))
logits, losses, _, _ = model(x, targets=x[:, 3:])
assert losses is not None
losses.total.backward()
print(" PASS test_tscale_model_integration")
def test_tscale_runtime_switch():
model = ARBModel(tscale_type=TScaleType.T64)
x = torch.randint(0, VOCAB, (1, 10))
logits64, _, _, _ = model(x)
for module in model.modules():
if isinstance(module, TernaryScaleTensor):
module.tscale_to(TScaleType.T4)
logits4, _, _, _ = model(x)
assert torch.isfinite(logits4).all(), "Non-finite after tscale.to(T4)"
assert logits4.shape == logits64.shape, "Shape mismatch after tscale switch"
print(" PASS test_tscale_runtime_switch")
# ─── SignSGD Tests ───
def test_sign_sgd_step():
model = torch.nn.Linear(10, 5)
optimizer = SignSGD(model.parameters(), lr=0.01)
x = torch.randn(2, 10)
loss = model(x).sum()
loss.backward()
w_before = model.weight.clone()
optimizer.step()
assert not torch.equal(model.weight, w_before), "Weights did not change"
print(" PASS test_sign_sgd_step")
def test_sign_sgd_no_momentum():
model = torch.nn.Linear(10, 5)
optimizer = SignSGD(model.parameters(), lr=0.01)
assert len(optimizer.state) == 0, "SignSGD should have no state (no momentum)"
print(" PASS test_sign_sgd_no_momentum")
def test_sign_sgd_memory():
model = torch.nn.Linear(100, 100)
optimizer = SignSGD(model.parameters(), lr=0.01)
mem = optimizer.get_memory_mb()
assert mem > 0, "Memory should be positive"
print(f" SignSGD memory: {mem:.2f} MB")
print(" PASS test_sign_sgd_memory")
def test_sign_sgd_with_tscale_model():
model = ARBModel(tscale_type=TScaleType.T32)
optimizer = SignSGD(model.parameters(), lr=0.01)
x = torch.randint(0, VOCAB, (2, 10))
logits, losses, _, _ = model(x, targets=x[:, 3:])
losses.total.backward()
optimizer.step()
model._ternary_update_memory()
assert len(optimizer.state) == 0, "SignSGD should have no state"
print(" PASS test_sign_sgd_with_tscale_model")
def test_sign_sgd_weight_decay():
model = torch.nn.Linear(10, 5)
optimizer = SignSGD(model.parameters(), lr=0.01, weight_decay=0.01)
x = torch.randn(2, 10)
loss = model(x).sum()
loss.backward()
w_before = model.weight.clone()
optimizer.step()
w_diff = (model.weight - w_before).abs().sum().item()
assert w_diff > 0, "Weights should change with weight_decay"
print(" PASS test_sign_sgd_weight_decay")
# ─── TileLang PyTorch Reference Tests ───
def test_dequant_gemm_pytorch_ref():
import importlib.util
kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
if not os.path.exists(kernel_path):
print(" SKIP test_dequant_gemm_pytorch_ref (tilelang reference file missing)")
return
spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref
M, N, K, group_size = 4, 8, 96, 12
signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
x = torch.randn(M, K, dtype=torch.float16)
output = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)
assert output.shape == (M, N), f"Shape: {output.shape}"
assert torch.isfinite(output).all(), "Non-finite output"
print(" PASS test_dequant_gemm_pytorch_ref")
def test_dequant_gemm_matches_manual():
import importlib.util
import torch.nn.functional as F
kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
if not os.path.exists(kernel_path):
print(" SKIP test_dequant_gemm_matches_manual (tilelang reference file missing)")
return
spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref
M, N, K, group_size = 2, 4, 48, 12
signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
x = torch.randn(M, K, dtype=torch.float16)
result = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)
exp_expanded = exponents.repeat_interleave(group_size, dim=1)
pos_mask = exp_expanded >= 0
two_pow = torch.where(pos_mask,
(1 << exp_expanded.to(torch.int32)).to(torch.float16),
(1 >> (-exp_expanded.to(torch.int32))).to(torch.float16))
w = signs.to(torch.float16) * two_pow
expected = x @ w.t()
assert torch.allclose(result, expected, atol=1e-3), "PyTorch ref mismatch"
print(" PASS test_dequant_gemm_matches_manual")
# ─── Integration: SignSGD + TernaryScaleTensor training step ───
def test_full_training_step():
model = ARBModel(tscale_type=TScaleType.T32)
optimizer = SignSGD(model.parameters(), lr=0.01)
x = torch.randint(0, VOCAB, (2, 10))
logits, losses, _, _ = model(x, targets=x[:, 3:])
losses.total.backward()
optimizer.step()
model._ternary_update_memory()
logits2, losses2, _, _ = model(x, targets=x[:, 3:])
assert torch.isfinite(losses2.total), "Non-finite loss after step"
print(" PASS test_full_training_step")
def test_multiple_steps_converge():
model = ARBModel(tscale_type=TScaleType.T32)
optimizer = SignSGD(model.parameters(), lr=0.001)
x = torch.randint(0, VOCAB, (4, 10))
losses = []
for step in range(50):
optimizer.zero_grad()
logits, losses_out, _, _ = model(x, targets=x[:, 3:])
loss_val = losses_out.total
loss_val.backward()
optimizer.step()
model._ternary_update_memory(accum_threshold=3)
losses.append(loss_val.item())
assert torch.isfinite(torch.tensor(losses)).all(), "Non-finite loss during training"
print(f" Loss range: {min(losses):.4f}{max(losses):.4f} over 50 steps")
print(" PASS test_multiple_steps_converge")
def test_cuda_triton_correctness_linear():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_correctness_linear (CUDA/Triton unavailable)")
return
from arbitor.kernel.ternary_scale import TernaryRMSNorm, _triton_ternary_embed
from arbitor.main import ByteEmbedding
ATOL = 1e-3
for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
x = torch.randn(4, 4, 32, requires_grad=True)
cpu_out = lin_cpu(x)
grad_out = torch.randn_like(cpu_out)
cpu_out.backward(grad_out)
cpu_grad_x = x.grad.clone()
lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
lin_gpu.load_state_dict(lin_cpu.state_dict())
x_gpu = x.detach().clone().cuda().requires_grad_(True)
gpu_out = lin_gpu(x_gpu)
gpu_out.backward(grad_out.cuda())
gpu_grad_x = x_gpu.grad.clone()
fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
assert fwd_diff < ATOL, f"{tt.name} fwd_diff={fwd_diff}"
assert bwd_diff < ATOL, f"{tt.name} bwd_diff={bwd_diff}"
print(" PASS test_cuda_triton_correctness_linear")
def test_cuda_triton_correctness_rmsnorm():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_correctness_rmsnorm (CUDA/Triton unavailable)")
return
from arbitor.kernel.ternary_scale import TernaryRMSNorm
for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
norm_cpu = TernaryRMSNorm(256, tscale_type=tt)
x = torch.randn(2, 4, 256, requires_grad=True)
cpu_out = norm_cpu(x)
cpu_out.sum().backward()
cpu_grad_x = x.grad.clone()
norm_gpu = TernaryRMSNorm(256, tscale_type=tt).cuda()
norm_gpu.load_state_dict(norm_cpu.state_dict())
x_gpu = x.detach().clone().cuda().requires_grad_(True)
gpu_out = norm_gpu(x_gpu)
gpu_out.sum().backward()
gpu_grad_x = x_gpu.grad.clone()
fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
assert fwd_diff < 1e-5, f"{tt.name} rmsnorm fwd_diff={fwd_diff}"
assert bwd_diff < 1e-5, f"{tt.name} rmsnorm bwd_diff={bwd_diff}"
print(" PASS test_cuda_triton_correctness_rmsnorm")
def test_cuda_triton_correctness_embedding():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_correctness_embedding (CUDA/Triton unavailable)")
return
from arbitor.main import ByteEmbedding
for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
emb_cpu = ByteEmbedding(tscale_type=tt)
x = torch.tensor([0, 1, 2, 5, 10])
cpu_out = emb_cpu(x)
cpu_out.sum().backward()
emb_gpu = ByteEmbedding(tscale_type=tt).cuda()
emb_gpu.load_state_dict(emb_cpu.state_dict())
x_gpu = x.cuda()
gpu_out = emb_gpu(x_gpu)
gpu_out.sum().backward()
fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
assert fwd_diff < 1e-5, f"{tt.name} embed fwd_diff={fwd_diff}"
if hasattr(emb_cpu, '_hook_grad_T_sign') and hasattr(emb_gpu, '_hook_grad_T_sign'):
gs_match = (emb_gpu._hook_grad_T_sign.cpu() == emb_cpu._hook_grad_T_sign).float().mean().item()
assert gs_match > 0.99, f"{tt.name} embed grad_sign match={gs_match}"
print(" PASS test_cuda_triton_correctness_embedding")
def test_cuda_triton_correctness_update_E():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_correctness_update_E (CUDA/Triton unavailable)")
return
for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
lin_gpu.load_state_dict(lin_cpu.state_dict())
x_cpu = torch.randn(4, 4, 32, requires_grad=True)
x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)
cpu_out = lin_cpu(x_cpu)
cpu_out.sum().backward()
lin_cpu.update_E()
E_cpu = lin_cpu.E.clone()
E_accum_cpu = lin_cpu.E_accum.clone()
gpu_out = lin_gpu(x_gpu)
gpu_out.sum().backward()
lin_gpu.update_E()
E_gpu = lin_gpu.E.clone()
E_accum_gpu = lin_gpu.E_accum.clone()
# Compare fixed-point E residual update results.
E_diff = (E_cpu.float() - E_gpu.cpu().float()).abs().max().item()
assert E_diff < 0.01, f"{tt.name} CPU-GPU E update mismatch: {E_diff}"
E_accum_diff = (E_accum_cpu.float() - E_accum_gpu.cpu().float()).abs().max().item()
assert E_accum_diff < 0.01, f"{tt.name} CPU-GPU E_accum update mismatch: {E_accum_diff}"
print(" PASS test_cuda_triton_correctness_update_E")
def test_cuda_triton_correctness_ternary_step():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_correctness_ternary_step (CUDA/Triton unavailable)")
return
for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
lin_gpu.load_state_dict(lin_cpu.state_dict())
x_cpu = torch.randn(4, 4, 32, requires_grad=True)
x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)
cpu_out = lin_cpu(x_cpu)
cpu_out.sum().backward()
lin_cpu.ternary_step(accum_threshold=3)
T_cpu = lin_cpu._get_T().clone()
Taccum_cpu = lin_cpu.T_accum.clone()
gpu_out = lin_gpu(x_gpu)
gpu_out.sum().backward()
lin_gpu.ternary_step(accum_threshold=3)
T_gpu = lin_gpu._get_T().clone()
Taccum_gpu = lin_gpu.T_accum.clone()
T_match = (T_cpu == T_gpu.cpu()).float().mean().item()
Taccum_match = (Taccum_cpu == Taccum_gpu.cpu()).float().mean().item()
assert T_match == 1.0, f"{tt.name} T_match={T_match}"
assert Taccum_match == 1.0, f"{tt.name} Taccum_match={Taccum_match}"
print(" PASS test_cuda_triton_correctness_ternary_step")
def test_cuda_triton_tscale_path():
if not torch.cuda.is_available() or not tscale._HAS_TRITON:
print(" SKIP test_cuda_triton_tscale_path (CUDA/Triton unavailable)")
return
lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
x = torch.randn(2, 4, 32, device="cuda", requires_grad=True)
out = lin(x)
assert out.is_cuda, "Triton path should produce CUDA output"
assert out.shape == (2, 4, 16), f"Shape: {out.shape}"
grad_out = torch.randn_like(out)
out.backward(grad_out)
assert x.grad is not None and x.grad.is_cuda, "CUDA grad_x missing"
assert lin.T_accum.abs().sum().item() > 0, \
"Triton path should stream updates into int8 T_accum"
assert not hasattr(lin, "_hook_grad_T_sign"), \
"Triton path should not retain full weight-shaped grad-sign hooks"
assert not hasattr(lin, "_hook_grad_2d") and not hasattr(lin, "_hook_x_2d"), \
"Triton path should not retain fp32 grad/x views"
E_accum_before = lin.E_accum.clone()
torch.cuda.synchronize()
assert not torch.equal(lin.E_accum, E_accum_before) or lin.E_accum.abs().sum().item() > 0, \
"Streaming CUDA E update did not modify exponent residual state"
assert not hasattr(lin, "_hook_grad_T_sign"), \
"No retained grad-sign hook should remain after streaming backward"
assert lin.T_packed.is_cuda and lin.E.is_cuda, "Ternary buffers moved off CUDA after update"
lin_force = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
lin_force._hook_grad_2d = torch.ones(2, 16, device="cuda")
lin_force._hook_x_2d = torch.ones(2, 32, device="cuda")
lin_force.ternary_step(accum_threshold=0)
forced_T = lin_force._get_T()
assert forced_T.is_cuda, "Unpacked CUDA ternary state should stay on CUDA"
assert (forced_T == -1).all(), "CUDA ternary repack should move positive gradients in descent direction"
assert lin_force.T_accum.abs().sum().item() == 0, "CUDA ternary repack should reset flipped accumulators"
print(" PASS test_cuda_triton_tscale_path")
if __name__ == "__main__":
tests = [
test_tscale_shape,
test_tscale_ternary_output,
test_tscale_T64_per_element_s,
test_tscale_T32_group_s,
test_tscale_to_switching,
test_tscale_cast_alias,
test_tscale_gradient_flow,
test_tscale_all_types_forward,
test_tscale_dequantize,
test_tscale_effective_bpw,
test_tscale_model_integration,
test_tscale_runtime_switch,
test_sign_sgd_step,
test_sign_sgd_no_momentum,
test_sign_sgd_memory,
test_sign_sgd_with_tscale_model,
test_sign_sgd_weight_decay,
test_dequant_gemm_pytorch_ref,
test_dequant_gemm_matches_manual,
test_cuda_triton_correctness_linear,
test_cuda_triton_correctness_rmsnorm,
test_cuda_triton_correctness_embedding,
test_cuda_triton_correctness_update_E,
test_cuda_triton_correctness_ternary_step,
test_cuda_triton_tscale_path,
test_full_training_step,
test_multiple_steps_converge,
]
print("Running TernaryScale + SignSGD + TileLang Phase 2 tests...\n")
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f" FAIL {test.__name__}: {e}")
import traceback
traceback.print_exc()
failed += 1
print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")