"""Unit tests for KGVQCodebook and CompositeProposalHead (Phase 17 Plan 02).""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.components import KGVQCodebook, CompositeProposalHead from arbitor.config import HIDDEN_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES TEST_KGVQ_SIZE = 4096 def test_kgvq_shape_and_range(): kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) x = torch.randn(2, K_MAX_COMPOSITES, KGVQ_CODEBOOK_DIM) quantized, indices, loss = kgvq(x) assert quantized.shape == x.shape, f"shape {quantized.shape}" assert indices.shape == (2, K_MAX_COMPOSITES) assert indices.max() < TEST_KGVQ_SIZE assert indices.min() >= 0 assert torch.isfinite(loss) print(" PASS test_kgvq_shape_and_range") def test_kgvq_ema_update(): kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) x = torch.randn(2, 5, KGVQ_CODEBOOK_DIM) cluster_before = kgvq.cluster_size.clone() _, _, _ = kgvq(x) cluster_after = kgvq.cluster_size assert not torch.equal(cluster_before, cluster_after), "cluster usage should update" assert kgvq.cluster_size.sum().item() > 0, "cluster_size should accumulate" print(" PASS test_kgvq_ema_update") def test_kgvq_dead_code_reset(): kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE) kgvq.cluster_size.zero_() x = torch.randn(1, 1, KGVQ_CODEBOOK_DIM) _, _, _ = kgvq(x) # After one forward, the used entry has cluster_size=0.01 (EMA) used = (kgvq.cluster_size > 0).sum().item() assert used >= 1, "at least 1 entry should be used after forward" dead = (kgvq.cluster_size < kgvq.threshold_ema_dead_code).sum().item() # Reset shouldn't fire during warmup (n_initialized < codebook_size/4) before = kgvq.cluster_size.clone() kgvq._dead_code_reset(x.reshape(-1, KGVQ_CODEBOOK_DIM)) assert torch.equal(before, kgvq.cluster_size), "dead code should not fire during warmup" print(" PASS test_kgvq_dead_code_reset") def test_composite_head_variable_count(): head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE) pool_out = torch.randn(2, HIDDEN_DIM) composite_ids, vq_loss, halt = head(pool_out) assert composite_ids.shape == (2, K_MAX_COMPOSITES) assert halt.shape == (2, K_MAX_COMPOSITES) assert (halt > 0).all() and (halt < 1).all() assert (composite_ids == -1).any(), "at least one halted" assert torch.isfinite(vq_loss) print(" PASS test_composite_head_variable_count") def test_composite_head_diversity_loss(): head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE) pool_out = torch.ones(2, HIDDEN_DIM) _, vq_loss1, _ = head(pool_out) pool_out2 = torch.randn(2, HIDDEN_DIM) _, vq_loss2, _ = head(pool_out2) assert torch.isfinite(vq_loss1) and torch.isfinite(vq_loss2) print(" PASS test_composite_head_diversity_loss") if __name__ == "__main__": test_kgvq_shape_and_range() test_kgvq_ema_update() test_kgvq_dead_code_reset() test_composite_head_variable_count() test_composite_head_diversity_loss() print("\nAll composite head tests PASS")