""" import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) FlashVQ Correctness Tests — CPU path, GPU path, and CPU vs GPU equivalence. Test structure follows testing/test_tscale.py pattern: - Each test is a standalone function - Manual runner at bottom for direct execution - CUDA/Triton tests skip gracefully when unavailable Tests 1-7: CPU path correctness (Task 1) Tests 8-11: GPU path correctness + CPU vs GPU equivalence (Task 2) """ import torch import torch.nn.functional as F import sys import os import flash_vq from arbitor.kernel.flash_vq import FlashVQCodebook, _HAS_TRITON try: from arbitor.main import VQAdapter, MultimodalVQBridge, HIDDEN_DIM, CODEBOOK_DIM from arbitor.kernel.ternary_scale import TScaleType _HAS_TRIGRAM = True except ImportError: _HAS_TRIGRAM = False # ─── Test Helpers ─── def _make_cpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True): """Create a deterministic FlashVQCodebook on CPU.""" torch.manual_seed(seed) vq = FlashVQCodebook( codebook_size=codebook_size, codebook_dim=codebook_dim, decay=0.99, commitment_weight=1.0, threshold_ema_dead_code=2, kmeans_init=False, kmeans_iters=10, rotation_trick=rotation_trick, ) return vq # ─── Task 1: CPU Path Tests (Tests 1-7) ─── def test_flash_vq_cpu_forward_shapes(): """ Test 1: FlashVQCodebook CPU forward with random input returns (quantized, indices, commitment_loss) with correct shapes. """ vq = _make_cpu_vq() x = torch.randn(4, 16, 32) quantized, indices, loss = vq._cpu_forward(x.reshape(-1, 32)) # quantized: [N, D] where N=B*T assert quantized.shape == (64, 32), f"quantized shape: {quantized.shape}" # indices: [N] assert indices.shape == (64,), f"indices shape: {indices.shape}" # commitment_loss: scalar or single-element assert loss.numel() == 1, f"loss shape: {loss.shape}" assert loss.dim() == 0, f"loss dim: {loss.dim()}" # indices in valid range assert indices.min() >= 0, f"negative index: {indices.min()}" assert indices.max() < vq.codebook_size, f"index too large: {indices.max()}" # quantized should match codebook dim assert quantized.shape[-1] == 32, f"quantized last dim: {quantized.shape[-1]}" print(" PASS test_flash_vq_cpu_forward_shapes") def test_flash_vq_cpu_quantized_matches_codebook(): """ Test 2: FlashVQCodebook CPU quantized output matches codebook[indices] (straight-through estimator). """ vq = _make_cpu_vq() x = torch.randn(4, 16, 32) x_flat = x.reshape(-1, 32) # Save embed snapshot before forward (EMA update modifies embed in-place) embed_snapshot = vq.embed.clone() quantized, indices, loss = vq._cpu_forward(x_flat) # The quantized output should equal embed_snapshot[indices] with STE applied # STE: quantized = x_flat + (embed[indices] - x_flat).detach() expected_quantized = embed_snapshot[indices] diff_vq = quantized - x_flat diff_raw = expected_quantized - x_flat # diff_vq should equal diff_raw.detach() assert torch.allclose(diff_vq, diff_raw.detach(), atol=1e-6), \ "STE: quantized - x should equal (embed[indices] - x).detach()" print(" PASS test_flash_vq_cpu_quantized_matches_codebook") def test_flash_vq_cpu_cosine_sim(): """ Test 3: FlashVQCodebook CPU cosine similarity matches F.normalize(x) @ F.normalize(codebook).T argmax. """ vq = _make_cpu_vq() x = torch.randn(4, 16, 32) x_flat = x.reshape(-1, 32) # Capture embed snapshot before EMA update modifies it embed_snapshot = vq.embed.clone() quantized, indices, loss = vq._cpu_forward(x_flat) # Manual cosine similarity using embed before EMA update x_norm = F.normalize(x_flat, dim=-1) embed_norm = F.normalize(embed_snapshot, dim=-1) manual_sim = x_norm @ embed_norm.T manual_indices = manual_sim.argmax(dim=-1) # Indices should match assert torch.equal(indices, manual_indices), \ f"Indices differ! First 10 indices: {indices[:10]} vs {manual_indices[:10]}" print(" PASS test_flash_vq_cpu_cosine_sim") def test_flash_vq_cpu_ema_update(): """ Test 4: FlashVQCodebook CPU EMA update changes embed and cluster_size after forward pass (with rotation_trick=False for deterministic EMA). Tests EMA in isolation by calling _ema_update directly, then verifies embed and cluster_size changed for assigned codebook entries. """ vq = _make_cpu_vq(rotation_trick=False) embed_before = vq.embed.clone() cluster_size_before = vq.cluster_size.clone() # Create indices that assign all inputs to first few codebook entries x = torch.randn(2, 8, 32) x_flat = x.reshape(-1, 32) # Force indices to specific entries to make EMA predictable indices = torch.zeros(16, dtype=torch.long) # Assign inputs to the first 4 codebook entries for i in range(16): indices[i] = i % 4 # Call EMA update directly (isolated from dead code reset) vq._ema_update(x_flat, indices) # After EMA update, embed should have changed assert not torch.equal(embed_before, vq.embed), \ "Embed did not change after EMA update" # cluster_size should have changed assert not torch.equal(cluster_size_before, vq.cluster_size), \ "cluster_size did not change after EMA update" # cluster_size decay: initially 0, after assignment of 4 items each with decay=0.99: # cluster_size = 0 * 0.99 + 4 * 0.01 = 0.04 for entries 0-3 assert (vq.cluster_size[:4] > 0).all(), \ "Assigned entries should have non-zero cluster_size" assert (vq.cluster_size[4:] == 0).all(), \ "Unassigned entries should have zero cluster_size" # Also test that the full forward (EMA + dead code reset) runs without error # and embed changes overall vq2 = _make_cpu_vq(rotation_trick=False) embed_before2 = vq2.embed.clone() q, idx, loss = vq2._cpu_forward(torch.randn(4, 16, 32).reshape(-1, 32)) assert not torch.equal(embed_before2, vq2.embed), \ "Embed did not change after full forward pass" print(" PASS test_flash_vq_cpu_ema_update") def test_flash_vq_cpu_dead_code_reset(): """ Test 5: FlashVQCodebook CPU dead code reset replaces inactive codebook entries. """ vq = _make_cpu_vq() # Manually set all cluster_sizes to 0 (all dead) vq.cluster_size[:] = 0.0 # Mark a few entries as alive vq.cluster_size[:10] = 5.0 x = torch.randn(2, 8, 32) x_flat = x.reshape(-1, 32) # Record embed before reset embed_before = vq.embed.clone() n_dead_before = vq.get_dead_code_count() assert n_dead_before == vq.codebook_size - 10, \ f"Expected {vq.codebook_size - 10} dead entries, got {n_dead_before}" # Run dead code reset vq._dead_code_reset(x_flat) # After reset: previously dead entries should now have cluster_size=0 # (the reset function sets cluster_size[dead_indices] = 0.0 after replacing) n_dead_after = vq.get_dead_code_count() # Entries with cluster_size == 0 should have been replaced dead_indices_before_10 = torch.where(vq.cluster_size == 0)[0] # Those entries' embed should have changed from before if len(dead_indices_before_10) > 0: idx = dead_indices_before_10[0] assert not torch.equal(embed_before[idx], vq.embed[idx]), \ f"Dead entry {idx} embed was not replaced" print(" PASS test_flash_vq_cpu_dead_code_reset") def test_flash_vq_cpu_rotation_trick_grad(): """ Test 6: FlashVQCodebook CPU rotation trick gradient flows correctly. Gradient should not be zero, and should differ from STE gradient. """ torch.manual_seed(42) # With rotation trick vq_rot = _make_cpu_vq(rotation_trick=True, seed=42) x = torch.randn(2, 4, 32, requires_grad=True) x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True) # Forward pass with rotation trick quantized_rot, indices_rot, loss_rot = vq_rot._cpu_forward(x_flat) # Gradient should flow through rotation trick loss_val = quantized_rot.sum() loss_val.backward() rot_grad = x_flat.grad.clone() assert rot_grad is not None, "Rotation trick gradient is None" assert rot_grad.abs().sum().item() > 0, "Rotation trick gradient is all zeros" # Compare with STE gradient (no rotation) torch.manual_seed(42) vq_ste = _make_cpu_vq(rotation_trick=False, seed=42) x_flat2 = x.reshape(-1, 32).detach().clone().requires_grad_(True) quantized_ste, indices_ste, loss_ste = vq_ste._cpu_forward(x_flat2) loss_val_ste = quantized_ste.sum() loss_val_ste.backward() ste_grad = x_flat2.grad.clone() # Rotation trick gradient should differ from STE gradient # (if same codebook entries selected) if torch.equal(indices_rot, indices_ste): grad_diff = (rot_grad - ste_grad).abs().max().item() assert grad_diff > 1e-8, \ f"Rotation trick gradient equals STE gradient (diff={grad_diff})" print(" PASS test_flash_vq_cpu_rotation_trick_grad") def test_flash_vq_cpu_commitment_loss(): """ Test 7: FlashVQCodebook CPU commitment loss is non-negative scalar. """ vq = _make_cpu_vq(rotation_trick=False) x = torch.randn(4, 16, 32) x_flat = x.reshape(-1, 32) quantized, indices, loss = vq._cpu_forward(x_flat) assert loss.item() >= 0.0, f"Commitment loss is negative: {loss.item()}" assert loss.dim() == 0, f"Loss is not scalar: {loss.shape}" # With commitment_weight=1.0, loss should be MSE between x and quantized.detach() expected_loss = F.mse_loss(x_flat, quantized.detach()) assert torch.allclose(loss, expected_loss, atol=1e-6), \ f"Loss mismatch: {loss.item()} vs {expected_loss.item()}" print(" PASS test_flash_vq_cpu_commitment_loss") # ─── Task 2: GPU Path Tests (Tests 8-11) ─── def _make_gpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True): """Create a deterministic FlashVQCodebook on GPU.""" vq = _make_cpu_vq(codebook_size, codebook_dim, seed, rotation_trick) vq = vq.cuda() return vq def test_flash_vq_gpu_vs_cpu_forward(): """ Test 8: FlashVQCodebook GPU forward output matches CPU forward output within atol=1e-3. """ if not torch.cuda.is_available() or not _HAS_TRITON: print(" SKIP test_flash_vq_gpu_vs_cpu_forward (CUDA/Triton unavailable)") return torch.manual_seed(42) vq_cpu = _make_cpu_vq(rotation_trick=False) vq_gpu = _make_gpu_vq(rotation_trick=False) x = torch.randn(2, 8, 32) x_flat = x.reshape(-1, 32) quantized_cpu, indices_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) x_gpu = x_flat.detach().clone().cuda() quantized_gpu, indices_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) quantized_gpu_cpu = quantized_gpu.cpu() loss_gpu_cpu = loss_gpu.cpu() # Compare quantized output within tolerance fwd_diff = (quantized_cpu - quantized_gpu_cpu).abs().max().item() assert fwd_diff < 1e-3, \ f"CPU vs GPU quantized max diff: {fwd_diff} (exceeds 1e-3)" # Indices must match exactly assert torch.equal(indices_cpu, indices_gpu.cpu()), \ "CPU vs GPU indices differ" # Loss within tolerance loss_diff = abs(loss_cpu.item() - loss_gpu_cpu.item()) assert loss_diff < 1e-3, \ f"CPU vs GPU loss diff: {loss_diff}" print(f" PASS test_flash_vq_gpu_vs_cpu_forward (fwd_diff={fwd_diff:.6f})") def test_flash_vq_gpu_vs_cpu_gradients(): """ Test 9: FlashVQCodebook GPU gradient (rotation trick backward) matches CPU gradient within atol=1e-3. """ if not torch.cuda.is_available() or not _HAS_TRITON: print(" SKIP test_flash_vq_gpu_vs_cpu_gradients (CUDA/Triton unavailable)") return torch.manual_seed(42) vq_cpu = _make_cpu_vq(rotation_trick=True, seed=42) vq_gpu = _make_gpu_vq(rotation_trick=True, seed=42) x = torch.randn(2, 4, 32) x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True) # CPU forward + backward q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) q_cpu.sum().backward() cpu_grad = x_flat.grad.clone() # GPU forward + backward x_gpu = x_flat.detach().clone().cuda().requires_grad_(True) q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) q_gpu.sum().backward() gpu_grad = x_gpu.grad.clone() bwd_diff = (cpu_grad - gpu_grad.cpu()).abs().max().item() assert bwd_diff < 1e-3, \ f"CPU vs GPU gradient max diff: {bwd_diff} (exceeds 1e-3)" print(f" PASS test_flash_vq_gpu_vs_cpu_gradients (bwd_diff={bwd_diff:.6f})") def test_flash_vq_gpu_small_codebook(): """ Test 10: FlashVQCodebook GPU path with codebook_size=4096 also matches CPU path (multi-codebook support per D-102). """ if not torch.cuda.is_available() or not _HAS_TRITON: print(" SKIP test_flash_vq_gpu_small_codebook (CUDA/Triton unavailable)") return torch.manual_seed(42) vq_cpu = _make_cpu_vq(codebook_size=4096, rotation_trick=False) vq_gpu = _make_gpu_vq(codebook_size=4096, rotation_trick=False) x = torch.randn(2, 8, 32) x_flat = x.reshape(-1, 32) q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) x_gpu = x_flat.detach().clone().cuda() q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) fwd_diff = (q_cpu - q_gpu.cpu()).abs().max().item() assert fwd_diff < 1e-3, \ f"CPU vs GPU (4096) quantized max diff: {fwd_diff}" assert torch.equal(idx_cpu, idx_gpu.cpu()), \ "CPU vs GPU (4096) indices differ" print(f" PASS test_flash_vq_gpu_small_codebook (fwd_diff={fwd_diff:.6f})") # ─── Task 3: VQAdapter Integration Tests ─── def test_flash_vq_in_vqadapter(): """ Test 11: VQAdapter with FlashVQCodebook forward produces correct shapes and all VQAdapter methods work (get_codebook_utilization, get_dead_code_count, l2_distance_matching). """ if not _HAS_TRIGRAM: print(" SKIP test_flash_vq_in_vqadapter (trigram.py not importable)") return vq = VQAdapter(codebook_size=128, codebook_dim=32, tscale_type=TScaleType.T4) # Force CPU for deterministic testing vq.vq.embed.data = torch.randn(128, 32) * 0.02 vq.vq.cluster_size.data.zero_() vq.eval() x = torch.randn(2, 8, 512) # [B, T, trigram_dim] with torch.no_grad(): output, vq_loss, indices = vq(x) # output shape: [B, T, 512] (same as trigram_dim) assert output.shape == (2, 8, 512), f"output shape: {output.shape}" # vq_loss: scalar assert vq_loss.numel() == 1, f"vq_loss shape: {vq_loss.shape}" # indices: [B, T] assert indices.shape == (2, 8), f"indices shape: {indices.shape}" # indices in valid range assert indices.min() >= 0, f"negative index: {indices.min()}" assert indices.max() < 128, f"index too large: {indices.max()}" # get_codebook_utilization returns float 0..1 util = vq.get_codebook_utilization() assert isinstance(util, float), f"util type: {type(util)}" assert 0.0 <= util <= 1.0, f"util out of range: {util}" # get_dead_code_count returns non-negative int dead = vq.get_dead_code_count() assert isinstance(dead, (int, type(torch.tensor(0).item()))), f"dead type: {type(dead)}" dead_val = int(dead) assert dead_val >= 0, f"dead count negative: {dead_val}" # l2_distance_matching returns (indices, distances) — expects codebook_dim input x_codebook_dim = x[..., :32] # slice to match codebook_dim with torch.no_grad(): l2_idx, l2_dist = vq.l2_distance_matching(x_codebook_dim) assert l2_idx.shape == (2, 8), f"l2 indices shape: {l2_idx.shape}" assert l2_dist.shape == (2, 8), f"l2 distances shape: {l2_dist.shape}" assert l2_dist.min() >= 0.0, "l2 distance should be non-negative" print(" PASS test_flash_vq_in_vqadapter") def test_flash_vq_multimodal_bridge(): """ Test 12: MultimodalVQBridge with FlashVQCodebook — all three VQAdapters (text, image, audio) produce correct outputs. """ if not _HAS_TRIGRAM: print(" SKIP test_flash_vq_multimodal_bridge (trigram.py not importable)") return bridge = MultimodalVQBridge( text_codebook_size=256, image_codebook_size=128, audio_codebook_size=128, codebook_dim=32, enable_image=True, enable_audio=True, ) bridge.eval() x = torch.randn(2, 8, 512) with torch.no_grad(): text_out, text_loss, text_idx = bridge.text_vq(x) image_out, image_loss, image_idx = bridge.image_vq(x) audio_out, audio_loss, audio_idx = bridge.audio_vq(x) assert text_out.shape == (2, 8, 512), f"text output shape: {text_out.shape}" assert image_out.shape == (2, 8, 512), f"image output shape: {image_out.shape}" assert audio_out.shape == (2, 8, 512), f"audio output shape: {audio_out.shape}" assert text_idx.max() < 256, f"text index too large: {text_idx.max()}" assert image_idx.max() < 128, f"image index too large: {image_idx.max()}" assert audio_idx.max() < 128, f"audio index too large: {audio_idx.max()}" # Bridge-level codebook utilization all_util = bridge.get_codebook_utilization() assert 'text' in all_util assert 'image' in all_util assert 'audio' in all_util for mod, u in all_util.items(): assert 0.0 <= u <= 1.0, f"{mod} utilization out of range: {u}" print(" PASS test_flash_vq_multimodal_bridge") # ─── Manual Test Runner ─── if __name__ == "__main__": cpu_tests = [ test_flash_vq_cpu_forward_shapes, test_flash_vq_cpu_quantized_matches_codebook, test_flash_vq_cpu_cosine_sim, test_flash_vq_cpu_ema_update, test_flash_vq_cpu_dead_code_reset, test_flash_vq_cpu_rotation_trick_grad, test_flash_vq_cpu_commitment_loss, ] gpu_tests = [ test_flash_vq_gpu_vs_cpu_forward, test_flash_vq_gpu_vs_cpu_gradients, test_flash_vq_gpu_small_codebook, ] integration_tests = [ test_flash_vq_in_vqadapter, test_flash_vq_multimodal_bridge, ] all_tests = cpu_tests + gpu_tests + integration_tests print("Running FlashVQ tests...\n") passed = 0 failed = 0 skipped = 0 for test in all_tests: try: test() passed += 1 except Exception as e: msg = str(e) if msg.startswith(" SKIP"): print(msg) skipped += 1 else: print(f" FAIL {test.__name__}: {e}") import traceback traceback.print_exc() failed += 1 total_run = passed + failed print(f"\n{passed} passed, {failed} failed, {skipped} skipped out of {len(all_tests)} tests (attempted {total_run})") sys.exit(1 if failed > 0 else 0)