ARBS / testing /kg /test_composite_head.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for KGVQCodebook and CompositeProposalHead (Phase 17 Plan 02)."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.components import KGVQCodebook, CompositeProposalHead
from arbitor.config import HIDDEN_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES
TEST_KGVQ_SIZE = 4096
def test_kgvq_shape_and_range():
kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
x = torch.randn(2, K_MAX_COMPOSITES, KGVQ_CODEBOOK_DIM)
quantized, indices, loss = kgvq(x)
assert quantized.shape == x.shape, f"shape {quantized.shape}"
assert indices.shape == (2, K_MAX_COMPOSITES)
assert indices.max() < TEST_KGVQ_SIZE
assert indices.min() >= 0
assert torch.isfinite(loss)
print(" PASS test_kgvq_shape_and_range")
def test_kgvq_ema_update():
kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
x = torch.randn(2, 5, KGVQ_CODEBOOK_DIM)
cluster_before = kgvq.cluster_size.clone()
_, _, _ = kgvq(x)
cluster_after = kgvq.cluster_size
assert not torch.equal(cluster_before, cluster_after), "cluster usage should update"
assert kgvq.cluster_size.sum().item() > 0, "cluster_size should accumulate"
print(" PASS test_kgvq_ema_update")
def test_kgvq_dead_code_reset():
kgvq = KGVQCodebook(codebook_size=TEST_KGVQ_SIZE)
kgvq.cluster_size.zero_()
x = torch.randn(1, 1, KGVQ_CODEBOOK_DIM)
_, _, _ = kgvq(x)
# After one forward, the used entry has cluster_size=0.01 (EMA)
used = (kgvq.cluster_size > 0).sum().item()
assert used >= 1, "at least 1 entry should be used after forward"
dead = (kgvq.cluster_size < kgvq.threshold_ema_dead_code).sum().item()
# Reset shouldn't fire during warmup (n_initialized < codebook_size/4)
before = kgvq.cluster_size.clone()
kgvq._dead_code_reset(x.reshape(-1, KGVQ_CODEBOOK_DIM))
assert torch.equal(before, kgvq.cluster_size), "dead code should not fire during warmup"
print(" PASS test_kgvq_dead_code_reset")
def test_composite_head_variable_count():
head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE)
pool_out = torch.randn(2, HIDDEN_DIM)
composite_ids, vq_loss, halt = head(pool_out)
assert composite_ids.shape == (2, K_MAX_COMPOSITES)
assert halt.shape == (2, K_MAX_COMPOSITES)
assert (halt > 0).all() and (halt < 1).all()
assert (composite_ids == -1).any(), "at least one halted"
assert torch.isfinite(vq_loss)
print(" PASS test_composite_head_variable_count")
def test_composite_head_diversity_loss():
head = CompositeProposalHead(codebook_size=TEST_KGVQ_SIZE)
pool_out = torch.ones(2, HIDDEN_DIM)
_, vq_loss1, _ = head(pool_out)
pool_out2 = torch.randn(2, HIDDEN_DIM)
_, vq_loss2, _ = head(pool_out2)
assert torch.isfinite(vq_loss1) and torch.isfinite(vq_loss2)
print(" PASS test_composite_head_diversity_loss")
if __name__ == "__main__":
test_kgvq_shape_and_range()
test_kgvq_ema_update()
test_kgvq_dead_code_reset()
test_composite_head_variable_count()
test_composite_head_diversity_loss()
print("\nAll composite head tests PASS")