| """Unit tests for KGVQCodebook and CompositeProposalHead (Phase 17 Plan 02).""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.components import KGVQCodebook, CompositeProposalHead |
| from arbitor.config import HIDDEN_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES |
|
|
|
|
| TEST_KGVQ_SIZE = 4096 |
|
|
|
|
| def test_kgvq_shape_and_range(): |
| kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) |
| x = torch.randn(2, K_MAX_COMPOSITES, KGVQ_CODEBOOK_DIM) |
| quantized, indices, loss = kgvq(x) |
| assert quantized.shape == x.shape, f"shape {quantized.shape}" |
| assert indices.shape == (2, K_MAX_COMPOSITES) |
| assert indices.max() < TEST_KGVQ_SIZE |
| assert indices.min() >= 0 |
| assert torch.isfinite(loss) |
| print(" PASS test_kgvq_shape_and_range") |
|
|
|
|
| def test_kgvq_ema_update(): |
| kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) |
| x = torch.randn(2, 5, KGVQ_CODEBOOK_DIM) |
| cluster_before = kgvq.cluster_size.clone() |
| _, _, _ = kgvq(x) |
| cluster_after = kgvq.cluster_size |
| assert not torch.equal(cluster_before, cluster_after), "cluster usage should update" |
| assert kgvq.cluster_size.sum().item() > 0, "cluster_size should accumulate" |
| print(" PASS test_kgvq_ema_update") |
|
|
|
|
| def test_kgvq_dead_code_reset(): |
| kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) |
| kgvq.cluster_size.zero_() |
| x = torch.randn(1, 1, KGVQ_CODEBOOK_DIM) |
| _, _, _ = kgvq(x) |
| |
| used = (kgvq.cluster_size > 0).sum().item() |
| assert used >= 1, "at least 1 entry should be used after forward" |
| dead = (kgvq.cluster_size < kgvq.threshold_ema_dead_code).sum().item() |
| |
| before = kgvq.cluster_size.clone() |
| kgvq._dead_code_reset(x.reshape(-1, KGVQ_CODEBOOK_DIM)) |
| assert torch.equal(before, kgvq.cluster_size), "dead code should not fire during warmup" |
| print(" PASS test_kgvq_dead_code_reset") |
|
|
|
|
| def test_composite_head_variable_count(): |
| head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE) |
| pool_out = torch.randn(2, HIDDEN_DIM) |
| composite_ids, vq_loss, halt = head(pool_out) |
| assert composite_ids.shape == (2, K_MAX_COMPOSITES) |
| assert halt.shape == (2, K_MAX_COMPOSITES) |
| assert (halt > 0).all() and (halt < 1).all() |
| assert (composite_ids == -1).any(), "at least one halted" |
| assert torch.isfinite(vq_loss) |
| print(" PASS test_composite_head_variable_count") |
|
|
|
|
| def test_composite_head_diversity_loss(): |
| head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE) |
| pool_out = torch.ones(2, HIDDEN_DIM) |
| _, vq_loss1, _ = head(pool_out) |
| pool_out2 = torch.randn(2, HIDDEN_DIM) |
| _, vq_loss2, _ = head(pool_out2) |
| assert torch.isfinite(vq_loss1) and torch.isfinite(vq_loss2) |
| print(" PASS test_composite_head_diversity_loss") |
|
|
|
|
| if __name__ == "__main__": |
| test_kgvq_shape_and_range() |
| test_kgvq_ema_update() |
| test_kgvq_dead_code_reset() |
| test_composite_head_variable_count() |
| test_composite_head_diversity_loss() |
| print("\nAll composite head tests PASS") |
|
|