| """ |
| Comprehensive tests for DKM implementation. |
| |
| Tests verify: |
| 1. DKM Layer correctness (distance, attention, centroid updates) |
| 2. Convergence behavior |
| 3. Multi-dimensional clustering |
| 4. Gradient flow (differentiability) |
| 5. Train vs inference mode behavior |
| 6. Compression ratio calculations |
| 7. Full pipeline end-to-end |
| 8. Numerical stability |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import math |
| import sys |
| import traceback |
|
|
| |
| sys.path.insert(0, "/app") |
|
|
| from dkm.dkm_layer import DKMLayer |
| from dkm.compressor import DKMCompressor, compress_model |
| from dkm.utils import ( |
| compute_model_size, |
| compute_compression_ratio, |
| compute_effective_bpw, |
| count_unique_weights, |
| ) |
|
|
|
|
| def test_passed(name): |
| print(f" β {name}") |
|
|
| def test_failed(name, error): |
| print(f" β {name}: {error}") |
| return False |
|
|
|
|
| def test_dkm_layer_basic(): |
| """Test basic DKM layer creation and forward pass.""" |
| print("\n[Test 1] DKM Layer Basic Operations") |
| all_passed = True |
| |
| |
| weight = nn.Parameter(torch.randn(10, 5)) |
| |
| |
| dkm = DKMLayer(weight, n_clusters=4, tau=1e-3, dim=1, max_iter=5) |
| |
| |
| dkm.train() |
| compressed = dkm() |
| |
| if compressed.shape != weight.shape: |
| all_passed = test_failed("shape preservation", |
| f"Expected {weight.shape}, got {compressed.shape}") |
| else: |
| test_passed("shape preservation") |
| |
| |
| dkm.eval() |
| compressed_eval = dkm() |
| |
| if compressed_eval.shape != weight.shape: |
| all_passed = test_failed("eval shape", |
| f"Expected {weight.shape}, got {compressed_eval.shape}") |
| else: |
| test_passed("eval shape preservation") |
| |
| |
| codebook = dkm.get_codebook() |
| flat_eval = compressed_eval.reshape(-1) |
| codebook_values = codebook.squeeze() |
| |
| for val in flat_eval: |
| if not any(torch.isclose(val, cv, atol=1e-5) for cv in codebook_values): |
| all_passed = test_failed("hard assignment", |
| f"Value {val.item():.6f} not in codebook") |
| break |
| else: |
| test_passed("hard assignment (eval mode snaps to codebook)") |
| |
| return all_passed |
|
|
|
|
| def test_distance_matrix(): |
| """Test distance matrix computation.""" |
| print("\n[Test 2] Distance Matrix Computation") |
| all_passed = True |
| |
| weight = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) |
| dkm = DKMLayer(weight, n_clusters=2, tau=1.0, dim=1, max_iter=1) |
| |
| |
| W = weight.reshape(-1, 1) |
| C = dkm.centroids |
| |
| D = dkm._compute_distance_matrix(W, C) |
| |
| |
| for i in range(W.shape[0]): |
| for j in range(C.shape[0]): |
| expected = -((W[i, 0] - C[j, 0]) ** 2).item() |
| actual = D[i, j].item() |
| if abs(expected - actual) > 1e-5: |
| all_passed = test_failed( |
| f"distance D[{i},{j}]", |
| f"Expected {expected:.6f}, got {actual:.6f}" |
| ) |
| |
| if all_passed: |
| test_passed("distance matrix values correct") |
| |
| |
| if (D > 1e-6).any(): |
| all_passed = test_failed("non-positive distances", "Found positive distances") |
| else: |
| test_passed("all distances non-positive") |
| |
| return all_passed |
|
|
|
|
| def test_attention_matrix(): |
| """Test attention matrix computation (softmax with temperature).""" |
| print("\n[Test 3] Attention Matrix (Softmax with Temperature)") |
| all_passed = True |
| |
| weight = nn.Parameter(torch.randn(20)) |
| dkm = DKMLayer(weight, n_clusters=4, tau=1e-3, dim=1) |
| |
| W = weight.reshape(-1, 1) |
| D = dkm._compute_distance_matrix(W, dkm.centroids) |
| A = dkm._compute_attention(D) |
| |
| |
| row_sums = A.sum(dim=1) |
| if not torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5): |
| all_passed = test_failed("row sum", f"Rows don't sum to 1: {row_sums}") |
| else: |
| test_passed("attention rows sum to 1") |
| |
| |
| if (A < -1e-7).any(): |
| all_passed = test_failed("non-negative", "Found negative attention values") |
| else: |
| test_passed("all attention values non-negative") |
| |
| |
| dkm_hard = DKMLayer(weight, n_clusters=4, tau=1e-8, dim=1) |
| dkm_hard.centroids = dkm.centroids.clone() |
| D_hard = dkm_hard._compute_distance_matrix(W, dkm_hard.centroids) |
| A_hard = dkm_hard._compute_attention(D_hard) |
| |
| |
| max_vals = A_hard.max(dim=1).values |
| if not torch.allclose(max_vals, torch.ones_like(max_vals), atol=1e-3): |
| all_passed = test_failed("hard attention", |
| f"Small tau should give near-one-hot, max vals: {max_vals.mean():.6f}") |
| else: |
| test_passed("small tau produces near-one-hot attention") |
| |
| |
| dkm_soft = DKMLayer(weight, n_clusters=4, tau=1.0, dim=1) |
| dkm_soft.centroids = dkm.centroids.clone() |
| D_soft = dkm_soft._compute_distance_matrix(W, dkm_soft.centroids) |
| A_soft = dkm_soft._compute_attention(D_soft) |
| |
| entropy_hard = -(A_hard * torch.log(A_hard + 1e-10)).sum(dim=1).mean() |
| entropy_soft = -(A_soft * torch.log(A_soft + 1e-10)).sum(dim=1).mean() |
| |
| if entropy_soft <= entropy_hard: |
| all_passed = test_failed("tau entropy", |
| f"Larger tau should have higher entropy: soft={entropy_soft:.4f}, hard={entropy_hard:.4f}") |
| else: |
| test_passed(f"larger tau β higher entropy (soft={entropy_soft:.4f} > hard={entropy_hard:.4f})") |
| |
| return all_passed |
|
|
|
|
| def test_centroid_update(): |
| """Test centroid update formula: c_j = Ξ£(a_ij * w_i) / Ξ£(a_ij)""" |
| print("\n[Test 4] Centroid Update") |
| all_passed = True |
| |
| weight = nn.Parameter(torch.tensor([1.0, 2.0, 3.0, 10.0, 11.0, 12.0])) |
| dkm = DKMLayer(weight, n_clusters=2, tau=1e-6, dim=1, max_iter=10, epsilon=1e-8) |
| |
| |
| |
| dkm.train() |
| _ = dkm() |
| |
| centroids = dkm.centroids.squeeze().sort().values |
| expected_c1 = torch.tensor([1.0, 2.0, 3.0]).mean() |
| expected_c2 = torch.tensor([10.0, 11.0, 12.0]).mean() |
| |
| |
| if abs(centroids[0].item() - expected_c1.item()) > 0.5: |
| all_passed = test_failed("centroid 1", |
| f"Expected ~{expected_c1:.1f}, got {centroids[0]:.4f}") |
| else: |
| test_passed(f"centroid 1 converged to {centroids[0]:.4f} (expected ~{expected_c1:.1f})") |
| |
| if abs(centroids[1].item() - expected_c2.item()) > 0.5: |
| all_passed = test_failed("centroid 2", |
| f"Expected ~{expected_c2:.1f}, got {centroids[1]:.4f}") |
| else: |
| test_passed(f"centroid 2 converged to {centroids[1]:.4f} (expected ~{expected_c2:.1f})") |
| |
| return all_passed |
|
|
|
|
| def test_gradient_flow(): |
| """ |
| Test that gradients flow through the DKM layer (key paper contribution). |
| |
| The paper's main claim is that DKM is differentiable and enables |
| joint optimization of weights and clustering. |
| """ |
| print("\n[Test 5] Gradient Flow (Differentiability)") |
| all_passed = True |
| |
| |
| |
| |
| weight = nn.Parameter(torch.randn(8, 4) * 2.0) |
| dkm = DKMLayer(weight, n_clusters=4, tau=5e-2, dim=1, max_iter=3) |
| dkm.train() |
| |
| |
| compressed = dkm() |
| |
| |
| loss = compressed.sum() |
| loss.backward() |
| |
| |
| if weight.grad is None: |
| all_passed = test_failed("gradient exists", "No gradient on weight parameter") |
| elif weight.grad.abs().sum() == 0: |
| all_passed = test_failed("non-zero gradient", "Gradient is all zeros") |
| else: |
| test_passed(f"gradients flow through DKM (grad norm: {weight.grad.norm():.6f})") |
| |
| |
| if weight.grad is not None and weight.grad.shape != weight.shape: |
| all_passed = test_failed("gradient shape", |
| f"Expected {weight.shape}, got {weight.grad.shape}") |
| else: |
| test_passed("gradient shape matches weight shape") |
| |
| |
| |
| |
| |
| |
| |
| weight.grad = None |
| compressed_w = dkm() |
| target = torch.randn_like(weight) |
| loss_w = ((compressed_w - target) ** 2).sum() |
| loss_w.backward() |
| |
| |
| naive_grad = 2 * (weight.data - target) |
| |
| if weight.grad is not None: |
| rel_diff = (weight.grad - naive_grad).abs().mean() / (naive_grad.abs().mean() + 1e-8) |
| if rel_diff > 0.01: |
| test_passed(f"DKM transforms gradients (rel diff from naive: {rel_diff:.4f})") |
| else: |
| |
| test_passed(f"gradient flows through attention (rel diff: {rel_diff:.6f})") |
| else: |
| all_passed = test_failed("non-trivial gradient", "No gradient computed") |
| |
| |
| weight.grad = None |
| compressed2 = dkm() |
| loss2 = (compressed2 ** 2).sum() |
| loss2.backward() |
| |
| if weight.grad is not None and weight.grad.abs().sum() > 0: |
| test_passed("gradients change with different loss functions") |
| else: |
| all_passed = test_failed("loss-dependent gradient", "Gradient doesn't change with loss") |
| |
| return all_passed |
|
|
|
|
| def test_multidim_clustering(): |
| """ |
| Test multi-dimensional clustering (Section 3.3). |
| |
| With dim=d, weights are split into N/d contiguous d-dimensional sub-vectors |
| and clustered in d-dimensional space. |
| """ |
| print("\n[Test 6] Multi-Dimensional Clustering (Section 3.3)") |
| all_passed = True |
| |
| |
| weight = nn.Parameter(torch.randn(24)) |
| dkm = DKMLayer(weight, n_clusters=4, tau=1e-3, dim=4, max_iter=5) |
| |
| if dkm.n_vectors != 6: |
| all_passed = test_failed("n_vectors", f"Expected 6, got {dkm.n_vectors}") |
| else: |
| test_passed(f"24 weights / dim 4 = 6 sub-vectors") |
| |
| |
| if dkm.centroids.shape != (4, 4): |
| all_passed = test_failed("centroid shape", |
| f"Expected (4,4), got {dkm.centroids.shape}") |
| else: |
| test_passed("centroid shape is (n_clusters, dim) = (4, 4)") |
| |
| |
| dkm.train() |
| compressed = dkm() |
| if compressed.shape != weight.shape: |
| all_passed = test_failed("output shape", |
| f"Expected {weight.shape}, got {compressed.shape}") |
| else: |
| test_passed("multi-dim output shape preserved") |
| |
| |
| bpw = compute_effective_bpw(4, dim=4) |
| expected_bpw = math.log2(4) / 4 |
| if abs(bpw - expected_bpw) > 1e-6: |
| all_passed = test_failed("effective bpw", f"Expected {expected_bpw}, got {bpw}") |
| else: |
| test_passed(f"effective bits per weight: {bpw} (2 bits / 4 dim = 0.5 bpw)") |
| |
| |
| loss = compressed.sum() |
| loss.backward() |
| if weight.grad is None or weight.grad.abs().sum() == 0: |
| all_passed = test_failed("multi-dim gradient", "No gradient flow in multi-dim mode") |
| else: |
| test_passed("gradient flows in multi-dimensional mode") |
| |
| return all_passed |
|
|
|
|
| def test_convergence(): |
| """Test that DKM iterations converge (centroids stabilize).""" |
| print("\n[Test 7] Iterative Convergence") |
| all_passed = True |
| |
| |
| weight = nn.Parameter( |
| torch.cat([ |
| torch.randn(20) * 0.1 + 5.0, |
| torch.randn(20) * 0.1 - 5.0, |
| ]) |
| ) |
| |
| dkm = DKMLayer(weight, n_clusters=2, tau=1e-5, dim=1, max_iter=20, epsilon=1e-6) |
| dkm.train() |
| _ = dkm() |
| |
| centroids = dkm.centroids.squeeze().sort().values |
| |
| |
| if abs(centroids[0].item() - (-5.0)) > 1.0: |
| all_passed = test_failed("convergence c1", |
| f"Expected ~-5, got {centroids[0]:.4f}") |
| else: |
| test_passed(f"centroid 1 converged: {centroids[0]:.4f}") |
| |
| if abs(centroids[1].item() - 5.0) > 1.0: |
| all_passed = test_failed("convergence c2", |
| f"Expected ~5, got {centroids[1]:.4f}") |
| else: |
| test_passed(f"centroid 2 converged: {centroids[1]:.4f}") |
| |
| return all_passed |
|
|
|
|
| def test_compressor_wrapper(): |
| """Test the DKMCompressor wrapper on a small model.""" |
| print("\n[Test 8] DKM Compressor Wrapper") |
| all_passed = True |
| |
| |
| |
| |
| model = nn.Sequential( |
| nn.Linear(100, 200), |
| nn.ReLU(), |
| nn.Linear(200, 200), |
| nn.ReLU(), |
| nn.Linear(200, 10), |
| ) |
| |
| |
| for p in model.parameters(): |
| nn.init.normal_(p, std=0.1) |
| |
| |
| compressor = compress_model( |
| model, bits=2, dim=1, tau=1e-3, skip_first_last=False |
| ) |
| |
| |
| x = torch.randn(2, 100) |
| |
| compressor.train() |
| out_train = compressor(x) |
| if out_train.shape != (2, 10): |
| all_passed = test_failed("train output shape", |
| f"Expected (2,10), got {out_train.shape}") |
| else: |
| test_passed("train forward pass works") |
| |
| compressor.eval() |
| out_eval = compressor(x) |
| if out_eval.shape != (2, 10): |
| all_passed = test_failed("eval output shape", |
| f"Expected (2,10), got {out_eval.shape}") |
| else: |
| test_passed("eval forward pass works") |
| |
| |
| info = compressor.get_compression_info() |
| if info["compression_ratio"] <= 1.0: |
| all_passed = test_failed("compression ratio", |
| f"Expected >1, got {info['compression_ratio']:.2f}") |
| else: |
| test_passed(f"compression ratio: {info['compression_ratio']:.2f}x") |
| |
| |
| compressor.train() |
| out = compressor(x) |
| loss = out.sum() |
| loss.backward() |
| |
| has_grads = any(p.grad is not None and p.grad.abs().sum() > 0 |
| for p in compressor.parameters()) |
| if not has_grads: |
| all_passed = test_failed("compressor gradient", "No gradient flow through compressor") |
| else: |
| test_passed("gradient flows through entire compressor") |
| |
| return all_passed |
|
|
|
|
| def test_snap_weights(): |
| """Test weight snapping (inference mode).""" |
| print("\n[Test 9] Weight Snapping for Inference") |
| all_passed = True |
| |
| model = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.ReLU(), |
| nn.Linear(20, 5), |
| ) |
| |
| compressor = compress_model(model, bits=2, dim=1, tau=1e-3, skip_first_last=False) |
| |
| |
| x = torch.randn(2, 10) |
| compressor.train() |
| _ = compressor(x) |
| |
| |
| compressor.snap_weights() |
| |
| |
| |
| unique_counts = count_unique_weights(model) |
| for name, count in unique_counts.items(): |
| |
| max_expected = 256 |
| if count > max_expected: |
| all_passed = test_failed(f"snap {name}", |
| f"Too many unique values: {count} > {max_expected}") |
| else: |
| test_passed(f"layer {name}: {count} unique values") |
| |
| return all_passed |
|
|
|
|
| def test_export_compressed(): |
| """Test compressed model export.""" |
| print("\n[Test 10] Export Compressed Model") |
| all_passed = True |
| |
| model = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.Linear(20, 5), |
| ) |
| |
| compressor = compress_model(model, bits=2, dim=1, tau=1e-3, skip_first_last=False) |
| |
| |
| x = torch.randn(2, 10) |
| compressor.train() |
| _ = compressor(x) |
| |
| |
| export = compressor.export_compressed() |
| |
| if "state_dict" not in export: |
| all_passed = test_failed("export state_dict", "Missing state_dict") |
| else: |
| test_passed("export contains state_dict") |
| |
| if "codebooks" not in export: |
| all_passed = test_failed("export codebooks", "Missing codebooks") |
| else: |
| test_passed(f"export contains {len(export['codebooks'])} codebooks") |
| |
| if "assignments" not in export: |
| all_passed = test_failed("export assignments", "Missing assignments") |
| else: |
| test_passed(f"export contains {len(export['assignments'])} assignment maps") |
| |
| |
| for name, codebook in export["codebooks"].items(): |
| expected_clusters = 2 ** 2 |
| |
| if codebook.shape[0] not in [expected_clusters, 256]: |
| all_passed = test_failed(f"codebook {name}", |
| f"Expected {expected_clusters} or 256 clusters, got {codebook.shape[0]}") |
| else: |
| test_passed(f"codebook {name}: {codebook.shape}") |
| |
| return all_passed |
|
|
|
|
| def test_training_step(): |
| """Test that a full training step (forward + backward + step) works correctly.""" |
| print("\n[Test 11] Full Training Step") |
| all_passed = True |
| |
| model = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.ReLU(), |
| nn.Linear(20, 5), |
| ) |
| |
| compressor = compress_model(model, bits=2, dim=1, tau=1e-3, skip_first_last=False) |
| |
| optimizer = optim.SGD(compressor.parameters(), lr=0.01, momentum=0.9) |
| criterion = nn.CrossEntropyLoss() |
| |
| |
| compressor.train() |
| initial_loss = None |
| |
| for step in range(10): |
| x = torch.randn(8, 10) |
| y = torch.randint(0, 5, (8,)) |
| |
| optimizer.zero_grad() |
| out = compressor(x) |
| loss = criterion(out, y) |
| loss.backward() |
| optimizer.step() |
| |
| if step == 0: |
| initial_loss = loss.item() |
| |
| final_loss = loss.item() |
| |
| if math.isnan(final_loss) or math.isinf(final_loss): |
| all_passed = test_failed("numerical stability", f"Loss is {final_loss}") |
| else: |
| test_passed(f"training is numerically stable (loss: {initial_loss:.4f} β {final_loss:.4f})") |
| |
| return all_passed |
|
|
|
|
| def test_paper_configurations(): |
| """ |
| Test configurations mentioned in the paper: |
| - 2-bit scalar clustering (Table 1) |
| - 4/4 multi-dim (1 effective bpw) |
| - 8/8 multi-dim (1 effective bpw) |
| - 4/8 (0.5 effective bpw) |
| """ |
| print("\n[Test 12] Paper Configurations (Table 1)") |
| all_passed = True |
| |
| configs = [ |
| {"name": "3-bit", "bits": 3, "dim": 1, "expected_bpw": 3.0}, |
| {"name": "2-bit", "bits": 2, "dim": 1, "expected_bpw": 2.0}, |
| {"name": "1-bit", "bits": 1, "dim": 1, "expected_bpw": 1.0}, |
| {"name": "4/4", "bits": 4, "dim": 4, "expected_bpw": 1.0}, |
| {"name": "8/8", "bits": 8, "dim": 8, "expected_bpw": 1.0}, |
| {"name": "4/8", "bits": 4, "dim": 8, "expected_bpw": 0.5}, |
| {"name": "8/16", "bits": 8, "dim": 16, "expected_bpw": 0.5}, |
| ] |
| |
| for cfg in configs: |
| n_clusters = 2 ** cfg["bits"] |
| bpw = compute_effective_bpw(n_clusters, cfg["dim"]) |
| |
| if abs(bpw - cfg["expected_bpw"]) > 1e-6: |
| all_passed = test_failed(cfg["name"], |
| f"Expected bpw={cfg['expected_bpw']}, got {bpw}") |
| else: |
| test_passed(f"config {cfg['name']}: {n_clusters} clusters, dim={cfg['dim']} β {bpw} bpw") |
| |
| return all_passed |
|
|
|
|
| def test_kmeans_plus_plus(): |
| """Test k-means++ initialization produces well-spread centroids.""" |
| print("\n[Test 13] K-means++ Initialization") |
| all_passed = True |
| |
| torch.manual_seed(42) |
| |
| |
| weight = nn.Parameter( |
| torch.cat([ |
| torch.randn(50) * 0.1 - 10, |
| torch.randn(50) * 0.1, |
| torch.randn(50) * 0.1 + 10, |
| ]) |
| ) |
| |
| dkm = DKMLayer(weight, n_clusters=3, tau=1e-5, dim=1, init_method="kmeans++") |
| centroids = dkm.centroids.squeeze().sort().values |
| |
| |
| |
| spread = centroids.max() - centroids.min() |
| if spread < 5.0: |
| all_passed = test_failed("kmeans++ spread", |
| f"Centroids not well-spread: range={spread:.4f}") |
| else: |
| test_passed(f"k-means++ centroids well-spread (range={spread:.2f})") |
| |
| return all_passed |
|
|
|
|
| def test_warm_start(): |
| """ |
| Test that centroids are warm-started across batches (Section 3.2). |
| |
| In real training, weights change between batches due to gradient updates. |
| The warm start means centroids from the previous batch are used as initial |
| centroids for the next batch, accelerating convergence. |
| """ |
| print("\n[Test 14] Warm Start Across Batches") |
| all_passed = True |
| |
| weight = nn.Parameter(torch.randn(50)) |
| dkm = DKMLayer(weight, n_clusters=4, tau=1e-3, dim=1, max_iter=3) |
| dkm.train() |
| |
| |
| compressed = dkm() |
| centroids_after_1 = dkm.centroids.clone() |
| |
| |
| loss = compressed.sum() |
| loss.backward() |
| with torch.no_grad(): |
| weight.data -= 0.01 * weight.grad |
| weight.grad = None |
| |
| |
| compressed = dkm() |
| centroids_after_2 = dkm.centroids.clone() |
| |
| |
| loss = compressed.sum() |
| loss.backward() |
| with torch.no_grad(): |
| weight.data -= 0.01 * weight.grad |
| weight.grad = None |
| |
| |
| _ = dkm() |
| centroids_after_3 = dkm.centroids.clone() |
| |
| |
| delta_1_2 = (centroids_after_2 - centroids_after_1).abs().max().item() |
| delta_2_3 = (centroids_after_3 - centroids_after_2).abs().max().item() |
| |
| test_passed(f"centroid deltas: batch1β2: {delta_1_2:.6f}, batch2β3: {delta_2_3:.6f}") |
| |
| |
| if delta_1_2 == 0 and delta_2_3 == 0: |
| all_passed = test_failed("centroid movement", |
| "Centroids didn't move despite weight updates") |
| else: |
| test_passed("centroids adapt to weight changes (warm start working)") |
| |
| return all_passed |
|
|
|
|
| def test_numerical_stability(): |
| """Test numerical stability with extreme values.""" |
| print("\n[Test 15] Numerical Stability") |
| all_passed = True |
| |
| |
| weight_large = nn.Parameter(torch.randn(100) * 1000) |
| dkm_large = DKMLayer(weight_large, n_clusters=4, tau=1.0, dim=1) |
| dkm_large.train() |
| out = dkm_large() |
| if torch.isnan(out).any() or torch.isinf(out).any(): |
| all_passed = test_failed("large weights", "NaN/Inf with large weights") |
| else: |
| test_passed("stable with large weights") |
| |
| |
| weight_small = nn.Parameter(torch.randn(100) * 1e-8) |
| dkm_small = DKMLayer(weight_small, n_clusters=4, tau=1e-10, dim=1) |
| dkm_small.train() |
| out = dkm_small() |
| if torch.isnan(out).any() or torch.isinf(out).any(): |
| all_passed = test_failed("small weights", "NaN/Inf with small weights") |
| else: |
| test_passed("stable with small weights") |
| |
| |
| weight_uniform = nn.Parameter(torch.ones(100) * 5.0) |
| dkm_uniform = DKMLayer(weight_uniform, n_clusters=4, tau=1e-3, dim=1) |
| dkm_uniform.train() |
| out = dkm_uniform() |
| if torch.isnan(out).any() or torch.isinf(out).any(): |
| all_passed = test_failed("uniform weights", "NaN/Inf with uniform weights") |
| else: |
| test_passed("stable with uniform weights") |
| |
| return all_passed |
|
|
|
|
| def test_resnet_compression(): |
| """Test DKM on a small ResNet-like model end-to-end.""" |
| print("\n[Test 16] ResNet-like Model Compression") |
| all_passed = True |
| |
| |
| class ResBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(channels) |
| |
| def forward(self, x): |
| residual = x |
| out = torch.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| return torch.relu(out + residual) |
| |
| class SmallResNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(3, 16, 3, padding=1) |
| self.bn1 = nn.BatchNorm2d(16) |
| self.block1 = ResBlock(16) |
| self.block2 = ResBlock(16) |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Linear(16, 10) |
| |
| def forward(self, x): |
| x = torch.relu(self.bn1(self.conv1(x))) |
| x = self.block1(x) |
| x = self.block2(x) |
| x = self.pool(x).flatten(1) |
| return self.fc(x) |
| |
| model = SmallResNet() |
| |
| |
| compressor = compress_model( |
| model, bits=2, dim=1, tau=1e-3, skip_first_last=True |
| ) |
| |
| |
| optimizer = optim.SGD(compressor.parameters(), lr=0.01, momentum=0.9) |
| criterion = nn.CrossEntropyLoss() |
| |
| compressor.train() |
| x = torch.randn(4, 3, 32, 32) |
| y = torch.randint(0, 10, (4,)) |
| |
| out = compressor(x) |
| loss = criterion(out, y) |
| loss.backward() |
| optimizer.step() |
| |
| if math.isnan(loss.item()): |
| all_passed = test_failed("resnet train", "NaN loss") |
| else: |
| test_passed(f"ResNet training step: loss={loss.item():.4f}") |
| |
| |
| info = compressor.get_compression_info() |
| test_passed(f"Compression ratio: {info['compression_ratio']:.2f}x, " |
| f"Size: {info['original_size_mb']:.3f}MB β {info['compressed_size_mb']:.3f}MB") |
| |
| return all_passed |
|
|
|
|
| def run_all_tests(): |
| """Run all tests and report results.""" |
| print("=" * 70) |
| print("DKM Implementation Test Suite") |
| print("Based on: 'DKM: Differentiable K-Means Clustering Layer for") |
| print(" Neural Network Compression' (ICLR 2022, arXiv:2108.12659)") |
| print("=" * 70) |
| |
| tests = [ |
| ("DKM Layer Basic", test_dkm_layer_basic), |
| ("Distance Matrix", test_distance_matrix), |
| ("Attention Matrix", test_attention_matrix), |
| ("Centroid Update", test_centroid_update), |
| ("Gradient Flow", test_gradient_flow), |
| ("Multi-Dim Clustering", test_multidim_clustering), |
| ("Convergence", test_convergence), |
| ("Compressor Wrapper", test_compressor_wrapper), |
| ("Weight Snapping", test_snap_weights), |
| ("Export Compressed", test_export_compressed), |
| ("Training Step", test_training_step), |
| ("Paper Configurations", test_paper_configurations), |
| ("K-means++ Init", test_kmeans_plus_plus), |
| ("Warm Start", test_warm_start), |
| ("Numerical Stability", test_numerical_stability), |
| ("ResNet Compression", test_resnet_compression), |
| ] |
| |
| results = {} |
| for name, test_fn in tests: |
| try: |
| passed = test_fn() |
| results[name] = passed |
| except Exception as e: |
| print(f"\n βββ EXCEPTION in {name}: {e}") |
| traceback.print_exc() |
| results[name] = False |
| |
| |
| print("\n" + "=" * 70) |
| print("TEST SUMMARY") |
| print("=" * 70) |
| |
| total = len(results) |
| passed = sum(1 for v in results.values() if v) |
| failed = total - passed |
| |
| for name, result in results.items(): |
| status = "PASS β" if result else "FAIL β" |
| print(f" [{status}] {name}") |
| |
| print(f"\n{passed}/{total} test groups passed, {failed} failed") |
| |
| if failed > 0: |
| print("\nβ Some tests failed! Review the output above for details.") |
| return False |
| else: |
| print("\nβ All tests passed!") |
| return True |
|
|
|
|
| if __name__ == "__main__": |
| torch.manual_seed(42) |
| success = run_all_tests() |
| sys.exit(0 if success else 1) |
|
|