import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.kernel import ternary_scale as tscale from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, TILE_SIZE, GROUP_SIZES from arbitor.optim.sign_sgd import SignSGD from arbitor.components import StickyZoneSTE from arbitor.config import VOCAB, CTX, SPECIAL_VOCAB from arbitor.main import ARBModel # ─── TernaryScaleTensor Tests ─── def test_tscale_shape(): lin = TernaryScaleTensor(32, 16) x = torch.randn(2, 10, 32) out = lin(x) assert out.shape == (2, 10, 16), f"Shape: {out.shape}" print(" PASS test_tscale_shape") def test_tscale_ternary_output(): lin = TernaryScaleTensor(32, 16, threshold=0.05) T = lin._get_T() unique = set(T.detach().flatten().tolist()) assert unique.issubset({-1, 0, 1}), f"Non-ternary values in T: {unique}" print(" PASS test_tscale_ternary_output") def test_tscale_T64_per_element_s(): lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64) dq = lin.dequantize() assert dq.shape == (16, 32), f"Dequantize shape: {dq.shape}" print(" PASS test_tscale_T64_per_element_s") def test_tscale_T32_group_s(): lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T32) dq = lin.dequantize() gpr = lin.E.shape[0] // lin.out_dim assert gpr > 0, f"Groups per row: {gpr}" assert dq.shape == (16, 96), f"Dequantize shape: {dq.shape}" print(" PASS test_tscale_T32_group_s") def test_tscale_to_switching(): lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T64) dq_before = lin.dequantize() assert lin.tscale_type == TScaleType.T64 lin.tscale_to(TScaleType.T32) assert lin.tscale_type == TScaleType.T32 dq_after = lin.dequantize() assert dq_before.shape == dq_after.shape lin.tscale_to(TScaleType.T4) assert lin.tscale_type == TScaleType.T4 dq_t4 = lin.dequantize() assert dq_t4.shape == dq_before.shape print(" PASS test_tscale_to_switching") def test_tscale_cast_alias(): lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64) result = lin.tscale_cast(TScaleType.T8) assert result is lin, "tscale_cast should return self" assert lin.tscale_type == TScaleType.T8 print(" PASS test_tscale_cast_alias") def test_tscale_gradient_flow(): lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) x = torch.randn(2, 10, 32) x.requires_grad_(True) out = lin(x) out.sum().backward() assert x.grad is not None, "No gradient on input" print(" PASS test_tscale_gradient_flow") def test_tscale_all_types_forward(): for tscale_type in TScaleType: lin = TernaryScaleTensor(96, 16, tscale_type=tscale_type) x = torch.randn(2, 4, 96) out = lin(x) assert out.shape == (2, 4, 16), f"{tscale_type.name}: shape {out.shape}" assert torch.isfinite(out).all(), f"{tscale_type.name}: non-finite output" print(" PASS test_tscale_all_types_forward") def test_tscale_dequantize(): lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) w_eff = lin.dequantize() assert w_eff.shape == (16, 32), f"Shape: {w_eff.shape}" assert torch.isfinite(w_eff).all() print(" PASS test_tscale_dequantize") def test_tscale_effective_bpw(): lin64 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T64) lin4 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T4) assert lin4.effective_bpw < lin64.effective_bpw, "T4 should have lower BPW than T64" print(f" T64 BPW: {lin64.effective_bpw:.2f}, T4 BPW: {lin4.effective_bpw:.2f}") print(" PASS test_tscale_effective_bpw") def test_tscale_model_integration(): for tscale_type in [TScaleType.T64, TScaleType.T32, TScaleType.T8]: model = ARBModel(tscale_type=tscale_type) x = torch.randint(0, VOCAB, (2, 10)) logits, losses, _, _ = model(x, targets=x[:, 3:]) assert losses is not None losses.total.backward() print(" PASS test_tscale_model_integration") def test_tscale_runtime_switch(): model = ARBModel(tscale_type=TScaleType.T64) x = torch.randint(0, VOCAB, (1, 10)) logits64, _, _, _ = model(x) for module in model.modules(): if isinstance(module, TernaryScaleTensor): module.tscale_to(TScaleType.T4) logits4, _, _, _ = model(x) assert torch.isfinite(logits4).all(), "Non-finite after tscale.to(T4)" assert logits4.shape == logits64.shape, "Shape mismatch after tscale switch" print(" PASS test_tscale_runtime_switch") # ─── SignSGD Tests ─── def test_sign_sgd_step(): model = torch.nn.Linear(10, 5) optimizer = SignSGD(model.parameters(), lr=0.01) x = torch.randn(2, 10) loss = model(x).sum() loss.backward() w_before = model.weight.clone() optimizer.step() assert not torch.equal(model.weight, w_before), "Weights did not change" print(" PASS test_sign_sgd_step") def test_sign_sgd_no_momentum(): model = torch.nn.Linear(10, 5) optimizer = SignSGD(model.parameters(), lr=0.01) assert len(optimizer.state) == 0, "SignSGD should have no state (no momentum)" print(" PASS test_sign_sgd_no_momentum") def test_sign_sgd_memory(): model = torch.nn.Linear(100, 100) optimizer = SignSGD(model.parameters(), lr=0.01) mem = optimizer.get_memory_mb() assert mem > 0, "Memory should be positive" print(f" SignSGD memory: {mem:.2f} MB") print(" PASS test_sign_sgd_memory") def test_sign_sgd_with_tscale_model(): model = ARBModel(tscale_type=TScaleType.T32) optimizer = SignSGD(model.parameters(), lr=0.01) x = torch.randint(0, VOCAB, (2, 10)) logits, losses, _, _ = model(x, targets=x[:, 3:]) losses.total.backward() optimizer.step() model._ternary_update_memory() assert len(optimizer.state) == 0, "SignSGD should have no state" print(" PASS test_sign_sgd_with_tscale_model") def test_sign_sgd_weight_decay(): model = torch.nn.Linear(10, 5) optimizer = SignSGD(model.parameters(), lr=0.01, weight_decay=0.01) x = torch.randn(2, 10) loss = model(x).sum() loss.backward() w_before = model.weight.clone() optimizer.step() w_diff = (model.weight - w_before).abs().sum().item() assert w_diff > 0, "Weights should change with weight_decay" print(" PASS test_sign_sgd_weight_decay") # ─── TileLang PyTorch Reference Tests ─── def test_dequant_gemm_pytorch_ref(): import importlib.util kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py") if not os.path.exists(kernel_path): print(" SKIP test_dequant_gemm_pytorch_ref (tilelang reference file missing)") return spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref M, N, K, group_size = 4, 8, 96, 12 signs = torch.randint(-1, 2, (N, K), dtype=torch.int8) exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8) x = torch.randn(M, K, dtype=torch.float16) output = dequant_gemm_pytorch_ref(signs, exponents, x, group_size) assert output.shape == (M, N), f"Shape: {output.shape}" assert torch.isfinite(output).all(), "Non-finite output" print(" PASS test_dequant_gemm_pytorch_ref") def test_dequant_gemm_matches_manual(): import importlib.util import torch.nn.functional as F kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py") if not os.path.exists(kernel_path): print(" SKIP test_dequant_gemm_matches_manual (tilelang reference file missing)") return spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref M, N, K, group_size = 2, 4, 48, 12 signs = torch.randint(-1, 2, (N, K), dtype=torch.int8) exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8) x = torch.randn(M, K, dtype=torch.float16) result = dequant_gemm_pytorch_ref(signs, exponents, x, group_size) exp_expanded = exponents.repeat_interleave(group_size, dim=1) pos_mask = exp_expanded >= 0 two_pow = torch.where(pos_mask, (1 << exp_expanded.to(torch.int32)).to(torch.float16), (1 >> (-exp_expanded.to(torch.int32))).to(torch.float16)) w = signs.to(torch.float16) * two_pow expected = x @ w.t() assert torch.allclose(result, expected, atol=1e-3), "PyTorch ref mismatch" print(" PASS test_dequant_gemm_matches_manual") # ─── Integration: SignSGD + TernaryScaleTensor training step ─── def test_full_training_step(): model = ARBModel(tscale_type=TScaleType.T32) optimizer = SignSGD(model.parameters(), lr=0.01) x = torch.randint(0, VOCAB, (2, 10)) logits, losses, _, _ = model(x, targets=x[:, 3:]) losses.total.backward() optimizer.step() model._ternary_update_memory() logits2, losses2, _, _ = model(x, targets=x[:, 3:]) assert torch.isfinite(losses2.total), "Non-finite loss after step" print(" PASS test_full_training_step") def test_multiple_steps_converge(): model = ARBModel(tscale_type=TScaleType.T32) optimizer = SignSGD(model.parameters(), lr=0.001) x = torch.randint(0, VOCAB, (4, 10)) losses = [] for step in range(50): optimizer.zero_grad() logits, losses_out, _, _ = model(x, targets=x[:, 3:]) loss_val = losses_out.total loss_val.backward() optimizer.step() model._ternary_update_memory(accum_threshold=3) losses.append(loss_val.item()) assert torch.isfinite(torch.tensor(losses)).all(), "Non-finite loss during training" print(f" Loss range: {min(losses):.4f} – {max(losses):.4f} over 50 steps") print(" PASS test_multiple_steps_converge") def test_cuda_triton_correctness_linear(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_correctness_linear (CUDA/Triton unavailable)") return from arbitor.kernel.ternary_scale import TernaryRMSNorm, _triton_ternary_embed from arbitor.main import ByteEmbedding ATOL = 1e-3 for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) x = torch.randn(4, 4, 32, requires_grad=True) cpu_out = lin_cpu(x) grad_out = torch.randn_like(cpu_out) cpu_out.backward(grad_out) cpu_grad_x = x.grad.clone() lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() lin_gpu.load_state_dict(lin_cpu.state_dict()) x_gpu = x.detach().clone().cuda().requires_grad_(True) gpu_out = lin_gpu(x_gpu) gpu_out.backward(grad_out.cuda()) gpu_grad_x = x_gpu.grad.clone() fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item() assert fwd_diff < ATOL, f"{tt.name} fwd_diff={fwd_diff}" assert bwd_diff < ATOL, f"{tt.name} bwd_diff={bwd_diff}" print(" PASS test_cuda_triton_correctness_linear") def test_cuda_triton_correctness_rmsnorm(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_correctness_rmsnorm (CUDA/Triton unavailable)") return from arbitor.kernel.ternary_scale import TernaryRMSNorm for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: norm_cpu = TernaryRMSNorm(256, tscale_type=tt) x = torch.randn(2, 4, 256, requires_grad=True) cpu_out = norm_cpu(x) cpu_out.sum().backward() cpu_grad_x = x.grad.clone() norm_gpu = TernaryRMSNorm(256, tscale_type=tt).cuda() norm_gpu.load_state_dict(norm_cpu.state_dict()) x_gpu = x.detach().clone().cuda().requires_grad_(True) gpu_out = norm_gpu(x_gpu) gpu_out.sum().backward() gpu_grad_x = x_gpu.grad.clone() fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item() assert fwd_diff < 1e-5, f"{tt.name} rmsnorm fwd_diff={fwd_diff}" assert bwd_diff < 1e-5, f"{tt.name} rmsnorm bwd_diff={bwd_diff}" print(" PASS test_cuda_triton_correctness_rmsnorm") def test_cuda_triton_correctness_embedding(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_correctness_embedding (CUDA/Triton unavailable)") return from arbitor.main import ByteEmbedding for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: emb_cpu = ByteEmbedding(tscale_type=tt) x = torch.tensor([0, 1, 2, 5, 10]) cpu_out = emb_cpu(x) cpu_out.sum().backward() emb_gpu = ByteEmbedding(tscale_type=tt).cuda() emb_gpu.load_state_dict(emb_cpu.state_dict()) x_gpu = x.cuda() gpu_out = emb_gpu(x_gpu) gpu_out.sum().backward() fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() assert fwd_diff < 1e-5, f"{tt.name} embed fwd_diff={fwd_diff}" if hasattr(emb_cpu, '_hook_grad_T_sign') and hasattr(emb_gpu, '_hook_grad_T_sign'): gs_match = (emb_gpu._hook_grad_T_sign.cpu() == emb_cpu._hook_grad_T_sign).float().mean().item() assert gs_match > 0.99, f"{tt.name} embed grad_sign match={gs_match}" print(" PASS test_cuda_triton_correctness_embedding") def test_cuda_triton_correctness_update_E(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_correctness_update_E (CUDA/Triton unavailable)") return for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() lin_gpu.load_state_dict(lin_cpu.state_dict()) x_cpu = torch.randn(4, 4, 32, requires_grad=True) x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True) cpu_out = lin_cpu(x_cpu) cpu_out.sum().backward() lin_cpu.update_E() E_cpu = lin_cpu.E.clone() E_accum_cpu = lin_cpu.E_accum.clone() gpu_out = lin_gpu(x_gpu) gpu_out.sum().backward() lin_gpu.update_E() E_gpu = lin_gpu.E.clone() E_accum_gpu = lin_gpu.E_accum.clone() # Compare fixed-point E residual update results. E_diff = (E_cpu.float() - E_gpu.cpu().float()).abs().max().item() assert E_diff < 0.01, f"{tt.name} CPU-GPU E update mismatch: {E_diff}" E_accum_diff = (E_accum_cpu.float() - E_accum_gpu.cpu().float()).abs().max().item() assert E_accum_diff < 0.01, f"{tt.name} CPU-GPU E_accum update mismatch: {E_accum_diff}" print(" PASS test_cuda_triton_correctness_update_E") def test_cuda_triton_correctness_ternary_step(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_correctness_ternary_step (CUDA/Triton unavailable)") return for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() lin_gpu.load_state_dict(lin_cpu.state_dict()) x_cpu = torch.randn(4, 4, 32, requires_grad=True) x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True) cpu_out = lin_cpu(x_cpu) cpu_out.sum().backward() lin_cpu.ternary_step(accum_threshold=3) T_cpu = lin_cpu._get_T().clone() Taccum_cpu = lin_cpu.T_accum.clone() gpu_out = lin_gpu(x_gpu) gpu_out.sum().backward() lin_gpu.ternary_step(accum_threshold=3) T_gpu = lin_gpu._get_T().clone() Taccum_gpu = lin_gpu.T_accum.clone() T_match = (T_cpu == T_gpu.cpu()).float().mean().item() Taccum_match = (Taccum_cpu == Taccum_gpu.cpu()).float().mean().item() assert T_match == 1.0, f"{tt.name} T_match={T_match}" assert Taccum_match == 1.0, f"{tt.name} Taccum_match={Taccum_match}" print(" PASS test_cuda_triton_correctness_ternary_step") def test_cuda_triton_tscale_path(): if not torch.cuda.is_available() or not tscale._HAS_TRITON: print(" SKIP test_cuda_triton_tscale_path (CUDA/Triton unavailable)") return lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda() x = torch.randn(2, 4, 32, device="cuda", requires_grad=True) out = lin(x) assert out.is_cuda, "Triton path should produce CUDA output" assert out.shape == (2, 4, 16), f"Shape: {out.shape}" grad_out = torch.randn_like(out) out.backward(grad_out) assert x.grad is not None and x.grad.is_cuda, "CUDA grad_x missing" assert lin.T_accum.abs().sum().item() > 0, \ "Triton path should stream updates into int8 T_accum" assert not hasattr(lin, "_hook_grad_T_sign"), \ "Triton path should not retain full weight-shaped grad-sign hooks" assert not hasattr(lin, "_hook_grad_2d") and not hasattr(lin, "_hook_x_2d"), \ "Triton path should not retain fp32 grad/x views" E_accum_before = lin.E_accum.clone() torch.cuda.synchronize() assert not torch.equal(lin.E_accum, E_accum_before) or lin.E_accum.abs().sum().item() > 0, \ "Streaming CUDA E update did not modify exponent residual state" assert not hasattr(lin, "_hook_grad_T_sign"), \ "No retained grad-sign hook should remain after streaming backward" assert lin.T_packed.is_cuda and lin.E.is_cuda, "Ternary buffers moved off CUDA after update" lin_force = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda() lin_force._hook_grad_2d = torch.ones(2, 16, device="cuda") lin_force._hook_x_2d = torch.ones(2, 32, device="cuda") lin_force.ternary_step(accum_threshold=0) forced_T = lin_force._get_T() assert forced_T.is_cuda, "Unpacked CUDA ternary state should stay on CUDA" assert (forced_T == -1).all(), "CUDA ternary repack should move positive gradients in descent direction" assert lin_force.T_accum.abs().sum().item() == 0, "CUDA ternary repack should reset flipped accumulators" print(" PASS test_cuda_triton_tscale_path") if __name__ == "__main__": tests = [ test_tscale_shape, test_tscale_ternary_output, test_tscale_T64_per_element_s, test_tscale_T32_group_s, test_tscale_to_switching, test_tscale_cast_alias, test_tscale_gradient_flow, test_tscale_all_types_forward, test_tscale_dequantize, test_tscale_effective_bpw, test_tscale_model_integration, test_tscale_runtime_switch, test_sign_sgd_step, test_sign_sgd_no_momentum, test_sign_sgd_memory, test_sign_sgd_with_tscale_model, test_sign_sgd_weight_decay, test_dequant_gemm_pytorch_ref, test_dequant_gemm_matches_manual, test_cuda_triton_correctness_linear, test_cuda_triton_correctness_rmsnorm, test_cuda_triton_correctness_embedding, test_cuda_triton_correctness_update_E, test_cuda_triton_correctness_ternary_step, test_cuda_triton_tscale_path, test_full_training_step, test_multiple_steps_converge, ] print("Running TernaryScale + SignSGD + TileLang Phase 2 tests...\n") passed = 0 failed = 0 for test in tests: try: test() passed += 1 except Exception as e: print(f" FAIL {test.__name__}: {e}") import traceback traceback.print_exc() failed += 1 print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")