| import math |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.kernel import ternary_scale as tscale |
| from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, TILE_SIZE, GROUP_SIZES |
| from arbitor.optim.sign_sgd import SignSGD |
| from arbitor.components import StickyZoneSTE |
| from arbitor.config import VOCAB, CTX, SPECIAL_VOCAB |
| from arbitor.main import ARBModel |
| from arbitor.components import LossComponents |
| from arbitor.kernel.ternary_scale import TernaryRMSNorm |
| from arbitor.sequencers import ByteEmbedding |
|
|
|
|
| def _cuda_available(min_gib=10): |
| """Check CUDA is available with enough GPU memory (min_gib GiB).""" |
| if not torch.cuda.is_available(): |
| return False |
| free, total = torch.cuda.mem_get_info() |
| if total < min_gib * 1e9: |
| return False |
| return True |
|
|
|
|
| |
|
|
| def test_tscale_shape(): |
| lin = TernaryScaleTensor(32, 16) |
| x = torch.randn(2, 10, 32) |
| out = lin(x) |
| assert out.shape == (2, 10, 16), f"Shape: {out.shape}" |
| print(" PASS test_tscale_shape") |
|
|
|
|
| def test_tscale_ternary_output(): |
| lin = TernaryScaleTensor(32, 16, threshold=0.05) |
| T = lin._get_T() |
| unique = set(T.detach().flatten().tolist()) |
| assert unique.issubset({-1, 0, 1}), f"Non-ternary values in T: {unique}" |
| print(" PASS test_tscale_ternary_output") |
|
|
|
|
| def test_tscale_T64_per_element_s(): |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64) |
| dq = lin.dequantize() |
| assert dq.shape == (16, 32), f"Dequantize shape: {dq.shape}" |
| print(" PASS test_tscale_T64_per_element_s") |
|
|
|
|
| def test_tscale_T32_group_s(): |
| lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T32) |
| dq = lin.dequantize() |
| gpr = lin.E.shape[0] // lin.out_dim |
| assert gpr > 0, f"Groups per row: {gpr}" |
| assert dq.shape == (16, 96), f"Dequantize shape: {dq.shape}" |
| print(" PASS test_tscale_T32_group_s") |
|
|
|
|
| def test_tscale_to_switching(): |
| lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T64) |
| dq_before = lin.dequantize() |
| assert lin.tscale_type == TScaleType.T64 |
| lin.tscale_to(TScaleType.T32) |
| assert lin.tscale_type == TScaleType.T32 |
| dq_after = lin.dequantize() |
| assert dq_before.shape == dq_after.shape |
| lin.tscale_to(TScaleType.T4) |
| assert lin.tscale_type == TScaleType.T4 |
| dq_t4 = lin.dequantize() |
| assert dq_t4.shape == dq_before.shape |
| print(" PASS test_tscale_to_switching") |
|
|
|
|
| def test_tscale_cast_alias(): |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64) |
| result = lin.tscale_cast(TScaleType.T8) |
| assert result is lin, "tscale_cast should return self" |
| assert lin.tscale_type == TScaleType.T8 |
| print(" PASS test_tscale_cast_alias") |
|
|
|
|
| def test_tscale_gradient_flow(): |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| x = torch.randn(2, 10, 32) |
| x.requires_grad_(True) |
| out = lin(x) |
| out.sum().backward() |
| assert x.grad is not None, "No gradient on input" |
| print(" PASS test_tscale_gradient_flow") |
|
|
|
|
| def test_tscale_all_types_forward(): |
| for tscale_type in TScaleType: |
| lin = TernaryScaleTensor(96, 16, tscale_type=tscale_type) |
| x = torch.randn(2, 4, 96) |
| out = lin(x) |
| assert out.shape == (2, 4, 16), f"{tscale_type.name}: shape {out.shape}" |
| assert torch.isfinite(out).all(), f"{tscale_type.name}: non-finite output" |
| print(" PASS test_tscale_all_types_forward") |
|
|
|
|
| def test_tscale_dequantize(): |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| w_eff = lin.dequantize() |
| assert w_eff.shape == (16, 32), f"Shape: {w_eff.shape}" |
| assert torch.isfinite(w_eff).all() |
| print(" PASS test_tscale_dequantize") |
|
|
|
|
| def test_tscale_effective_bpw(): |
| lin64 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T64) |
| lin4 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T4) |
| assert lin4.effective_bpw > lin64.effective_bpw, "T4 (gs=4) should have higher BPW than T64 (gs=64)" |
| print(f" T64 BPW: {lin64.effective_bpw:.2f}, T4 BPW: {lin4.effective_bpw:.2f}") |
| print(" PASS test_tscale_effective_bpw") |
|
|
|
|
| def test_tscale_model_integration(): |
| if not _cuda_available(): |
| print(" SKIP test_tscale_model_integration (need CUDA + >10GB GPU)") |
| return |
| for tscale_type in [TScaleType.T64, TScaleType.T32, TScaleType.T8]: |
| model = ARBModel(tscale_type=tscale_type).to("cuda") |
| x = torch.randint(0, VOCAB, (2, 10), device="cuda") |
| logits, losses, _, _ = model(x, targets=x[:, 3:]) |
| assert losses is not None |
| losses.total.backward() |
| print(" PASS test_tscale_model_integration") |
|
|
|
|
| def test_tscale_runtime_switch(): |
| if not _cuda_available(): |
| print(" SKIP test_tscale_runtime_switch (need CUDA + >10GB GPU)") |
| return |
| model = ARBModel(tscale_type=TScaleType.T64).to("cuda") |
| x = torch.randint(0, VOCAB, (1, 10), device="cuda") |
|
|
| logits64, _, _, _ = model(x) |
| for module in model.modules(): |
| if isinstance(module, TernaryScaleTensor): |
| module.tscale_to(TScaleType.T4) |
| logits4, _, _, _ = model(x) |
|
|
| assert torch.isfinite(logits4).all(), "Non-finite after tscale.to(T4)" |
| assert logits4.shape == logits64.shape, "Shape mismatch after tscale switch" |
| print(" PASS test_tscale_runtime_switch") |
|
|
|
|
| |
|
|
| def test_sign_sgd_step(): |
| model = torch.nn.Linear(10, 5) |
| optimizer = SignSGD(model.parameters(), lr=0.01) |
| x = torch.randn(2, 10) |
| loss = model(x).sum() |
| loss.backward() |
| w_before = model.weight.clone() |
| optimizer.step() |
| assert not torch.equal(model.weight, w_before), "Weights did not change" |
| print(" PASS test_sign_sgd_step") |
|
|
|
|
| def test_sign_sgd_no_momentum(): |
| model = torch.nn.Linear(10, 5) |
| optimizer = SignSGD(model.parameters(), lr=0.01) |
| assert len(optimizer.state) == 0, "SignSGD should have no state (no momentum)" |
| print(" PASS test_sign_sgd_no_momentum") |
|
|
|
|
| def test_sign_sgd_memory(): |
| model = torch.nn.Linear(100, 100) |
| optimizer = SignSGD(model.parameters(), lr=0.01) |
| mem = optimizer.get_memory_mb() |
| assert mem > 0, "Memory should be positive" |
| print(f" SignSGD memory: {mem:.2f} MB") |
| print(" PASS test_sign_sgd_memory") |
|
|
|
|
| def test_sign_sgd_with_tscale_model(): |
| if not _cuda_available(): |
| print(" SKIP test_sign_sgd_with_tscale_model (need CUDA + >10GB GPU)") |
| return |
| model = ARBModel(tscale_type=TScaleType.T32).to("cuda") |
| x = torch.randint(0, VOCAB, (2, 10), device="cuda") |
| logits, losses, _, _ = model(x, targets=x[:, 3:]) |
| losses.total.backward() |
| model._ternary_update_memory() |
| print(" PASS test_sign_sgd_with_tscale_model") |
|
|
|
|
| def test_sign_sgd_weight_decay(): |
| model = torch.nn.Linear(10, 5) |
| optimizer = SignSGD(model.parameters(), lr=0.01, weight_decay=0.01) |
| x = torch.randn(2, 10) |
| loss = model(x).sum() |
| loss.backward() |
| w_before = model.weight.clone() |
| optimizer.step() |
| w_diff = (model.weight - w_before).abs().sum().item() |
| assert w_diff > 0, "Weights should change with weight_decay" |
| print(" PASS test_sign_sgd_weight_decay") |
|
|
|
|
| |
|
|
| def test_dequant_gemm_pytorch_ref(): |
| import importlib.util |
| kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py") |
| if not os.path.exists(kernel_path): |
| print(" SKIP test_dequant_gemm_pytorch_ref (tilelang reference file missing)") |
| return |
| spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref |
|
|
| M, N, K, group_size = 4, 8, 96, 12 |
| signs = torch.randint(-1, 2, (N, K), dtype=torch.int8) |
| exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8) |
| x = torch.randn(M, K, dtype=torch.float16) |
|
|
| output = dequant_gemm_pytorch_ref(signs, exponents, x, group_size) |
| assert output.shape == (M, N), f"Shape: {output.shape}" |
| assert torch.isfinite(output).all(), "Non-finite output" |
| print(" PASS test_dequant_gemm_pytorch_ref") |
|
|
|
|
| def test_dequant_gemm_matches_manual(): |
| import importlib.util |
| import torch.nn.functional as F |
| kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py") |
| if not os.path.exists(kernel_path): |
| print(" SKIP test_dequant_gemm_matches_manual (tilelang reference file missing)") |
| return |
| spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref |
|
|
| M, N, K, group_size = 2, 4, 48, 12 |
| signs = torch.randint(-1, 2, (N, K), dtype=torch.int8) |
| exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8) |
| x = torch.randn(M, K, dtype=torch.float16) |
|
|
| result = dequant_gemm_pytorch_ref(signs, exponents, x, group_size) |
|
|
| exp_expanded = exponents.repeat_interleave(group_size, dim=1) |
| pos_mask = exp_expanded >= 0 |
| two_pow = torch.where(pos_mask, |
| (1 << exp_expanded.to(torch.int32)).to(torch.float16), |
| (1 >> (-exp_expanded.to(torch.int32))).to(torch.float16)) |
| w = signs.to(torch.float16) * two_pow |
| expected = x @ w.t() |
|
|
| assert torch.allclose(result, expected, atol=1e-3), "PyTorch ref mismatch" |
| print(" PASS test_dequant_gemm_matches_manual") |
|
|
|
|
| |
|
|
| def test_full_training_step(): |
| if not _cuda_available(): |
| print(" SKIP test_full_training_step (need CUDA + >10GB GPU)") |
| return |
| model = ARBModel(tscale_type=TScaleType.T32).to("cuda") |
| x = torch.randint(0, VOCAB, (2, 10), device="cuda") |
| logits, losses, _, _ = model(x, targets=x[:, 3:]) |
| losses.total.backward() |
| model._ternary_update_memory() |
|
|
| logits2, losses2, _, _ = model(x, targets=x[:, 3:]) |
| assert torch.isfinite(losses2.total), "Non-finite loss after step" |
| print(" PASS test_full_training_step") |
|
|
|
|
| def test_multiple_steps_converge(): |
| if not _cuda_available(): |
| print(" SKIP test_multiple_steps_converge (need CUDA + >10GB GPU)") |
| return |
| model = ARBModel(tscale_type=TScaleType.T32).to("cuda") |
| x = torch.randint(0, VOCAB, (4, 10), device="cuda") |
| losses = [] |
| for step in range(50): |
| logits, losses_out, _, _ = model(x, targets=x[:, 3:]) |
| loss_val = losses_out.total |
| loss_val.backward() |
| model._ternary_update_memory(accum_threshold=3) |
| losses.append(loss_val.item()) |
|
|
| assert torch.isfinite(torch.tensor(losses)).all(), "Non-finite loss during training" |
| print(f" Loss range: {min(losses):.4f} – {max(losses):.4f} over 50 steps") |
| print(" PASS test_multiple_steps_converge") |
|
|
|
|
| def test_cuda_triton_correctness_linear(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_correctness_linear (CUDA/Triton unavailable)") |
| return |
| from arbitor.kernel.ternary_scale import TernaryRMSNorm, _triton_ternary_embed |
| from arbitor.main import ByteEmbedding |
| ATOL = 2e-3 |
| for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: |
| lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) |
| x = torch.randn(4, 4, 32, requires_grad=True) |
| lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() |
| lin_gpu.load_state_dict(lin_cpu.state_dict()) |
| cpu_out = lin_cpu(x) |
| grad_out = torch.randn_like(cpu_out) |
| cpu_out.backward(grad_out) |
| cpu_grad_x = x.grad.clone() |
|
|
| x_gpu = x.detach().clone().cuda().requires_grad_(True) |
| gpu_out = lin_gpu(x_gpu) |
| gpu_out.backward(grad_out.cuda()) |
| gpu_grad_x = x_gpu.grad.clone() |
|
|
| fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() |
| bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item() |
| assert fwd_diff < ATOL, f"{tt.name} fwd_diff={fwd_diff}" |
| assert bwd_diff < ATOL, f"{tt.name} bwd_diff={bwd_diff}" |
| print(" PASS test_cuda_triton_correctness_linear") |
|
|
|
|
| def test_cuda_triton_correctness_rmsnorm(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_correctness_rmsnorm (CUDA/Triton unavailable)") |
| return |
| from arbitor.kernel.ternary_scale import TernaryRMSNorm |
| for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: |
| norm_cpu = TernaryRMSNorm(256, tscale_type=tt) |
| x = torch.randn(2, 4, 256, requires_grad=True) |
| cpu_out = norm_cpu(x) |
| cpu_out.sum().backward() |
| cpu_grad_x = x.grad.clone() |
|
|
| norm_gpu = TernaryRMSNorm(256, tscale_type=tt).cuda() |
| norm_gpu.load_state_dict(norm_cpu.state_dict()) |
| x_gpu = x.detach().clone().cuda().requires_grad_(True) |
| gpu_out = norm_gpu(x_gpu) |
| gpu_out.sum().backward() |
| gpu_grad_x = x_gpu.grad.clone() |
|
|
| fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() |
| bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item() |
| assert fwd_diff < 1e-5, f"{tt.name} rmsnorm fwd_diff={fwd_diff}" |
| assert bwd_diff < 1e-5, f"{tt.name} rmsnorm bwd_diff={bwd_diff}" |
| print(" PASS test_cuda_triton_correctness_rmsnorm") |
|
|
|
|
| def test_cuda_triton_correctness_embedding(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_correctness_embedding (CUDA/Triton unavailable)") |
| return |
| from arbitor.main import ByteEmbedding |
| for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: |
| emb_cpu = ByteEmbedding(tscale_type=tt) |
| x = torch.tensor([0, 1, 2, 5, 10]) |
| cpu_out = emb_cpu(x) |
| cpu_out.sum().backward() |
|
|
| emb_gpu = ByteEmbedding(tscale_type=tt).cuda() |
| emb_gpu.load_state_dict(emb_cpu.state_dict()) |
| x_gpu = x.cuda() |
| gpu_out = emb_gpu(x_gpu) |
| gpu_out.sum().backward() |
|
|
| fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item() |
| assert fwd_diff < 1e-5, f"{tt.name} embed fwd_diff={fwd_diff}" |
| if hasattr(emb_cpu, '_hook_grad_T_sign') and hasattr(emb_gpu, '_hook_grad_T_sign'): |
| gs_match = (emb_gpu._hook_grad_T_sign.cpu() == emb_cpu._hook_grad_T_sign).float().mean().item() |
| assert gs_match > 0.99, f"{tt.name} embed grad_sign match={gs_match}" |
| print(" PASS test_cuda_triton_correctness_embedding") |
|
|
|
|
| def test_cuda_triton_correctness_update_E(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_correctness_update_E (CUDA/Triton unavailable)") |
| return |
| for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: |
| lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) |
| lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() |
| lin_gpu.load_state_dict(lin_cpu.state_dict()) |
|
|
| x_cpu = torch.randn(4, 4, 32, requires_grad=True) |
| x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True) |
|
|
| cpu_out = lin_cpu(x_cpu) |
| cpu_out.sum().backward() |
| E_cpu = lin_cpu.E.clone() |
| corr_cpu = lin_cpu.corr_accum.clone() |
| step_cpu = lin_cpu.step_counter.clone() |
| gpu_out = lin_gpu(x_gpu) |
| gpu_out.sum().backward() |
| E_gpu = lin_gpu.E.clone() |
| corr_gpu = lin_gpu.corr_accum.clone() |
| step_gpu = lin_gpu.step_counter.clone() |
|
|
| |
| E_diff = (E_cpu.float() - E_gpu.cpu().float()).abs().max().item() |
| assert E_diff < 0.01, f"{tt.name} CPU-GPU E update mismatch: {E_diff}" |
| corr_diff = (corr_cpu - corr_gpu.cpu()).abs().max().item() |
| assert corr_diff == 0, f"{tt.name} CPU-GPU corr_accum update mismatch: {corr_diff}" |
| assert int(step_cpu.item()) == int(step_gpu.cpu().item()) == 1, \ |
| f"{tt.name} CPU-GPU step_counter mismatch: cpu={step_cpu.item()} gpu={step_gpu.cpu().item()}" |
| print(" PASS test_cuda_triton_correctness_update_E") |
|
|
|
|
| def test_cuda_triton_correctness_ternary_step(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_correctness_ternary_step (CUDA/Triton unavailable)") |
| return |
| for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]: |
| lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt) |
| lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda() |
| lin_gpu.load_state_dict(lin_cpu.state_dict()) |
|
|
| x_cpu = torch.randn(4, 4, 32, requires_grad=True) |
| x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True) |
|
|
| cpu_out = lin_cpu(x_cpu) |
| cpu_out.sum().backward() |
| lin_cpu.ternary_step(accum_threshold=3) |
| T_cpu = lin_cpu._get_T().clone() |
| corr_cpu = lin_cpu.corr_accum.clone() |
|
|
| gpu_out = lin_gpu(x_gpu) |
| gpu_out.sum().backward() |
| lin_gpu.ternary_step(accum_threshold=3) |
| T_gpu = lin_gpu._get_T().clone() |
| corr_gpu = lin_gpu.corr_accum.clone() |
|
|
| T_match = (T_cpu == T_gpu.cpu()).float().mean().item() |
| corr_match = (corr_cpu == corr_gpu.cpu()).float().mean().item() |
| assert T_match == 1.0, f"{tt.name} T_match={T_match}" |
| assert corr_match == 1.0, f"{tt.name} corr_match={corr_match}" |
| print(" PASS test_cuda_triton_correctness_ternary_step") |
|
|
|
|
| def test_cuda_triton_tscale_path(): |
| if not torch.cuda.is_available() or not tscale._HAS_TRITON: |
| print(" SKIP test_cuda_triton_tscale_path (CUDA/Triton unavailable)") |
| return |
|
|
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda() |
| x = torch.randn(2, 4, 32, device="cuda", requires_grad=True) |
| out = lin(x) |
| assert out.is_cuda, "Triton path should produce CUDA output" |
| assert out.shape == (2, 4, 16), f"Shape: {out.shape}" |
| grad_out = torch.randn_like(out) |
| out.backward(grad_out) |
| assert x.grad is not None and x.grad.is_cuda, "CUDA grad_x missing" |
| assert lin.corr_accum.abs().sum().item() > 0, \ |
| "Triton path should stream updates into int64 corr_accum" |
| assert int(lin.step_counter.item()) == 1, "Triton path should advance the BigInt step counter" |
| assert not hasattr(lin, "_hook_grad_T_sign"), \ |
| "Triton path should not retain full weight-shaped grad-sign hooks" |
| assert not hasattr(lin, "_hook_grad_2d") and not hasattr(lin, "_hook_x_2d"), \ |
| "Triton path should not retain fp32 grad/x views" |
| torch.cuda.synchronize() |
| assert not hasattr(lin, "_hook_grad_T_sign"), \ |
| "No retained grad-sign hook should remain after streaming backward" |
| assert lin.T_packed.is_cuda and lin.E.is_cuda, "Ternary buffers moved off CUDA after update" |
|
|
| lin_force = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda() |
| lin_force._hook_grad_2d = torch.ones(2, 16, device="cuda") |
| lin_force._hook_x_2d = torch.ones(2, 32, device="cuda") |
| lin_force.update_E() |
| forced_T = lin_force._get_T() |
| assert forced_T.is_cuda, "Unpacked CUDA ternary state should stay on CUDA" |
| assert lin_force.corr_accum.abs().sum().item() > 0, "Forced CUDA hook should update BigInt corr_accum" |
| assert int(lin_force.step_counter.item()) == 1, "Forced CUDA hook should advance the BigInt step counter" |
| print(" PASS test_cuda_triton_tscale_path") |
|
|
|
|
| def test_small_ternary_training_loss_finite(): |
| if not torch.cuda.is_available(): |
| print(" SKIP test_small_ternary_training_loss_finite (CUDA unavailable)") |
| return |
| model = ARBModel( |
| enable_image=False, |
| enable_audio=False, |
| enable_vq=False, |
| enable_graph=False, |
| enable_memory_modules=False, |
| enable_moe=False, |
| tscale_type=TScaleType.T32, |
| ).cuda() |
| x = torch.randint(0, VOCAB, (1, 4), device="cuda") |
| _, losses, _, _ = model(x, targets=x[:, 3:]) |
| assert torch.isfinite(losses.total), "Small ternary training loss is non-finite" |
| model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_components=losses) |
| leftovers = [ |
| name for name, module in model.named_modules() |
| if any(hasattr(module, hook) for hook in ("_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d")) |
| ] |
| assert not leftovers, f"Ternary update left stale hooks: {leftovers[:5]}" |
| print(" PASS test_small_ternary_training_loss_finite") |
|
|
|
|
| def test_ternary_update_rejects_nonfinite_loss(): |
| import warnings |
| model = ARBModel( |
| enable_image=False, |
| enable_audio=False, |
| enable_vq=False, |
| enable_graph=False, |
| enable_memory_modules=False, |
| enable_moe=False, |
| tscale_type=TScaleType.T32, |
| ) |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| lc = LossComponents(lm=torch.tensor(float("nan"))) |
| model._ternary_update_memory(loss_components=lc) |
| assert len(w) > 0, "Expected a warning for non-finite loss" |
| assert "Non-finite loss" in str(w[0].message), f"Unexpected warning: {w[0].message}" |
| print(" PASS test_ternary_update_rejects_nonfinite_loss") |
|
|
|
|
| |
|
|
| def test_e_rms_weighted_delta(): |
| lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32) |
| grad = torch.randn(4, 16) |
| x = torch.randn(4, 32) |
| raw_grad = grad.T @ x |
| |
| gpr = (32 + lin.group_size - 1) // lin.group_size |
| rms_per_group = [] |
| for g in range(gpr): |
| start = g * lin.group_size |
| end = min(start + lin.group_size, 32) |
| group = raw_grad[:, start:end] |
| rms = group.pow(2).mean().sqrt().item() |
| rms_per_group.append(rms) |
| rms = rms_per_group[0] |
| score = (raw_grad * lin._get_T().float()).sum().item() |
| delta = - (1 if score > 0 else -1 if score < 0 else 0) * max(1, min(3, round(math.log2(1 + rms)))) |
| assert 1 <= abs(delta) <= 4, f"delta magnitude {abs(delta)} out of range" |
| print(" PASS test_e_rms_weighted_delta") |
|
|
|
|
| def test_e_rms_vs_sign_only(): |
| |
| raw_low = torch.ones(16, 32) * 0.1 |
| raw_high = torch.ones(16, 32) * 10.0 |
| T = torch.ones(16, 32) |
| rms_low = raw_low.pow(2).mean().sqrt() |
| rms_high = raw_high.pow(2).mean().sqrt() |
| delta_low = max(1, min(3, round(math.log2(1 + rms_low.item())))) |
| delta_high = max(1, min(3, round(math.log2(1 + rms_high.item())))) |
| assert delta_low != delta_high, "RMS delta should differ for different magnitudes" |
| assert delta_low < delta_high, "Higher RMS should give larger delta" |
| print(" PASS test_e_rms_vs_sign_only") |
|
|
|
|
| def test_e_zscore_normalization(): |
| comp_a_rms = torch.tensor([10.0, 12.0, 8.0, 11.0]) |
| comp_b_rms = torch.tensor([1.0, 1.2, 0.8, 1.1]) |
| z_a = (comp_a_rms - comp_a_rms.mean()) / (comp_a_rms.std() + 1e-8) |
| z_b = (comp_b_rms - comp_b_rms.mean()) / (comp_b_rms.std() + 1e-8) |
| assert abs(z_a.mean().item()) < 1e-6, f"z_a mean not ~0: {z_a.mean().item()}" |
| assert abs(z_b.mean().item()) < 1e-6, f"z_b mean not ~0: {z_b.mean().item()}" |
| assert abs(z_a.std().item() - 1.0) < 0.1, f"z_a std not ~1: {z_a.std().item()}" |
| assert abs(z_b.std().item() - 1.0) < 0.1, f"z_b std not ~1: {z_b.std().item()}" |
| print(" PASS test_e_zscore_normalization") |
|
|
|
|
| def test_e_zscore_zero_std(): |
| rms_flat = torch.ones(8) * 5.0 |
| z = torch.where(rms_flat.std() > 1e-8, (rms_flat - rms_flat.mean()) / (rms_flat.std()), torch.zeros_like(rms_flat)) |
| assert torch.isfinite(z).all(), "z-scores should be finite when std=0" |
| assert (z == 0).all(), "z-scores should be zero when std=0" |
| print(" PASS test_e_zscore_zero_std") |
|
|
|
|
| def test_group_lr_registration(): |
| tst = TernaryScaleTensor(32, 16) |
| assert hasattr(tst, "corr_accum") |
| assert tst.corr_accum.dtype == torch.int64 |
| assert tst.corr_accum.shape == tst.E.shape |
| assert int(tst.step_counter.item()) == 0 |
| be = ByteEmbedding() |
| assert hasattr(be, "corr_accum") |
| assert be.corr_accum.dtype == torch.int64 |
| assert be.corr_accum.shape[0] > 0 |
| assert int(be.step_counter.item()) == 0 |
| rms = TernaryRMSNorm(256) |
| assert hasattr(rms, "E") |
| print(" PASS test_group_lr_registration") |
|
|
|
|
| def test_group_lr_effect(): |
| delta = torch.tensor(4, dtype=torch.int8) |
| group_lr_high = torch.tensor(8, dtype=torch.int8) |
| group_lr_low = torch.tensor(1, dtype=torch.int8) |
| eff_high = delta.to(torch.int16) * group_lr_high.to(torch.int16) // 8 |
| eff_low = delta.to(torch.int16) * group_lr_low.to(torch.int16) // 8 |
| assert eff_high.item() == 4, f"high LR should give full delta, got {eff_high.item()}" |
| assert eff_low.item() == 0, f"low LR should give 0 delta, got {eff_low.item()}" |
| print(" PASS test_group_lr_effect") |
|
|
|
|
| def test_group_lr_dynamic_update(): |
| group_lr = torch.ones(4, dtype=torch.int8) |
| rms_prev = torch.tensor([1.0, 5.0, 3.0, 2.0]) |
| rms_curr = torch.tensor([2.0, 3.0, 3.0, 1.0]) |
| rms_growth = rms_curr - rms_prev |
| updated = torch.clamp(group_lr.to(torch.int16) + (rms_growth > 0).to(torch.int16) - (rms_growth < 0).to(torch.int16), 1, 8).to(torch.int8) |
| assert updated[0].item() == 2, f"RMS increased -> LR should increase, got {updated[0].item()}" |
| assert updated[1].item() == 1, f"RMS decreased -> LR should decrease, got {updated[1].item()}" |
| assert updated[2].item() == 1, f"RMS unchanged -> LR unchanged, got {updated[2].item()}" |
| |
| too_high = torch.clamp(torch.tensor([100], dtype=torch.int16), 1, 8) |
| too_low = torch.clamp(torch.tensor([-100], dtype=torch.int16), 1, 8) |
| assert too_high.item() == 8, f"clamp max, got {too_high.item()}" |
| assert too_low.item() == 1, f"clamp min, got {too_low.item()}" |
| print(" PASS test_group_lr_dynamic_update") |
|
|
|
|
| def test_e_stats_cpu_fallback(): |
| N, K, group_size = 16, 32, 12 |
| grad = torch.randn(4, N) |
| x = torch.randn(4, K) |
| raw_grad = grad.T @ x |
| gpr = (K + group_size - 1) // group_size |
| rms_vals = [] |
| for g in range(gpr): |
| start = g * group_size |
| end = min(start + group_size, K) |
| group = raw_grad[:, start:end] |
| rms = group.pow(2).mean().sqrt() |
| rms_vals.append(rms.item()) |
| assert all(torch.isfinite(torch.tensor(rms_vals))), "finite check" |
| assert all(1 <= max(1, min(3, round(math.log2(1 + r)))) <= 3 for r in rms_vals), "clamp range" |
| print(" PASS test_e_stats_cpu_fallback") |
|
|
|
|
| def test_e_per_component_routing(): |
| if not _cuda_available(): |
| print(" SKIP test_e_per_component_routing (CUDA)") |
| return |
| model = ARBModel(enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False).cuda() |
| x = torch.randint(0, VOCAB, (1, 4), device="cuda") |
| for step in range(3): |
| _, lc, _, _ = model(x, targets=x[:, 3:]) |
| model._ternary_update_memory(loss_components=lc) |
| assert True |
| print(" PASS test_e_per_component_routing") |
|
|
|
|
| def test_ensure_group_lr_backward_compat(): |
| tst = TernaryScaleTensor(32, 16) |
| assert hasattr(tst, "corr_accum") |
| assert hasattr(tst, "step_counter") |
| be = ByteEmbedding() |
| assert hasattr(be, "corr_accum") |
| assert hasattr(be, "step_counter") |
| rms = TernaryRMSNorm(256) |
| assert hasattr(rms, "E") |
| print(" PASS test_ensure_group_lr_backward_compat") |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| tests = [ |
| test_tscale_shape, |
| test_tscale_ternary_output, |
| test_tscale_T64_per_element_s, |
| test_tscale_T32_group_s, |
| test_tscale_to_switching, |
| test_tscale_cast_alias, |
| test_tscale_gradient_flow, |
| test_tscale_all_types_forward, |
| test_tscale_dequantize, |
| test_tscale_effective_bpw, |
| test_tscale_model_integration, |
| test_tscale_runtime_switch, |
| test_sign_sgd_step, |
| test_sign_sgd_no_momentum, |
| test_sign_sgd_memory, |
| test_sign_sgd_with_tscale_model, |
| test_sign_sgd_weight_decay, |
| test_dequant_gemm_pytorch_ref, |
| test_dequant_gemm_matches_manual, |
| test_cuda_triton_correctness_linear, |
| test_cuda_triton_correctness_rmsnorm, |
| test_cuda_triton_correctness_embedding, |
| test_cuda_triton_correctness_update_E, |
| test_cuda_triton_correctness_ternary_step, |
| test_cuda_triton_tscale_path, |
| test_small_ternary_training_loss_finite, |
| test_ternary_update_rejects_nonfinite_loss, |
| test_full_training_step, |
| test_multiple_steps_converge, |
| test_e_rms_weighted_delta, |
| test_e_rms_vs_sign_only, |
| test_e_zscore_normalization, |
| test_e_zscore_zero_std, |
| test_group_lr_registration, |
| test_group_lr_effect, |
| test_group_lr_dynamic_update, |
| test_e_stats_cpu_fallback, |
| test_e_per_component_routing, |
| test_ensure_group_lr_backward_compat, |
| ] |
| print("Running TernaryScale + SignSGD + TileLang Phase 2 tests...\n") |
| passed = 0 |
| failed = 0 |
| for test in tests: |
| try: |
| test() |
| passed += 1 |
| except Exception as e: |
| print(f" FAIL {test.__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests") |
|
|