File size: 3,163 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Unit tests for KGVQCodebook and CompositeProposalHead (Phase 17 Plan 02)."""
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.components import KGVQCodebook, CompositeProposalHead
from arbitor.config import HIDDEN_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES


TEST_KGVQ_SIZE = 4096


def test_kgvq_shape_and_range():
    kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
    x = torch.randn(2, K_MAX_COMPOSITES, KGVQ_CODEBOOK_DIM)
    quantized, indices, loss = kgvq(x)
    assert quantized.shape == x.shape, f"shape {quantized.shape}"
    assert indices.shape == (2, K_MAX_COMPOSITES)
    assert indices.max() < TEST_KGVQ_SIZE
    assert indices.min() >= 0
    assert torch.isfinite(loss)
    print(" PASS test_kgvq_shape_and_range")


def test_kgvq_ema_update():
    kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
    x = torch.randn(2, 5, KGVQ_CODEBOOK_DIM)
    cluster_before = kgvq.cluster_size.clone()
    _, _, _ = kgvq(x)
    cluster_after = kgvq.cluster_size
    assert not torch.equal(cluster_before, cluster_after), "cluster usage should update"
    assert kgvq.cluster_size.sum().item() > 0, "cluster_size should accumulate"
    print(" PASS test_kgvq_ema_update")


def test_kgvq_dead_code_reset():
    kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
    kgvq.cluster_size.zero_()
    x = torch.randn(1, 1, KGVQ_CODEBOOK_DIM)
    _, _, _ = kgvq(x)
    # After one forward, the used entry has cluster_size=0.01 (EMA)
    used = (kgvq.cluster_size > 0).sum().item()
    assert used >= 1, "at least 1 entry should be used after forward"
    dead = (kgvq.cluster_size < kgvq.threshold_ema_dead_code).sum().item()
    # Reset shouldn't fire during warmup (n_initialized < codebook_size/4)
    before = kgvq.cluster_size.clone()
    kgvq._dead_code_reset(x.reshape(-1, KGVQ_CODEBOOK_DIM))
    assert torch.equal(before, kgvq.cluster_size), "dead code should not fire during warmup"
    print(" PASS test_kgvq_dead_code_reset")


def test_composite_head_variable_count():
    head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE)
    pool_out = torch.randn(2, HIDDEN_DIM)
    composite_ids, vq_loss, halt = head(pool_out)
    assert composite_ids.shape == (2, K_MAX_COMPOSITES)
    assert halt.shape == (2, K_MAX_COMPOSITES)
    assert (halt > 0).all() and (halt < 1).all()
    assert (composite_ids == -1).any(), "at least one halted"
    assert torch.isfinite(vq_loss)
    print(" PASS test_composite_head_variable_count")


def test_composite_head_diversity_loss():
    head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE)
    pool_out = torch.ones(2, HIDDEN_DIM)
    _, vq_loss1, _ = head(pool_out)
    pool_out2 = torch.randn(2, HIDDEN_DIM)
    _, vq_loss2, _ = head(pool_out2)
    assert torch.isfinite(vq_loss1) and torch.isfinite(vq_loss2)
    print(" PASS test_composite_head_diversity_loss")


if __name__ == "__main__":
    test_kgvq_shape_and_range()
    test_kgvq_ema_update()
    test_kgvq_dead_code_reset()
    test_composite_head_variable_count()
    test_composite_head_diversity_loss()
    print("\nAll composite head tests PASS")