File size: 29,768 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
import math
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.kernel import ternary_scale as tscale
from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, TILE_SIZE, GROUP_SIZES
from arbitor.optim.sign_sgd import SignSGD
from arbitor.components import StickyZoneSTE
from arbitor.config import VOCAB, CTX, SPECIAL_VOCAB
from arbitor.main import ARBModel
from arbitor.components import LossComponents
from arbitor.kernel.ternary_scale import TernaryRMSNorm
from arbitor.sequencers import ByteEmbedding


def _cuda_available(min_gib=10):
    """Check CUDA is available with enough GPU memory (min_gib GiB)."""
    if not torch.cuda.is_available():
        return False
    free, total = torch.cuda.mem_get_info()
    if total < min_gib * 1e9:
        return False
    return True


# ─── TernaryScaleTensor Tests ───

def test_tscale_shape():
    lin = TernaryScaleTensor(32, 16)
    x = torch.randn(2, 10, 32)
    out = lin(x)
    assert out.shape == (2, 10, 16), f"Shape: {out.shape}"
    print(" PASS test_tscale_shape")


def test_tscale_ternary_output():
    lin = TernaryScaleTensor(32, 16, threshold=0.05)
    T = lin._get_T()
    unique = set(T.detach().flatten().tolist())
    assert unique.issubset({-1, 0, 1}), f"Non-ternary values in T: {unique}"
    print(" PASS test_tscale_ternary_output")


def test_tscale_T64_per_element_s():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
    dq = lin.dequantize()
    assert dq.shape == (16, 32), f"Dequantize shape: {dq.shape}"
    print(" PASS test_tscale_T64_per_element_s")


def test_tscale_T32_group_s():
    lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T32)
    dq = lin.dequantize()
    gpr = lin.E.shape[0] // lin.out_dim
    assert gpr > 0, f"Groups per row: {gpr}"
    assert dq.shape == (16, 96), f"Dequantize shape: {dq.shape}"
    print(" PASS test_tscale_T32_group_s")


def test_tscale_to_switching():
    lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T64)
    dq_before = lin.dequantize()
    assert lin.tscale_type == TScaleType.T64
    lin.tscale_to(TScaleType.T32)
    assert lin.tscale_type == TScaleType.T32
    dq_after = lin.dequantize()
    assert dq_before.shape == dq_after.shape
    lin.tscale_to(TScaleType.T4)
    assert lin.tscale_type == TScaleType.T4
    dq_t4 = lin.dequantize()
    assert dq_t4.shape == dq_before.shape
    print(" PASS test_tscale_to_switching")


def test_tscale_cast_alias():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
    result = lin.tscale_cast(TScaleType.T8)
    assert result is lin, "tscale_cast should return self"
    assert lin.tscale_type == TScaleType.T8
    print(" PASS test_tscale_cast_alias")


def test_tscale_gradient_flow():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    x = torch.randn(2, 10, 32)
    x.requires_grad_(True)
    out = lin(x)
    out.sum().backward()
    assert x.grad is not None, "No gradient on input"
    print(" PASS test_tscale_gradient_flow")


def test_tscale_all_types_forward():
    for tscale_type in TScaleType:
        lin = TernaryScaleTensor(96, 16, tscale_type=tscale_type)
        x = torch.randn(2, 4, 96)
        out = lin(x)
        assert out.shape == (2, 4, 16), f"{tscale_type.name}: shape {out.shape}"
        assert torch.isfinite(out).all(), f"{tscale_type.name}: non-finite output"
    print(" PASS test_tscale_all_types_forward")


def test_tscale_dequantize():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    w_eff = lin.dequantize()
    assert w_eff.shape == (16, 32), f"Shape: {w_eff.shape}"
    assert torch.isfinite(w_eff).all()
    print(" PASS test_tscale_dequantize")


def test_tscale_effective_bpw():
    lin64 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T64)
    lin4 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T4)
    assert lin4.effective_bpw > lin64.effective_bpw, "T4 (gs=4) should have higher BPW than T64 (gs=64)"
    print(f"   T64 BPW: {lin64.effective_bpw:.2f}, T4 BPW: {lin4.effective_bpw:.2f}")
    print(" PASS test_tscale_effective_bpw")


def test_tscale_model_integration():
    if not _cuda_available():
        print(" SKIP test_tscale_model_integration (need CUDA + >10GB GPU)")
        return
    for tscale_type in [TScaleType.T64, TScaleType.T32, TScaleType.T8]:
        model = ARBModel(tscale_type=tscale_type).to("cuda")
        x = torch.randint(0, VOCAB, (2, 10), device="cuda")
        logits, losses, _, _ = model(x, targets=x[:, 3:])
        assert losses is not None
        losses.total.backward()
    print(" PASS test_tscale_model_integration")


def test_tscale_runtime_switch():
    if not _cuda_available():
        print(" SKIP test_tscale_runtime_switch (need CUDA + >10GB GPU)")
        return
    model = ARBModel(tscale_type=TScaleType.T64).to("cuda")
    x = torch.randint(0, VOCAB, (1, 10), device="cuda")

    logits64, _, _, _ = model(x)
    for module in model.modules():
        if isinstance(module, TernaryScaleTensor):
            module.tscale_to(TScaleType.T4)
    logits4, _, _, _ = model(x)

    assert torch.isfinite(logits4).all(), "Non-finite after tscale.to(T4)"
    assert logits4.shape == logits64.shape, "Shape mismatch after tscale switch"
    print(" PASS test_tscale_runtime_switch")


# ─── SignSGD Tests ───

def test_sign_sgd_step():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    x = torch.randn(2, 10)
    loss = model(x).sum()
    loss.backward()
    w_before = model.weight.clone()
    optimizer.step()
    assert not torch.equal(model.weight, w_before), "Weights did not change"
    print(" PASS test_sign_sgd_step")


def test_sign_sgd_no_momentum():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    assert len(optimizer.state) == 0, "SignSGD should have no state (no momentum)"
    print(" PASS test_sign_sgd_no_momentum")


def test_sign_sgd_memory():
    model = torch.nn.Linear(100, 100)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    mem = optimizer.get_memory_mb()
    assert mem > 0, "Memory should be positive"
    print(f"   SignSGD memory: {mem:.2f} MB")
    print(" PASS test_sign_sgd_memory")


def test_sign_sgd_with_tscale_model():
    if not _cuda_available():
        print(" SKIP test_sign_sgd_with_tscale_model (need CUDA + >10GB GPU)")
        return
    model = ARBModel(tscale_type=TScaleType.T32).to("cuda")
    x = torch.randint(0, VOCAB, (2, 10), device="cuda")
    logits, losses, _, _ = model(x, targets=x[:, 3:])
    losses.total.backward()
    model._ternary_update_memory()
    print(" PASS test_sign_sgd_with_tscale_model")


def test_sign_sgd_weight_decay():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01, weight_decay=0.01)
    x = torch.randn(2, 10)
    loss = model(x).sum()
    loss.backward()
    w_before = model.weight.clone()
    optimizer.step()
    w_diff = (model.weight - w_before).abs().sum().item()
    assert w_diff > 0, "Weights should change with weight_decay"
    print(" PASS test_sign_sgd_weight_decay")


# ─── TileLang PyTorch Reference Tests ───

def test_dequant_gemm_pytorch_ref():
    import importlib.util
    kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
    if not os.path.exists(kernel_path):
        print(" SKIP test_dequant_gemm_pytorch_ref (tilelang reference file missing)")
        return
    spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref

    M, N, K, group_size = 4, 8, 96, 12
    signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
    exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
    x = torch.randn(M, K, dtype=torch.float16)

    output = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)
    assert output.shape == (M, N), f"Shape: {output.shape}"
    assert torch.isfinite(output).all(), "Non-finite output"
    print(" PASS test_dequant_gemm_pytorch_ref")


def test_dequant_gemm_matches_manual():
    import importlib.util
    import torch.nn.functional as F
    kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
    if not os.path.exists(kernel_path):
        print(" SKIP test_dequant_gemm_matches_manual (tilelang reference file missing)")
        return
    spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref

    M, N, K, group_size = 2, 4, 48, 12
    signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
    exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
    x = torch.randn(M, K, dtype=torch.float16)

    result = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)

    exp_expanded = exponents.repeat_interleave(group_size, dim=1)
    pos_mask = exp_expanded >= 0
    two_pow = torch.where(pos_mask,
                           (1 << exp_expanded.to(torch.int32)).to(torch.float16),
                           (1 >> (-exp_expanded.to(torch.int32))).to(torch.float16))
    w = signs.to(torch.float16) * two_pow
    expected = x @ w.t()

    assert torch.allclose(result, expected, atol=1e-3), "PyTorch ref mismatch"
    print(" PASS test_dequant_gemm_matches_manual")


# ─── Integration: SignSGD + TernaryScaleTensor training step ───

def test_full_training_step():
    if not _cuda_available():
        print(" SKIP test_full_training_step (need CUDA + >10GB GPU)")
        return
    model = ARBModel(tscale_type=TScaleType.T32).to("cuda")
    x = torch.randint(0, VOCAB, (2, 10), device="cuda")
    logits, losses, _, _ = model(x, targets=x[:, 3:])
    losses.total.backward()
    model._ternary_update_memory()

    logits2, losses2, _, _ = model(x, targets=x[:, 3:])
    assert torch.isfinite(losses2.total), "Non-finite loss after step"
    print(" PASS test_full_training_step")


def test_multiple_steps_converge():
    if not _cuda_available():
        print(" SKIP test_multiple_steps_converge (need CUDA + >10GB GPU)")
        return
    model = ARBModel(tscale_type=TScaleType.T32).to("cuda")
    x = torch.randint(0, VOCAB, (4, 10), device="cuda")
    losses = []
    for step in range(50):
        logits, losses_out, _, _ = model(x, targets=x[:, 3:])
        loss_val = losses_out.total
        loss_val.backward()
        model._ternary_update_memory(accum_threshold=3)
        losses.append(loss_val.item())

    assert torch.isfinite(torch.tensor(losses)).all(), "Non-finite loss during training"
    print(f"   Loss range: {min(losses):.4f} – {max(losses):.4f} over 50 steps")
    print(" PASS test_multiple_steps_converge")


def test_cuda_triton_correctness_linear():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_linear (CUDA/Triton unavailable)")
        return
    from arbitor.kernel.ternary_scale import TernaryRMSNorm, _triton_ternary_embed
    from arbitor.main import ByteEmbedding
    ATOL = 2e-3
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        x = torch.randn(4, 4, 32, requires_grad=True)
        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())
        cpu_out = lin_cpu(x)
        grad_out = torch.randn_like(cpu_out)
        cpu_out.backward(grad_out)
        cpu_grad_x = x.grad.clone()

        x_gpu = x.detach().clone().cuda().requires_grad_(True)
        gpu_out = lin_gpu(x_gpu)
        gpu_out.backward(grad_out.cuda())
        gpu_grad_x = x_gpu.grad.clone()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
        assert fwd_diff < ATOL, f"{tt.name} fwd_diff={fwd_diff}"
        assert bwd_diff < ATOL, f"{tt.name} bwd_diff={bwd_diff}"
    print(" PASS test_cuda_triton_correctness_linear")


def test_cuda_triton_correctness_rmsnorm():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_rmsnorm (CUDA/Triton unavailable)")
        return
    from arbitor.kernel.ternary_scale import TernaryRMSNorm
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        norm_cpu = TernaryRMSNorm(256, tscale_type=tt)
        x = torch.randn(2, 4, 256, requires_grad=True)
        cpu_out = norm_cpu(x)
        cpu_out.sum().backward()
        cpu_grad_x = x.grad.clone()

        norm_gpu = TernaryRMSNorm(256, tscale_type=tt).cuda()
        norm_gpu.load_state_dict(norm_cpu.state_dict())
        x_gpu = x.detach().clone().cuda().requires_grad_(True)
        gpu_out = norm_gpu(x_gpu)
        gpu_out.sum().backward()
        gpu_grad_x = x_gpu.grad.clone()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
        assert fwd_diff < 1e-5, f"{tt.name} rmsnorm fwd_diff={fwd_diff}"
        assert bwd_diff < 1e-5, f"{tt.name} rmsnorm bwd_diff={bwd_diff}"
    print(" PASS test_cuda_triton_correctness_rmsnorm")


def test_cuda_triton_correctness_embedding():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_embedding (CUDA/Triton unavailable)")
        return
    from arbitor.main import ByteEmbedding
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        emb_cpu = ByteEmbedding(tscale_type=tt)
        x = torch.tensor([0, 1, 2, 5, 10])
        cpu_out = emb_cpu(x)
        cpu_out.sum().backward()

        emb_gpu = ByteEmbedding(tscale_type=tt).cuda()
        emb_gpu.load_state_dict(emb_cpu.state_dict())
        x_gpu = x.cuda()
        gpu_out = emb_gpu(x_gpu)
        gpu_out.sum().backward()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        assert fwd_diff < 1e-5, f"{tt.name} embed fwd_diff={fwd_diff}"
        if hasattr(emb_cpu, '_hook_grad_T_sign') and hasattr(emb_gpu, '_hook_grad_T_sign'):
            gs_match = (emb_gpu._hook_grad_T_sign.cpu() == emb_cpu._hook_grad_T_sign).float().mean().item()
            assert gs_match > 0.99, f"{tt.name} embed grad_sign match={gs_match}"
    print(" PASS test_cuda_triton_correctness_embedding")


def test_cuda_triton_correctness_update_E():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_update_E (CUDA/Triton unavailable)")
        return
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())

        x_cpu = torch.randn(4, 4, 32, requires_grad=True)
        x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)

        cpu_out = lin_cpu(x_cpu)
        cpu_out.sum().backward()
        E_cpu = lin_cpu.E.clone()
        corr_cpu = lin_cpu.corr_accum.clone()
        step_cpu = lin_cpu.step_counter.clone()
        gpu_out = lin_gpu(x_gpu)
        gpu_out.sum().backward()
        E_gpu = lin_gpu.E.clone()
        corr_gpu = lin_gpu.corr_accum.clone()
        step_gpu = lin_gpu.step_counter.clone()

        # E is fixed after init; BigInt corr_accum carries the continuous scale adjustment.
        E_diff = (E_cpu.float() - E_gpu.cpu().float()).abs().max().item()
        assert E_diff < 0.01, f"{tt.name} CPU-GPU E update mismatch: {E_diff}"
        corr_diff = (corr_cpu - corr_gpu.cpu()).abs().max().item()
        assert corr_diff == 0, f"{tt.name} CPU-GPU corr_accum update mismatch: {corr_diff}"
        assert int(step_cpu.item()) == int(step_gpu.cpu().item()) == 1, \
            f"{tt.name} CPU-GPU step_counter mismatch: cpu={step_cpu.item()} gpu={step_gpu.cpu().item()}"
    print(" PASS test_cuda_triton_correctness_update_E")


def test_cuda_triton_correctness_ternary_step():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_ternary_step (CUDA/Triton unavailable)")
        return
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())

        x_cpu = torch.randn(4, 4, 32, requires_grad=True)
        x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)

        cpu_out = lin_cpu(x_cpu)
        cpu_out.sum().backward()
        lin_cpu.ternary_step(accum_threshold=3)
        T_cpu = lin_cpu._get_T().clone()
        corr_cpu = lin_cpu.corr_accum.clone()

        gpu_out = lin_gpu(x_gpu)
        gpu_out.sum().backward()
        lin_gpu.ternary_step(accum_threshold=3)
        T_gpu = lin_gpu._get_T().clone()
        corr_gpu = lin_gpu.corr_accum.clone()

        T_match = (T_cpu == T_gpu.cpu()).float().mean().item()
        corr_match = (corr_cpu == corr_gpu.cpu()).float().mean().item()
        assert T_match == 1.0, f"{tt.name} T_match={T_match}"
        assert corr_match == 1.0, f"{tt.name} corr_match={corr_match}"
    print(" PASS test_cuda_triton_correctness_ternary_step")


def test_cuda_triton_tscale_path():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_tscale_path (CUDA/Triton unavailable)")
        return

    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
    x = torch.randn(2, 4, 32, device="cuda", requires_grad=True)
    out = lin(x)
    assert out.is_cuda, "Triton path should produce CUDA output"
    assert out.shape == (2, 4, 16), f"Shape: {out.shape}"
    grad_out = torch.randn_like(out)
    out.backward(grad_out)
    assert x.grad is not None and x.grad.is_cuda, "CUDA grad_x missing"
    assert lin.corr_accum.abs().sum().item() > 0, \
        "Triton path should stream updates into int64 corr_accum"
    assert int(lin.step_counter.item()) == 1, "Triton path should advance the BigInt step counter"
    assert not hasattr(lin, "_hook_grad_T_sign"), \
        "Triton path should not retain full weight-shaped grad-sign hooks"
    assert not hasattr(lin, "_hook_grad_2d") and not hasattr(lin, "_hook_x_2d"), \
        "Triton path should not retain fp32 grad/x views"
    torch.cuda.synchronize()
    assert not hasattr(lin, "_hook_grad_T_sign"), \
        "No retained grad-sign hook should remain after streaming backward"
    assert lin.T_packed.is_cuda and lin.E.is_cuda, "Ternary buffers moved off CUDA after update"

    lin_force = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
    lin_force._hook_grad_2d = torch.ones(2, 16, device="cuda")
    lin_force._hook_x_2d = torch.ones(2, 32, device="cuda")
    lin_force.update_E()
    forced_T = lin_force._get_T()
    assert forced_T.is_cuda, "Unpacked CUDA ternary state should stay on CUDA"
    assert lin_force.corr_accum.abs().sum().item() > 0, "Forced CUDA hook should update BigInt corr_accum"
    assert int(lin_force.step_counter.item()) == 1, "Forced CUDA hook should advance the BigInt step counter"
    print(" PASS test_cuda_triton_tscale_path")


def test_small_ternary_training_loss_finite():
    if not torch.cuda.is_available():
        print(" SKIP test_small_ternary_training_loss_finite (CUDA unavailable)")
        return
    model = ARBModel(
        enable_image=False,
        enable_audio=False,
        enable_vq=False,
        enable_graph=False,
        enable_memory_modules=False,
        enable_moe=False,
        tscale_type=TScaleType.T32,
    ).cuda()
    x = torch.randint(0, VOCAB, (1, 4), device="cuda")
    _, losses, _, _ = model(x, targets=x[:, 3:])
    assert torch.isfinite(losses.total), "Small ternary training loss is non-finite"
    model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_components=losses)
    leftovers = [
        name for name, module in model.named_modules()
        if any(hasattr(module, hook) for hook in ("_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d"))
    ]
    assert not leftovers, f"Ternary update left stale hooks: {leftovers[:5]}"
    print(" PASS test_small_ternary_training_loss_finite")


def test_ternary_update_rejects_nonfinite_loss():
    import warnings
    model = ARBModel(
        enable_image=False,
        enable_audio=False,
        enable_vq=False,
        enable_graph=False,
        enable_memory_modules=False,
        enable_moe=False,
        tscale_type=TScaleType.T32,
    )
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        lc = LossComponents(lm=torch.tensor(float("nan")))
        model._ternary_update_memory(loss_components=lc)
        assert len(w) > 0, "Expected a warning for non-finite loss"
        assert "Non-finite loss" in str(w[0].message), f"Unexpected warning: {w[0].message}"
    print(" PASS test_ternary_update_rejects_nonfinite_loss")


# ─── Phase 12: E Gradient Field + Statistical Metrics Tests ───

def test_e_rms_weighted_delta():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    grad = torch.randn(4, 16)
    x = torch.randn(4, 32)
    raw_grad = grad.T @ x
    # compute RMS per group
    gpr = (32 + lin.group_size - 1) // lin.group_size
    rms_per_group = []
    for g in range(gpr):
        start = g * lin.group_size
        end = min(start + lin.group_size, 32)
        group = raw_grad[:, start:end]
        rms = group.pow(2).mean().sqrt().item()
        rms_per_group.append(rms)
    rms = rms_per_group[0]
    score = (raw_grad * lin._get_T().float()).sum().item()
    delta = - (1 if score > 0 else -1 if score < 0 else 0) * max(1, min(3, round(math.log2(1 + rms))))
    assert 1 <= abs(delta) <= 4, f"delta magnitude {abs(delta)} out of range"
    print(" PASS test_e_rms_weighted_delta")


def test_e_rms_vs_sign_only():
    # RMS-weighted delta differs for different gradient magnitudes even when sign is same
    raw_low = torch.ones(16, 32) * 0.1
    raw_high = torch.ones(16, 32) * 10.0
    T = torch.ones(16, 32)
    rms_low = raw_low.pow(2).mean().sqrt()
    rms_high = raw_high.pow(2).mean().sqrt()
    delta_low = max(1, min(3, round(math.log2(1 + rms_low.item()))))
    delta_high = max(1, min(3, round(math.log2(1 + rms_high.item()))))
    assert delta_low != delta_high, "RMS delta should differ for different magnitudes"
    assert delta_low < delta_high, "Higher RMS should give larger delta"
    print(" PASS test_e_rms_vs_sign_only")


def test_e_zscore_normalization():
    comp_a_rms = torch.tensor([10.0, 12.0, 8.0, 11.0])
    comp_b_rms = torch.tensor([1.0, 1.2, 0.8, 1.1])
    z_a = (comp_a_rms - comp_a_rms.mean()) / (comp_a_rms.std() + 1e-8)
    z_b = (comp_b_rms - comp_b_rms.mean()) / (comp_b_rms.std() + 1e-8)
    assert abs(z_a.mean().item()) < 1e-6, f"z_a mean not ~0: {z_a.mean().item()}"
    assert abs(z_b.mean().item()) < 1e-6, f"z_b mean not ~0: {z_b.mean().item()}"
    assert abs(z_a.std().item() - 1.0) < 0.1, f"z_a std not ~1: {z_a.std().item()}"
    assert abs(z_b.std().item() - 1.0) < 0.1, f"z_b std not ~1: {z_b.std().item()}"
    print(" PASS test_e_zscore_normalization")


def test_e_zscore_zero_std():
    rms_flat = torch.ones(8) * 5.0
    z = torch.where(rms_flat.std() > 1e-8, (rms_flat - rms_flat.mean()) / (rms_flat.std()), torch.zeros_like(rms_flat))
    assert torch.isfinite(z).all(), "z-scores should be finite when std=0"
    assert (z == 0).all(), "z-scores should be zero when std=0"
    print(" PASS test_e_zscore_zero_std")


def test_group_lr_registration():
    tst = TernaryScaleTensor(32, 16)
    assert hasattr(tst, "corr_accum")
    assert tst.corr_accum.dtype == torch.int64
    assert tst.corr_accum.shape == tst.E.shape
    assert int(tst.step_counter.item()) == 0
    be = ByteEmbedding()
    assert hasattr(be, "corr_accum")
    assert be.corr_accum.dtype == torch.int64
    assert be.corr_accum.shape[0] > 0
    assert int(be.step_counter.item()) == 0
    rms = TernaryRMSNorm(256)
    assert hasattr(rms, "E")
    print(" PASS test_group_lr_registration")


def test_group_lr_effect():
    delta = torch.tensor(4, dtype=torch.int8)
    group_lr_high = torch.tensor(8, dtype=torch.int8)
    group_lr_low = torch.tensor(1, dtype=torch.int8)
    eff_high = delta.to(torch.int16) * group_lr_high.to(torch.int16) // 8
    eff_low = delta.to(torch.int16) * group_lr_low.to(torch.int16) // 8
    assert eff_high.item() == 4, f"high LR should give full delta, got {eff_high.item()}"
    assert eff_low.item() == 0, f"low LR should give 0 delta, got {eff_low.item()}"
    print(" PASS test_group_lr_effect")


def test_group_lr_dynamic_update():
    group_lr = torch.ones(4, dtype=torch.int8)
    rms_prev = torch.tensor([1.0, 5.0, 3.0, 2.0])
    rms_curr = torch.tensor([2.0, 3.0, 3.0, 1.0])
    rms_growth = rms_curr - rms_prev
    updated = torch.clamp(group_lr.to(torch.int16) + (rms_growth > 0).to(torch.int16) - (rms_growth < 0).to(torch.int16), 1, 8).to(torch.int8)
    assert updated[0].item() == 2, f"RMS increased -> LR should increase, got {updated[0].item()}"
    assert updated[1].item() == 1, f"RMS decreased -> LR should decrease, got {updated[1].item()}"
    assert updated[2].item() == 1, f"RMS unchanged -> LR unchanged, got {updated[2].item()}"
    # clamp boundaries
    too_high = torch.clamp(torch.tensor([100], dtype=torch.int16), 1, 8)
    too_low = torch.clamp(torch.tensor([-100], dtype=torch.int16), 1, 8)
    assert too_high.item() == 8, f"clamp max, got {too_high.item()}"
    assert too_low.item() == 1, f"clamp min, got {too_low.item()}"
    print(" PASS test_group_lr_dynamic_update")


def test_e_stats_cpu_fallback():
    N, K, group_size = 16, 32, 12
    grad = torch.randn(4, N)
    x = torch.randn(4, K)
    raw_grad = grad.T @ x
    gpr = (K + group_size - 1) // group_size
    rms_vals = []
    for g in range(gpr):
        start = g * group_size
        end = min(start + group_size, K)
        group = raw_grad[:, start:end]
        rms = group.pow(2).mean().sqrt()
        rms_vals.append(rms.item())
    assert all(torch.isfinite(torch.tensor(rms_vals))), "finite check"
    assert all(1 <= max(1, min(3, round(math.log2(1 + r)))) <= 3 for r in rms_vals), "clamp range"
    print(" PASS test_e_stats_cpu_fallback")


def test_e_per_component_routing():
    if not _cuda_available():
        print(" SKIP test_e_per_component_routing (CUDA)")
        return
    model = ARBModel(enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False).cuda()
    x = torch.randint(0, VOCAB, (1, 4), device="cuda")
    for step in range(3):
        _, lc, _, _ = model(x, targets=x[:, 3:])
        model._ternary_update_memory(loss_components=lc)
    assert True  # no crash = pass
    print(" PASS test_e_per_component_routing")


def test_ensure_group_lr_backward_compat():
    tst = TernaryScaleTensor(32, 16)
    assert hasattr(tst, "corr_accum")
    assert hasattr(tst, "step_counter")
    be = ByteEmbedding()
    assert hasattr(be, "corr_accum")
    assert hasattr(be, "step_counter")
    rms = TernaryRMSNorm(256)
    assert hasattr(rms, "E")
    print(" PASS test_ensure_group_lr_backward_compat")


# ─── Main ───

if __name__ == "__main__":
    tests = [
        test_tscale_shape,
        test_tscale_ternary_output,
        test_tscale_T64_per_element_s,
        test_tscale_T32_group_s,
        test_tscale_to_switching,
        test_tscale_cast_alias,
        test_tscale_gradient_flow,
        test_tscale_all_types_forward,
        test_tscale_dequantize,
        test_tscale_effective_bpw,
        test_tscale_model_integration,
        test_tscale_runtime_switch,
        test_sign_sgd_step,
        test_sign_sgd_no_momentum,
        test_sign_sgd_memory,
        test_sign_sgd_with_tscale_model,
        test_sign_sgd_weight_decay,
        test_dequant_gemm_pytorch_ref,
        test_dequant_gemm_matches_manual,
        test_cuda_triton_correctness_linear,
        test_cuda_triton_correctness_rmsnorm,
        test_cuda_triton_correctness_embedding,
        test_cuda_triton_correctness_update_E,
        test_cuda_triton_correctness_ternary_step,
        test_cuda_triton_tscale_path,
        test_small_ternary_training_loss_finite,
        test_ternary_update_rejects_nonfinite_loss,
        test_full_training_step,
        test_multiple_steps_converge,
        test_e_rms_weighted_delta,
        test_e_rms_vs_sign_only,
        test_e_zscore_normalization,
        test_e_zscore_zero_std,
        test_group_lr_registration,
        test_group_lr_effect,
        test_group_lr_dynamic_update,
        test_e_stats_cpu_fallback,
        test_e_per_component_routing,
        test_ensure_group_lr_backward_compat,
    ]
    print("Running TernaryScale + SignSGD + TileLang Phase 2 tests...\n")
    passed = 0
    failed = 0
    for test in tests:
        try:
            test()
            passed += 1
        except Exception as e:
            print(f" FAIL {test.__name__}: {e}")
            import traceback
            traceback.print_exc()
            failed += 1
    print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")