ARBS / testing /model /test_flash.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
FlashVQ Correctness Tests — CPU path, GPU path, and CPU vs GPU equivalence.
Test structure follows testing/test_tscale.py pattern:
- Each test is a standalone function
- Manual runner at bottom for direct execution
- CUDA/Triton tests skip gracefully when unavailable
Tests 1-7: CPU path correctness (Task 1)
Tests 8-11: GPU path correctness + CPU vs GPU equivalence (Task 2)
"""
import torch
import torch.nn.functional as F
import sys
import os
import flash_vq
from arbitor.kernel.flash_vq import FlashVQCodebook, _HAS_TRITON
try:
from arbitor.main import VQAdapter, MultimodalVQBridge, HIDDEN_DIM, CODEBOOK_DIM
from arbitor.kernel.ternary_scale import TScaleType
_HAS_TRIGRAM = True
except ImportError:
_HAS_TRIGRAM = False
# ─── Test Helpers ───
def _make_cpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True):
"""Create a deterministic FlashVQCodebook on CPU."""
torch.manual_seed(seed)
vq = FlashVQCodebook(
codebook_size=codebook_size,
codebook_dim=codebook_dim,
decay=0.99,
commitment_weight=1.0,
threshold_ema_dead_code=2,
kmeans_init=False,
kmeans_iters=10,
rotation_trick=rotation_trick,
)
return vq
# ─── Task 1: CPU Path Tests (Tests 1-7) ───
def test_flash_vq_cpu_forward_shapes():
"""
Test 1: FlashVQCodebook CPU forward with random input returns
(quantized, indices, commitment_loss) with correct shapes.
"""
vq = _make_cpu_vq()
x = torch.randn(4, 16, 32)
quantized, indices, loss = vq._cpu_forward(x.reshape(-1, 32))
# quantized: [N, D] where N=B*T
assert quantized.shape == (64, 32), f"quantized shape: {quantized.shape}"
# indices: [N]
assert indices.shape == (64,), f"indices shape: {indices.shape}"
# commitment_loss: scalar or single-element
assert loss.numel() == 1, f"loss shape: {loss.shape}"
assert loss.dim() == 0, f"loss dim: {loss.dim()}"
# indices in valid range
assert indices.min() >= 0, f"negative index: {indices.min()}"
assert indices.max() < vq.codebook_size, f"index too large: {indices.max()}"
# quantized should match codebook dim
assert quantized.shape[-1] == 32, f"quantized last dim: {quantized.shape[-1]}"
print(" PASS test_flash_vq_cpu_forward_shapes")
def test_flash_vq_cpu_quantized_matches_codebook():
"""
Test 2: FlashVQCodebook CPU quantized output matches codebook[indices]
(straight-through estimator).
"""
vq = _make_cpu_vq()
x = torch.randn(4, 16, 32)
x_flat = x.reshape(-1, 32)
# Save embed snapshot before forward (EMA update modifies embed in-place)
embed_snapshot = vq.embed.clone()
quantized, indices, loss = vq._cpu_forward(x_flat)
# The quantized output should equal embed_snapshot[indices] with STE applied
# STE: quantized = x_flat + (embed[indices] - x_flat).detach()
expected_quantized = embed_snapshot[indices]
diff_vq = quantized - x_flat
diff_raw = expected_quantized - x_flat
# diff_vq should equal diff_raw.detach()
assert torch.allclose(diff_vq, diff_raw.detach(), atol=1e-6), \
"STE: quantized - x should equal (embed[indices] - x).detach()"
print(" PASS test_flash_vq_cpu_quantized_matches_codebook")
def test_flash_vq_cpu_cosine_sim():
"""
Test 3: FlashVQCodebook CPU cosine similarity matches
F.normalize(x) @ F.normalize(codebook).T argmax.
"""
vq = _make_cpu_vq()
x = torch.randn(4, 16, 32)
x_flat = x.reshape(-1, 32)
# Capture embed snapshot before EMA update modifies it
embed_snapshot = vq.embed.clone()
quantized, indices, loss = vq._cpu_forward(x_flat)
# Manual cosine similarity using embed before EMA update
x_norm = F.normalize(x_flat, dim=-1)
embed_norm = F.normalize(embed_snapshot, dim=-1)
manual_sim = x_norm @ embed_norm.T
manual_indices = manual_sim.argmax(dim=-1)
# Indices should match
assert torch.equal(indices, manual_indices), \
f"Indices differ! First 10 indices: {indices[:10]} vs {manual_indices[:10]}"
print(" PASS test_flash_vq_cpu_cosine_sim")
def test_flash_vq_cpu_ema_update():
"""
Test 4: FlashVQCodebook CPU EMA update changes embed and cluster_size
after forward pass (with rotation_trick=False for deterministic EMA).
Tests EMA in isolation by calling _ema_update directly, then verifies
embed and cluster_size changed for assigned codebook entries.
"""
vq = _make_cpu_vq(rotation_trick=False)
embed_before = vq.embed.clone()
cluster_size_before = vq.cluster_size.clone()
# Create indices that assign all inputs to first few codebook entries
x = torch.randn(2, 8, 32)
x_flat = x.reshape(-1, 32)
# Force indices to specific entries to make EMA predictable
indices = torch.zeros(16, dtype=torch.long)
# Assign inputs to the first 4 codebook entries
for i in range(16):
indices[i] = i % 4
# Call EMA update directly (isolated from dead code reset)
vq._ema_update(x_flat, indices)
# After EMA update, embed should have changed
assert not torch.equal(embed_before, vq.embed), \
"Embed did not change after EMA update"
# cluster_size should have changed
assert not torch.equal(cluster_size_before, vq.cluster_size), \
"cluster_size did not change after EMA update"
# cluster_size decay: initially 0, after assignment of 4 items each with decay=0.99:
# cluster_size = 0 * 0.99 + 4 * 0.01 = 0.04 for entries 0-3
assert (vq.cluster_size[:4] > 0).all(), \
"Assigned entries should have non-zero cluster_size"
assert (vq.cluster_size[4:] == 0).all(), \
"Unassigned entries should have zero cluster_size"
# Also test that the full forward (EMA + dead code reset) runs without error
# and embed changes overall
vq2 = _make_cpu_vq(rotation_trick=False)
embed_before2 = vq2.embed.clone()
q, idx, loss = vq2._cpu_forward(torch.randn(4, 16, 32).reshape(-1, 32))
assert not torch.equal(embed_before2, vq2.embed), \
"Embed did not change after full forward pass"
print(" PASS test_flash_vq_cpu_ema_update")
def test_flash_vq_cpu_dead_code_reset():
"""
Test 5: FlashVQCodebook CPU dead code reset replaces inactive codebook entries.
"""
vq = _make_cpu_vq()
# Manually set all cluster_sizes to 0 (all dead)
vq.cluster_size[:] = 0.0
# Mark a few entries as alive
vq.cluster_size[:10] = 5.0
x = torch.randn(2, 8, 32)
x_flat = x.reshape(-1, 32)
# Record embed before reset
embed_before = vq.embed.clone()
n_dead_before = vq.get_dead_code_count()
assert n_dead_before == vq.codebook_size - 10, \
f"Expected {vq.codebook_size - 10} dead entries, got {n_dead_before}"
# Run dead code reset
vq._dead_code_reset(x_flat)
# After reset: previously dead entries should now have cluster_size=0
# (the reset function sets cluster_size[dead_indices] = 0.0 after replacing)
n_dead_after = vq.get_dead_code_count()
# Entries with cluster_size == 0 should have been replaced
dead_indices_before_10 = torch.where(vq.cluster_size == 0)[0]
# Those entries' embed should have changed from before
if len(dead_indices_before_10) > 0:
idx = dead_indices_before_10[0]
assert not torch.equal(embed_before[idx], vq.embed[idx]), \
f"Dead entry {idx} embed was not replaced"
print(" PASS test_flash_vq_cpu_dead_code_reset")
def test_flash_vq_cpu_rotation_trick_grad():
"""
Test 6: FlashVQCodebook CPU rotation trick gradient flows correctly.
Gradient should not be zero, and should differ from STE gradient.
"""
torch.manual_seed(42)
# With rotation trick
vq_rot = _make_cpu_vq(rotation_trick=True, seed=42)
x = torch.randn(2, 4, 32, requires_grad=True)
x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True)
# Forward pass with rotation trick
quantized_rot, indices_rot, loss_rot = vq_rot._cpu_forward(x_flat)
# Gradient should flow through rotation trick
loss_val = quantized_rot.sum()
loss_val.backward()
rot_grad = x_flat.grad.clone()
assert rot_grad is not None, "Rotation trick gradient is None"
assert rot_grad.abs().sum().item() > 0, "Rotation trick gradient is all zeros"
# Compare with STE gradient (no rotation)
torch.manual_seed(42)
vq_ste = _make_cpu_vq(rotation_trick=False, seed=42)
x_flat2 = x.reshape(-1, 32).detach().clone().requires_grad_(True)
quantized_ste, indices_ste, loss_ste = vq_ste._cpu_forward(x_flat2)
loss_val_ste = quantized_ste.sum()
loss_val_ste.backward()
ste_grad = x_flat2.grad.clone()
# Rotation trick gradient should differ from STE gradient
# (if same codebook entries selected)
if torch.equal(indices_rot, indices_ste):
grad_diff = (rot_grad - ste_grad).abs().max().item()
assert grad_diff > 1e-8, \
f"Rotation trick gradient equals STE gradient (diff={grad_diff})"
print(" PASS test_flash_vq_cpu_rotation_trick_grad")
def test_flash_vq_cpu_commitment_loss():
"""
Test 7: FlashVQCodebook CPU commitment loss is non-negative scalar.
"""
vq = _make_cpu_vq(rotation_trick=False)
x = torch.randn(4, 16, 32)
x_flat = x.reshape(-1, 32)
quantized, indices, loss = vq._cpu_forward(x_flat)
assert loss.item() >= 0.0, f"Commitment loss is negative: {loss.item()}"
assert loss.dim() == 0, f"Loss is not scalar: {loss.shape}"
# With commitment_weight=1.0, loss should be MSE between x and quantized.detach()
expected_loss = F.mse_loss(x_flat, quantized.detach())
assert torch.allclose(loss, expected_loss, atol=1e-6), \
f"Loss mismatch: {loss.item()} vs {expected_loss.item()}"
print(" PASS test_flash_vq_cpu_commitment_loss")
# ─── Task 2: GPU Path Tests (Tests 8-11) ───
def _make_gpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True):
"""Create a deterministic FlashVQCodebook on GPU."""
vq = _make_cpu_vq(codebook_size, codebook_dim, seed, rotation_trick)
vq = vq.cuda()
return vq
def test_flash_vq_gpu_vs_cpu_forward():
"""
Test 8: FlashVQCodebook GPU forward output matches CPU forward output
within atol=1e-3.
"""
if not torch.cuda.is_available() or not _HAS_TRITON:
print(" SKIP test_flash_vq_gpu_vs_cpu_forward (CUDA/Triton unavailable)")
return
torch.manual_seed(42)
vq_cpu = _make_cpu_vq(rotation_trick=False)
vq_gpu = _make_gpu_vq(rotation_trick=False)
x = torch.randn(2, 8, 32)
x_flat = x.reshape(-1, 32)
quantized_cpu, indices_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat)
x_gpu = x_flat.detach().clone().cuda()
quantized_gpu, indices_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu)
quantized_gpu_cpu = quantized_gpu.cpu()
loss_gpu_cpu = loss_gpu.cpu()
# Compare quantized output within tolerance
fwd_diff = (quantized_cpu - quantized_gpu_cpu).abs().max().item()
assert fwd_diff < 1e-3, \
f"CPU vs GPU quantized max diff: {fwd_diff} (exceeds 1e-3)"
# Indices must match exactly
assert torch.equal(indices_cpu, indices_gpu.cpu()), \
"CPU vs GPU indices differ"
# Loss within tolerance
loss_diff = abs(loss_cpu.item() - loss_gpu_cpu.item())
assert loss_diff < 1e-3, \
f"CPU vs GPU loss diff: {loss_diff}"
print(f" PASS test_flash_vq_gpu_vs_cpu_forward (fwd_diff={fwd_diff:.6f})")
def test_flash_vq_gpu_vs_cpu_gradients():
"""
Test 9: FlashVQCodebook GPU gradient (rotation trick backward) matches
CPU gradient within atol=1e-3.
"""
if not torch.cuda.is_available() or not _HAS_TRITON:
print(" SKIP test_flash_vq_gpu_vs_cpu_gradients (CUDA/Triton unavailable)")
return
torch.manual_seed(42)
vq_cpu = _make_cpu_vq(rotation_trick=True, seed=42)
vq_gpu = _make_gpu_vq(rotation_trick=True, seed=42)
x = torch.randn(2, 4, 32)
x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True)
# CPU forward + backward
q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat)
q_cpu.sum().backward()
cpu_grad = x_flat.grad.clone()
# GPU forward + backward
x_gpu = x_flat.detach().clone().cuda().requires_grad_(True)
q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu)
q_gpu.sum().backward()
gpu_grad = x_gpu.grad.clone()
bwd_diff = (cpu_grad - gpu_grad.cpu()).abs().max().item()
assert bwd_diff < 1e-3, \
f"CPU vs GPU gradient max diff: {bwd_diff} (exceeds 1e-3)"
print(f" PASS test_flash_vq_gpu_vs_cpu_gradients (bwd_diff={bwd_diff:.6f})")
def test_flash_vq_gpu_small_codebook():
"""
Test 10: FlashVQCodebook GPU path with codebook_size=4096 also matches
CPU path (multi-codebook support per D-102).
"""
if not torch.cuda.is_available() or not _HAS_TRITON:
print(" SKIP test_flash_vq_gpu_small_codebook (CUDA/Triton unavailable)")
return
torch.manual_seed(42)
vq_cpu = _make_cpu_vq(codebook_size=4096, rotation_trick=False)
vq_gpu = _make_gpu_vq(codebook_size=4096, rotation_trick=False)
x = torch.randn(2, 8, 32)
x_flat = x.reshape(-1, 32)
q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat)
x_gpu = x_flat.detach().clone().cuda()
q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu)
fwd_diff = (q_cpu - q_gpu.cpu()).abs().max().item()
assert fwd_diff < 1e-3, \
f"CPU vs GPU (4096) quantized max diff: {fwd_diff}"
assert torch.equal(idx_cpu, idx_gpu.cpu()), \
"CPU vs GPU (4096) indices differ"
print(f" PASS test_flash_vq_gpu_small_codebook (fwd_diff={fwd_diff:.6f})")
# ─── Task 3: VQAdapter Integration Tests ───
def test_flash_vq_in_vqadapter():
"""
Test 11: VQAdapter with FlashVQCodebook forward produces correct shapes
and all VQAdapter methods work (get_codebook_utilization, get_dead_code_count,
l2_distance_matching).
"""
if not _HAS_TRIGRAM:
print(" SKIP test_flash_vq_in_vqadapter (trigram.py not importable)")
return
vq = VQAdapter(codebook_size=128, codebook_dim=32, tscale_type=TScaleType.T4)
# Force CPU for deterministic testing
vq.vq.embed.data = torch.randn(128, 32) * 0.02
vq.vq.cluster_size.data.zero_()
vq.eval()
x = torch.randn(2, 8, 512) # [B, T, trigram_dim]
with torch.no_grad():
output, vq_loss, indices = vq(x)
# output shape: [B, T, 512] (same as trigram_dim)
assert output.shape == (2, 8, 512), f"output shape: {output.shape}"
# vq_loss: scalar
assert vq_loss.numel() == 1, f"vq_loss shape: {vq_loss.shape}"
# indices: [B, T]
assert indices.shape == (2, 8), f"indices shape: {indices.shape}"
# indices in valid range
assert indices.min() >= 0, f"negative index: {indices.min()}"
assert indices.max() < 128, f"index too large: {indices.max()}"
# get_codebook_utilization returns float 0..1
util = vq.get_codebook_utilization()
assert isinstance(util, float), f"util type: {type(util)}"
assert 0.0 <= util <= 1.0, f"util out of range: {util}"
# get_dead_code_count returns non-negative int
dead = vq.get_dead_code_count()
assert isinstance(dead, (int, type(torch.tensor(0).item()))), f"dead type: {type(dead)}"
dead_val = int(dead)
assert dead_val >= 0, f"dead count negative: {dead_val}"
# l2_distance_matching returns (indices, distances) — expects codebook_dim input
x_codebook_dim = x[..., :32] # slice to match codebook_dim
with torch.no_grad():
l2_idx, l2_dist = vq.l2_distance_matching(x_codebook_dim)
assert l2_idx.shape == (2, 8), f"l2 indices shape: {l2_idx.shape}"
assert l2_dist.shape == (2, 8), f"l2 distances shape: {l2_dist.shape}"
assert l2_dist.min() >= 0.0, "l2 distance should be non-negative"
print(" PASS test_flash_vq_in_vqadapter")
def test_flash_vq_multimodal_bridge():
"""
Test 12: MultimodalVQBridge with FlashVQCodebook — all three VQAdapters
(text, image, audio) produce correct outputs.
"""
if not _HAS_TRIGRAM:
print(" SKIP test_flash_vq_multimodal_bridge (trigram.py not importable)")
return
bridge = MultimodalVQBridge(
text_codebook_size=256,
image_codebook_size=128,
audio_codebook_size=128,
codebook_dim=32,
enable_image=True,
enable_audio=True,
)
bridge.eval()
x = torch.randn(2, 8, 512)
with torch.no_grad():
text_out, text_loss, text_idx = bridge.text_vq(x)
image_out, image_loss, image_idx = bridge.image_vq(x)
audio_out, audio_loss, audio_idx = bridge.audio_vq(x)
assert text_out.shape == (2, 8, 512), f"text output shape: {text_out.shape}"
assert image_out.shape == (2, 8, 512), f"image output shape: {image_out.shape}"
assert audio_out.shape == (2, 8, 512), f"audio output shape: {audio_out.shape}"
assert text_idx.max() < 256, f"text index too large: {text_idx.max()}"
assert image_idx.max() < 128, f"image index too large: {image_idx.max()}"
assert audio_idx.max() < 128, f"audio index too large: {audio_idx.max()}"
# Bridge-level codebook utilization
all_util = bridge.get_codebook_utilization()
assert 'text' in all_util
assert 'image' in all_util
assert 'audio' in all_util
for mod, u in all_util.items():
assert 0.0 <= u <= 1.0, f"{mod} utilization out of range: {u}"
print(" PASS test_flash_vq_multimodal_bridge")
# ─── Manual Test Runner ───
if __name__ == "__main__":
cpu_tests = [
test_flash_vq_cpu_forward_shapes,
test_flash_vq_cpu_quantized_matches_codebook,
test_flash_vq_cpu_cosine_sim,
test_flash_vq_cpu_ema_update,
test_flash_vq_cpu_dead_code_reset,
test_flash_vq_cpu_rotation_trick_grad,
test_flash_vq_cpu_commitment_loss,
]
gpu_tests = [
test_flash_vq_gpu_vs_cpu_forward,
test_flash_vq_gpu_vs_cpu_gradients,
test_flash_vq_gpu_small_codebook,
]
integration_tests = [
test_flash_vq_in_vqadapter,
test_flash_vq_multimodal_bridge,
]
all_tests = cpu_tests + gpu_tests + integration_tests
print("Running FlashVQ tests...\n")
passed = 0
failed = 0
skipped = 0
for test in all_tests:
try:
test()
passed += 1
except Exception as e:
msg = str(e)
if msg.startswith(" SKIP"):
print(msg)
skipped += 1
else:
print(f" FAIL {test.__name__}: {e}")
import traceback
traceback.print_exc()
failed += 1
total_run = passed + failed
print(f"\n{passed} passed, {failed} failed, {skipped} skipped out of {len(all_tests)} tests (attempted {total_run})")
sys.exit(1 if failed > 0 else 0)