| |
| """ |
| Test WrinkleBrane Optimizations |
| Validate performance and fidelity improvements from optimizations. |
| """ |
|
|
| import sys |
| from pathlib import Path |
| sys.path.append(str(Path(__file__).resolve().parent / "src")) |
|
|
| import torch |
| import numpy as np |
| import time |
| from wrinklebrane.membrane_bank import MembraneBank |
| from wrinklebrane.codes import hadamard_codes |
| from wrinklebrane.slicer import make_slicer |
| from wrinklebrane.write_ops import store_pairs |
| from wrinklebrane.metrics import psnr, ssim |
| from wrinklebrane.optimizations import ( |
| compute_adaptive_alphas, |
| generate_extended_codes, |
| HierarchicalMembraneBank, |
| optimized_store_pairs |
| ) |
|
|
| def test_adaptive_alphas(): |
| """Test adaptive alpha scaling vs uniform alphas.""" |
| print("π§ͺ Testing Adaptive Alpha Scaling...") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| B, L, H, W, K = 1, 32, 16, 16, 8 |
| |
| |
| bank_uniform = MembraneBank(L, H, W, device=device) |
| bank_adaptive = MembraneBank(L, H, W, device=device) |
| bank_uniform.allocate(B) |
| bank_adaptive.allocate(B) |
| |
| C = hadamard_codes(L, K).to(device) |
| slicer = make_slicer(C) |
| |
| |
| patterns = [] |
| for i in range(K): |
| pattern = torch.zeros(H, W, device=device) |
| |
| energy_scale = 0.1 + i * 0.3 |
| |
| if i % 3 == 0: |
| for y in range(H): |
| for x in range(W): |
| if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2: |
| pattern[y, x] = energy_scale |
| elif i % 3 == 1: |
| size = 4 + i//3 |
| start = (H - size) // 2 |
| pattern[start:start+size, start:start+size] = energy_scale * 0.5 |
| else: |
| for d in range(min(H, W)): |
| if d + i//3 < H and d + i//3 < W: |
| pattern[d + i//3, d] = energy_scale * 0.1 |
| |
| patterns.append(pattern) |
| |
| patterns = torch.stack(patterns) |
| keys = torch.arange(K, device=device) |
| |
| |
| uniform_alphas = torch.ones(K, device=device) |
| M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas) |
| bank_uniform.write(M_uniform - bank_uniform.read()) |
| uniform_readouts = slicer(bank_uniform.read()).squeeze(0) |
| |
| |
| adaptive_alphas = compute_adaptive_alphas(patterns, C, keys) |
| M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas) |
| bank_adaptive.write(M_adaptive - bank_adaptive.read()) |
| adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0) |
| |
| |
| uniform_psnr = [] |
| adaptive_psnr = [] |
| |
| print(" Pattern-by-pattern comparison:") |
| for i in range(K): |
| u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy()) |
| a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy()) |
| |
| uniform_psnr.append(u_psnr) |
| adaptive_psnr.append(a_psnr) |
| |
| energy = torch.norm(patterns[i]).item() |
| print(f" Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}") |
| print(f" Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB") |
| |
| avg_uniform = np.mean(uniform_psnr) |
| avg_adaptive = np.mean(adaptive_psnr) |
| improvement = avg_adaptive - avg_uniform |
| |
| print(f"\n Results Summary:") |
| print(f" Uniform alphas: {avg_uniform:.1f}dB average PSNR") |
| print(f" Adaptive alphas: {avg_adaptive:.1f}dB average PSNR") |
| print(f" Improvement: {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)") |
| |
| return improvement > 0 |
|
|
|
|
| def test_extended_codes(): |
| """Test extended code generation for K > L scenarios.""" |
| print("\nπ§ͺ Testing Extended Code Generation...") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| L = 32 |
| test_Ks = [16, 32, 64, 128] |
| |
| results = {} |
| |
| for K in test_Ks: |
| print(f" Testing L={L}, K={K} (capacity: {K/L:.1f}x)") |
| |
| |
| C = generate_extended_codes(L, K, method="auto", device=device) |
| |
| |
| if K <= L: |
| G = C.T @ C |
| I_approx = torch.eye(K, device=device, dtype=C.dtype) |
| orthogonality_error = torch.norm(G - I_approx).item() |
| else: |
| |
| C_ortho = C[:, :L] |
| G = C_ortho.T @ C_ortho |
| I_approx = torch.eye(L, device=device, dtype=C.dtype) |
| orthogonality_error = torch.norm(G - I_approx).item() |
| |
| |
| B, H, W = 1, 8, 8 |
| bank = MembraneBank(L, H, W, device=device) |
| bank.allocate(B) |
| |
| slicer = make_slicer(C) |
| |
| |
| |
| actual_K = min(K, C.shape[1]) |
| patterns = torch.rand(actual_K, H, W, device=device) |
| keys = torch.arange(actual_K, device=device) |
| alphas = torch.ones(actual_K, device=device) |
| |
| |
| M = store_pairs(bank.read(), C, keys, patterns, alphas) |
| bank.write(M - bank.read()) |
| readouts = slicer(bank.read()).squeeze(0) |
| |
| |
| psnr_values = [] |
| for i in range(actual_K): |
| psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
| psnr_values.append(psnr_val) |
| |
| avg_psnr = np.mean(psnr_values) |
| min_psnr = np.min(psnr_values) |
| std_psnr = np.std(psnr_values) |
| |
| results[K] = { |
| "orthogonality_error": orthogonality_error, |
| "avg_psnr": avg_psnr, |
| "min_psnr": min_psnr, |
| "std_psnr": std_psnr |
| } |
| |
| print(f" Orthogonality error: {orthogonality_error:.6f}") |
| print(f" PSNR: {avg_psnr:.1f}Β±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)") |
| |
| return results |
|
|
|
|
| def test_hierarchical_memory(): |
| """Test hierarchical memory bank organization.""" |
| print("\nπ§ͺ Testing Hierarchical Memory Bank...") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| L, H, W = 64, 32, 32 |
| K = 32 |
| |
| |
| hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device) |
| hierarchical_bank.allocate(1) |
| |
| |
| regular_bank = MembraneBank(L, H, W, device=device) |
| regular_bank.allocate(1) |
| |
| |
| patterns = [] |
| for i in range(K): |
| if i < K // 3: |
| pattern = torch.rand(H, W, device=device) |
| elif i < 2 * K // 3: |
| pattern = torch.zeros(H, W, device=device) |
| pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device) |
| else: |
| pattern = torch.zeros(H, W, device=device) |
| pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device) |
| patterns.append(pattern) |
| |
| patterns = torch.stack(patterns) |
| keys = torch.arange(K, device=device) |
| |
| |
| C_regular = hadamard_codes(L, K).to(device) |
| slicer_regular = make_slicer(C_regular) |
| alphas_regular = torch.ones(K, device=device) |
| |
| start_time = time.time() |
| M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular) |
| regular_bank.write(M_regular - regular_bank.read()) |
| regular_readouts = slicer_regular(regular_bank.read()).squeeze(0) |
| regular_time = time.time() - start_time |
| |
| |
| start_time = time.time() |
| hierarchical_bank.store_hierarchical(patterns, keys) |
| hierarchical_time = time.time() - start_time |
| |
| |
| regular_memory = L * H * W * 4 |
| hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks) |
| memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100 |
| |
| |
| regular_psnr = [] |
| for i in range(K): |
| psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy()) |
| regular_psnr.append(psnr_val) |
| |
| avg_regular_psnr = np.mean(regular_psnr) |
| |
| print(f" Regular Bank:") |
| print(f" Storage time: {regular_time*1000:.2f}ms") |
| print(f" Memory usage: {regular_memory/1e6:.2f}MB") |
| print(f" Average PSNR: {avg_regular_psnr:.1f}dB") |
| |
| print(f" Hierarchical Bank:") |
| print(f" Storage time: {hierarchical_time*1000:.2f}ms") |
| print(f" Memory usage: {hierarchical_memory/1e6:.2f}MB") |
| print(f" Memory savings: {memory_savings:.1f}%") |
| print(f" Levels: {hierarchical_bank.levels}") |
| |
| for i, bank in enumerate(hierarchical_bank.banks): |
| level_fraction = bank.L / hierarchical_bank.total_L |
| print(f" Level {i}: L={bank.L} ({level_fraction:.1%})") |
| |
| return memory_savings > 0 |
|
|
|
|
| def test_optimized_storage(): |
| """Test the complete optimized storage pipeline.""" |
| print("\nπ§ͺ Testing Optimized Storage Pipeline...") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| B, L, H, W, K = 1, 64, 32, 32, 48 |
| |
| |
| bank_original = MembraneBank(L, H, W, device=device) |
| bank_optimized = MembraneBank(L, H, W, device=device) |
| bank_original.allocate(B) |
| bank_optimized.allocate(B) |
| |
| |
| C = generate_extended_codes(L, K, method="auto", device=device) |
| slicer = make_slicer(C) |
| |
| |
| patterns = [] |
| for i in range(K): |
| if i % 4 == 0: |
| pattern = torch.rand(H, W, device=device) * 2.0 |
| elif i % 4 == 1: |
| pattern = torch.rand(H, W, device=device) * 1.0 |
| elif i % 4 == 2: |
| pattern = torch.rand(H, W, device=device) * 0.5 |
| else: |
| pattern = torch.zeros(H, W, device=device) |
| pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device) |
| patterns.append(pattern) |
| |
| patterns = torch.stack(patterns) |
| keys = torch.arange(K, device=device) |
| |
| |
| start_time = time.time() |
| alphas_original = torch.ones(K, device=device) |
| M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original) |
| bank_original.write(M_original - bank_original.read()) |
| original_readouts = slicer(bank_original.read()).squeeze(0) |
| original_time = time.time() - start_time |
| |
| |
| start_time = time.time() |
| M_optimized = optimized_store_pairs( |
| bank_optimized.read(), C, keys, patterns, |
| adaptive_alphas=True, sparsity_threshold=0.01 |
| ) |
| bank_optimized.write(M_optimized - bank_optimized.read()) |
| optimized_readouts = slicer(bank_optimized.read()).squeeze(0) |
| optimized_time = time.time() - start_time |
| |
| |
| original_psnr = [] |
| optimized_psnr = [] |
| |
| for i in range(K): |
| o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy()) |
| opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy()) |
| |
| original_psnr.append(o_psnr) |
| optimized_psnr.append(opt_psnr) |
| |
| avg_original = np.mean(original_psnr) |
| avg_optimized = np.mean(optimized_psnr) |
| fidelity_improvement = avg_optimized - avg_original |
| speed_improvement = (original_time - optimized_time) / original_time * 100 |
| |
| print(f" Original Pipeline:") |
| print(f" Time: {original_time*1000:.2f}ms") |
| print(f" Average PSNR: {avg_original:.1f}dB") |
| |
| print(f" Optimized Pipeline:") |
| print(f" Time: {optimized_time*1000:.2f}ms") |
| print(f" Average PSNR: {avg_optimized:.1f}dB") |
| |
| print(f" Improvements:") |
| print(f" Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)") |
| print(f" Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}") |
| |
| return fidelity_improvement > 0 |
|
|
|
|
| def main(): |
| """Run complete optimization test suite.""" |
| print("π WrinkleBrane Optimization Test Suite") |
| print("="*50) |
| |
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
| |
| success_count = 0 |
| total_tests = 4 |
| |
| try: |
| |
| if test_adaptive_alphas(): |
| print("β
Adaptive alpha scaling: IMPROVED PERFORMANCE") |
| success_count += 1 |
| else: |
| print("β οΈ Adaptive alpha scaling: NO IMPROVEMENT") |
| |
| |
| extended_results = test_extended_codes() |
| if all(r['avg_psnr'] > 50 for r in extended_results.values()): |
| print("β
Extended code generation: WORKING") |
| success_count += 1 |
| else: |
| print("β οΈ Extended code generation: QUALITY ISSUES") |
| |
| |
| if test_hierarchical_memory(): |
| print("β
Hierarchical memory: MEMORY SAVINGS") |
| success_count += 1 |
| else: |
| print("β οΈ Hierarchical memory: NO SAVINGS") |
| |
| |
| if test_optimized_storage(): |
| print("β
Optimized storage pipeline: IMPROVED FIDELITY") |
| success_count += 1 |
| else: |
| print("β οΈ Optimized storage pipeline: NO IMPROVEMENT") |
| |
| print("\n" + "="*50) |
| print(f"π― Optimization Results: {success_count}/{total_tests} improvements successful") |
| |
| if success_count == total_tests: |
| print("π ALL OPTIMIZATIONS WORKING PERFECTLY!") |
| elif success_count > total_tests // 2: |
| print("β
MAJORITY OF OPTIMIZATIONS SUCCESSFUL") |
| else: |
| print("β οΈ Mixed results - some optimizations need work") |
| |
| except Exception as e: |
| print(f"\nβ Optimization tests failed with error: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| return success_count > 0 |
|
|
|
|
| if __name__ == "__main__": |
| main() |