| |
| """ |
| Simple WrinkleBrane Demo |
| Shows basic functionality and a few simple optimizations working. |
| """ |
|
|
| import sys |
| from pathlib import Path |
| sys.path.append(str(Path(__file__).resolve().parent / "src")) |
|
|
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from wrinklebrane.membrane_bank import MembraneBank |
| from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats |
| from wrinklebrane.slicer import make_slicer |
| from wrinklebrane.write_ops import store_pairs |
| from wrinklebrane.metrics import psnr, ssim |
|
|
| def create_test_patterns(K, H, W, device): |
| """Create diverse test patterns for demonstration.""" |
| patterns = [] |
| |
| for i in range(K): |
| pattern = torch.zeros(H, W, device=device) |
| |
| if i % 4 == 0: |
| center = (H // 2, W // 2) |
| radius = 2 + (i // 4) |
| for y in range(H): |
| for x in range(W): |
| if (x - center[0])**2 + (y - center[1])**2 <= radius**2: |
| pattern[y, x] = 1.0 |
| elif i % 4 == 1: |
| size = 4 + (i // 4) |
| start = (H - size) // 2 |
| end = start + size |
| if end <= H and end <= W: |
| pattern[start:end, start:end] = 1.0 |
| elif i % 4 == 2: |
| y = H // 2 + (i // 4) - 1 |
| if 0 <= y < H: |
| pattern[y, :] = 1.0 |
| else: |
| x = W // 2 + (i // 4) - 1 |
| if 0 <= x < W: |
| pattern[:, x] = 1.0 |
| |
| patterns.append(pattern) |
| |
| return torch.stack(patterns) |
|
|
| def demonstrate_basic_functionality(): |
| """Show WrinkleBrane working with perfect recall.""" |
| print("π WrinkleBrane Basic Functionality Demo") |
| print("="*40) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| B, L, H, W, K = 1, 32, 16, 16, 8 |
| |
| print(f"Configuration: L={L}, H={H}, W={W}, K={K} patterns") |
| print(f"Device: {device}") |
| |
| |
| bank = MembraneBank(L, H, W, device=device) |
| bank.allocate(B) |
| |
| C = hadamard_codes(L, K).to(device) |
| slicer = make_slicer(C) |
| |
| patterns = create_test_patterns(K, H, W, device) |
| keys = torch.arange(K, device=device) |
| alphas = torch.ones(K, device=device) |
| |
| |
| print("\nπ Storing patterns...") |
| M = store_pairs(bank.read(), C, keys, patterns, alphas) |
| bank.write(M - bank.read()) |
| |
| |
| print("π Retrieving patterns...") |
| readouts = slicer(bank.read()).squeeze(0) |
| |
| |
| print("\nπ Fidelity Results:") |
| total_psnr = 0 |
| total_ssim = 0 |
| |
| for i in range(K): |
| original = patterns[i] |
| retrieved = readouts[i] |
| |
| psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy()) |
| ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy()) |
| |
| total_psnr += psnr_val |
| total_ssim += ssim_val |
| |
| print(f" Pattern {i}: PSNR={psnr_val:.1f}dB, SSIM={ssim_val:.4f}") |
| |
| avg_psnr = total_psnr / K |
| avg_ssim = total_ssim / K |
| |
| print(f"\nπ― Summary:") |
| print(f" Average PSNR: {avg_psnr:.1f}dB") |
| print(f" Average SSIM: {avg_ssim:.4f}") |
| |
| if avg_psnr > 100: |
| print("β
EXCELLENT: >100dB PSNR (near-perfect recall)") |
| elif avg_psnr > 50: |
| print("β
GOOD: >50dB PSNR (high-quality recall)") |
| else: |
| print("β οΈ LOW: <50dB PSNR (may need optimization)") |
| |
| return avg_psnr |
|
|
| def compare_code_types(): |
| """Compare different orthogonal code types.""" |
| print("\n𧬠Code Types Comparison") |
| print("="*40) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| L, K = 32, 16 |
| |
| code_types = { |
| "Hadamard": hadamard_codes(L, K).to(device), |
| "DCT": dct_codes(L, K).to(device), |
| "Gaussian": gaussian_codes(L, K).to(device) |
| } |
| |
| results = {} |
| |
| for name, codes in code_types.items(): |
| print(f"\n{name} Codes:") |
| |
| |
| stats = coherence_stats(codes) |
| print(f" Max off-diagonal correlation: {stats['max_abs_offdiag']:.6f}") |
| print(f" Mean off-diagonal correlation: {stats['mean_abs_offdiag']:.6f}") |
| |
| |
| B, H, W = 1, 16, 16 |
| bank = MembraneBank(L, H, W, device=device) |
| bank.allocate(B) |
| |
| slicer = make_slicer(codes) |
| patterns = create_test_patterns(K, H, W, device) |
| keys = torch.arange(K, device=device) |
| alphas = torch.ones(K, device=device) |
| |
| |
| M = store_pairs(bank.read(), codes, keys, patterns, alphas) |
| bank.write(M - bank.read()) |
| readouts = slicer(bank.read()).squeeze(0) |
| |
| |
| psnr_values = [] |
| for i in range(K): |
| psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
| psnr_values.append(psnr_val) |
| |
| avg_psnr = np.mean(psnr_values) |
| std_psnr = np.std(psnr_values) |
| |
| print(f" Performance: {avg_psnr:.1f}Β±{std_psnr:.1f}dB PSNR") |
| |
| results[name] = { |
| 'orthogonality': stats['max_abs_offdiag'], |
| 'performance': avg_psnr |
| } |
| |
| |
| best_code = max(results.items(), key=lambda x: x[1]['performance']) |
| print(f"\nπ Best Performing: {best_code[0]} ({best_code[1]['performance']:.1f}dB)") |
| |
| return results |
|
|
| def test_capacity_scaling(): |
| """Test how performance scales with number of stored patterns.""" |
| print("\nπ Capacity Scaling Test") |
| print("="*40) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| L, H, W = 64, 16, 16 |
| |
| |
| pattern_counts = [8, 16, 32, 64] |
| results = [] |
| |
| for K in pattern_counts: |
| print(f"\nTesting {K} patterns (capacity: {K/L:.1%})...") |
| |
| bank = MembraneBank(L, H, W, device=device) |
| bank.allocate(1) |
| |
| |
| C = hadamard_codes(L, K).to(device) |
| slicer = make_slicer(C) |
| |
| patterns = create_test_patterns(K, H, W, device) |
| keys = torch.arange(K, device=device) |
| alphas = torch.ones(K, device=device) |
| |
| |
| M = store_pairs(bank.read(), C, keys, patterns, alphas) |
| bank.write(M - bank.read()) |
| readouts = slicer(bank.read()).squeeze(0) |
| |
| |
| psnr_values = [] |
| for i in range(K): |
| psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
| psnr_values.append(psnr_val) |
| |
| avg_psnr = np.mean(psnr_values) |
| min_psnr = np.min(psnr_values) |
| |
| print(f" PSNR: {avg_psnr:.1f}dB average, {min_psnr:.1f}dB minimum") |
| |
| result = { |
| 'K': K, |
| 'capacity_ratio': K / L, |
| 'avg_psnr': avg_psnr, |
| 'min_psnr': min_psnr |
| } |
| results.append(result) |
| |
| |
| print(f"\nπ Capacity Scaling Summary:") |
| for result in results: |
| status = "β
" if result['avg_psnr'] > 100 else "β οΈ" if result['avg_psnr'] > 50 else "β" |
| print(f" {result['capacity_ratio']:3.0%} capacity: {result['avg_psnr']:5.1f}dB {status}") |
| |
| return results |
|
|
| def demonstrate_wave_interference(): |
| """Show the wave interference pattern that gives WrinkleBrane its name.""" |
| print("\nπ Wave Interference Demonstration") |
| print("="*40) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| L, H, W = 16, 8, 8 |
| |
| |
| bank = MembraneBank(L, H, W, device=device) |
| bank.allocate(1) |
| |
| |
| K = 2 |
| C = hadamard_codes(L, K).to(device) |
| |
| |
| pattern1 = torch.zeros(H, W, device=device) |
| pattern1[H//2, W//2] = 1.0 |
| |
| |
| pattern2 = torch.zeros(H, W, device=device) |
| pattern2[H//2, :] = 0.5 |
| pattern2[:, W//2] = 0.5 |
| |
| patterns = torch.stack([pattern1, pattern2]) |
| keys = torch.tensor([0, 1], device=device) |
| alphas = torch.ones(2, device=device) |
| |
| |
| M = store_pairs(bank.read(), C, keys, patterns, alphas) |
| bank.write(M - bank.read()) |
| |
| |
| membrane_state = bank.read().squeeze(0) |
| |
| print(f"Membrane state shape: {membrane_state.shape}") |
| print(f"Pattern 1 energy: {torch.norm(pattern1):.3f}") |
| print(f"Pattern 2 energy: {torch.norm(pattern2):.3f}") |
| |
| |
| layer_energies = [] |
| for l in range(L): |
| energy = torch.norm(membrane_state[l]).item() |
| layer_energies.append(energy) |
| |
| print(f"Layer energies (first 8): {[f'{e:.3f}' for e in layer_energies[:8]]}") |
| |
| |
| slicer = make_slicer(C) |
| readouts = slicer(bank.read()).squeeze(0) |
| |
| psnr1 = psnr(pattern1.cpu().numpy(), readouts[0].cpu().numpy()) |
| psnr2 = psnr(pattern2.cpu().numpy(), readouts[1].cpu().numpy()) |
| |
| print(f"\nRetrieval fidelity:") |
| print(f" Pattern 1: {psnr1:.1f}dB PSNR") |
| print(f" Pattern 2: {psnr2:.1f}dB PSNR") |
| |
| |
| total_membrane_energy = torch.norm(membrane_state).item() |
| expected_energy = torch.norm(pattern1).item() + torch.norm(pattern2).item() |
| |
| print(f"\nWave interference analysis:") |
| print(f" Total membrane energy: {total_membrane_energy:.3f}") |
| print(f" Expected (no interference): {expected_energy:.3f}") |
| print(f" Interference factor: {total_membrane_energy/expected_energy:.3f}") |
| |
| return membrane_state |
|
|
| def main(): |
| """Run complete WrinkleBrane demonstration.""" |
| print("π WrinkleBrane Complete Demonstration") |
| print("="*50) |
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
| |
| try: |
| |
| basic_psnr = demonstrate_basic_functionality() |
| |
| |
| code_results = compare_code_types() |
| |
| |
| capacity_results = test_capacity_scaling() |
| |
| |
| membrane_state = demonstrate_wave_interference() |
| |
| print("\n" + "="*50) |
| print("π WrinkleBrane Demonstration Complete!") |
| print("="*50) |
| |
| print("\nπ Key Results:") |
| print(f"β’ Basic fidelity: {basic_psnr:.1f}dB PSNR") |
| print(f"β’ Best code type: {max(code_results.items(), key=lambda x: x[1]['performance'])[0]}") |
| print(f"β’ Maximum capacity: {capacity_results[-1]['K']} patterns at {capacity_results[-1]['avg_psnr']:.1f}dB") |
| print(f"β’ Membrane state shape: {membrane_state.shape}") |
| |
| if basic_psnr > 100: |
| print("\nπ WrinkleBrane is performing EXCELLENTLY!") |
| print(" Wave-interference associative memory working at near-perfect fidelity!") |
| else: |
| print(f"\nβ
WrinkleBrane is working correctly with {basic_psnr:.1f}dB fidelity") |
| |
| except Exception as e: |
| print(f"\nβ Demo failed with error: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| return True |
|
|
| if __name__ == "__main__": |
| main() |