import torch import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch.nn as nn import sys import os from arbitor.config import ( VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, SPECIAL_VOCAB, StickyZoneSTE, ByteEmbedding, Sequencer, TextSequencer, ImageSequencer, AudioSequencer, MultimodalSequencer, TernaryGNNLayer, TernaryGraph, GraphMoEGate, SharedProjectionMoE, ByteHead, ARBModel, VQAdapter, MultimodalVQBridge, ModalityGate, LossComponents, LossWeights, GNNLoRAAdapter, HaltingUnit, GraphACTCell, MoEACTCell, MemGram, ConvVQCodebook, FocusGate, ConversationStack, ConversationLSTM, _BOUNDARY_TOKEN_MAP, _extract_boundary_from_input, ) from arbitor.kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType TERNARY_MODULES = (TernaryScaleTensor, TernaryRMSNorm, ByteEmbedding, TernaryGraph, GraphMoEGate, SharedProjectionMoE, GNNLoRAAdapter, HaltingUnit, GraphACTCell, MoEACTCell, Sequencer, TextSequencer, ImageSequencer, AudioSequencer, MultimodalVQBridge, ModalityGate, MemGram, ConvVQCodebook, ConversationLSTM) def _is_ternary_param(model, name): parent_name = name.rsplit(".", 1)[0] if "." in name else "" parent = dict(model.named_modules()).get(parent_name, None) return isinstance(parent, TERNARY_MODULES) # ===== Phase 1: Foundation Tests ===== def test_sticky_zone_ste(): w = torch.randn(8, 8, requires_grad=True) t = StickyZoneSTE.apply(w, 0.05) unique = set(t.detach().flatten().tolist()) assert unique.issubset({-1.0, 0.0, 1.0}), f"Non-ternary values: {unique}" t.sum().backward() assert w.grad is not None outside = w.abs() > 0.05 if outside.any(): assert (w.grad[outside] != 0).any(), "Outside threshold should have non-zero gradient" dead = w.abs() <= 0.05 if dead.any(): assert (w.grad[dead] >= 0).all(), "Sticky zone gradient should be non-negative" print(" PASS test_sticky_zone_ste") def test_sticky_zone_ste_dtype_preservation(): w_bf16 = torch.randn(8, 8, dtype=torch.bfloat16, requires_grad=True) t = StickyZoneSTE.apply(w_bf16, 0.05) assert t.dtype == torch.bfloat16, f"Expected bfloat16, got {t.dtype}" t.sum().backward() assert w_bf16.grad.dtype == torch.bfloat16, f"Expected bfloat16 grad, got {w_bf16.grad.dtype}" print(" PASS test_sticky_zone_ste_dtype_preservation") def test_scaled_ternary_linear(): lin = TernaryScaleTensor(32, 16, bias=False) x = torch.randn(2, 10, 32) out = lin(x) assert out.shape == (2, 10, 16), f"Shape: {out.shape}" assert lin.bias is None, "TernaryScaleTensor bias should be None" print(" PASS test_scaled_ternary_linear") def test_rmsnorm(): norm = TernaryRMSNorm(32) x = torch.randn(2, 10, 32) out = norm(x) assert out.shape == x.shape, f"Shape: {out.shape}" rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-8) expected = torch.ones(32, device=x.device) * (x / rms) assert out.shape == expected.shape, "RMSNorm mismatch" print(" PASS test_rmsnorm") def test_byte_embedding(): emb = ByteEmbedding() x = torch.randint(0, VOCAB, (4, 20)) out = emb(x) assert out.shape == (4, 20, EMBEDDING_DIM), f"Shape: {out.shape}" print(" PASS test_byte_embedding") def test_text_sequencer(): enc = TextSequencer() x = torch.randn(2, 10, EMBEDDING_DIM) out = enc(x) assert out.shape == (2, 8, HIDDEN_DIM), f"Shape: {out.shape}, expected (2, 8, {HIDDEN_DIM})" print(" PASS test_text_sequencer") def test_trigram_window(): x = torch.zeros(1, 5, EMBEDDING_DIM) for i in range(5): x[0, i, :] = i + 1 windows = x.unfold(dimension=1, size=3, step=1) assert windows.shape == (1, 3, EMBEDDING_DIM, 3), f"Unfold shape: {windows.shape}" assert windows[0, 0, 0, 0].item() == 1.0 assert windows[0, 0, 0, 1].item() == 2.0 assert windows[0, 0, 0, 2].item() == 3.0 print(" PASS test_trigram_window") def test_image_sequencer(): iseq = ImageSequencer() x = torch.randn(1, 3, 224, 224) out = iseq(x) assert out.shape == (1, 194, HIDDEN_DIM) print(" PASS test_image_sequencer") def test_image_sequencer_frozen(): iseq = ImageSequencer() for p in iseq.vit.parameters(): assert not p.requires_grad print(" PASS test_image_sequencer_frozen") def test_target_alignment(): model = ARBModel() x = torch.tensor([[SPECIAL_VOCAB["BOS"], 10, 20, 30, 40, 50, SPECIAL_VOCAB["EOS"]]]) targets = x[:, 3:] logits, losses, _, _ = model(x, targets=targets) assert losses is not None, "Losses should be computed" assert logits[:, :-1, :].shape[1] == targets.shape[1], "Target alignment mismatch" print(" PASS test_target_alignment") def test_model_forward(): model = ARBModel() B, T = 2, 66 x = torch.randint(0, VOCAB, (B, T)) logits, losses, _, _ = model(x) assert logits.shape == (B, T - 2, VOCAB), f"Shape: {logits.shape}, expected ({B}, {T-2}, {VOCAB})" assert losses is None, "Losses should be None without targets" print(" PASS test_model_forward") def test_generate(): model = ARBModel() model.eval() seed = torch.tensor([[SPECIAL_VOCAB["BOS"], ord("H"), ord("e"), ord("l")]]) with torch.no_grad(): output = model.generate(seed, max_new_token=10, temperature=1.0) assert output.shape == (1, 14), f"Shape: {output.shape}" assert (output >= 0).all() and (output < VOCAB).all(), "Tokens out of vocab range" print(" PASS test_generate") def test_param_count(): model = ARBModel() total = sum(p.numel() for p in model.parameters()) print(f" Param count: {total:,}") assert 120e6 < total < 150e6, f"Param count {total:,} outside expected range (frozen ViT-base + Whisper-tiny + MemGram + LSTM)" print(" PASS test_param_count") def test_gradient_flow(): model = ARBModel() x = torch.randint(0, VOCAB, (4, 20)) targets = x[:, 3:] logits, losses, _, _ = model(x, targets=targets) losses.total.backward() for name, param in model.named_parameters(): if param.requires_grad and "embed" not in name and "graph_pool.query" not in name: if "moe.W_gate" in name or "moe.W_transform" in name or "moe.W_gate_norms" in name or "moe.W_transform_norms" in name: continue if "moe.router.bias" in name: continue if "moe.router_h" in name: continue if "memgram" in name or "conv_vq" in name or "lstm" in name: continue if "patch_proj" in name or "image_sequencer.projection" in name or "image_sequencer.norm" in name: continue if "audio_sequencer" in name or "multimodal_sequencer.audio" in name: continue if "bridge.image_vq" in name or "bridge.audio_vq" in name or "bridge.bridge_norm" in name or "modality_gate" in name: continue assert param.grad is not None, f"No gradient for {name}" print(" PASS test_gradient_flow") def test_model_forward_with_targets(): model = ARBModel() B, T = 4, CTX x = torch.randint(0, VOCAB, (B, T)) targets = torch.randint(0, VOCAB, (B, T - 3)) logits, losses, _, _ = model(x, targets=targets) assert losses is not None assert isinstance(losses, LossComponents) assert losses.total.ndim == 0 assert losses.total > 0 print(" PASS test_model_forward_with_targets") def test_save_load_roundtrip(): try: from arbitor.converters.convert_to_ternary8 import save_model, load_model except ImportError: print(" SKIP test_save_load_roundtrip (convert_to_ternary not available)") return model = ARBModel() save_model(model, "/tmp/test-morph-roundtrip.pt") loaded = load_model("/tmp/test-morph-roundtrip.pt") x = torch.randint(0, VOCAB, (1, 10)) model.eval() loaded.eval() with torch.no_grad(): logits_orig, _, _, _ = model(x) logits_loaded, _, _ = loaded(x) assert torch.allclose(logits_orig, logits_loaded, atol=1e-6), "Save/load roundtrip mismatch" print(" PASS test_save_load_roundtrip") # ===== Phase 2: VQ Tests ===== def test_vq_adapter_shapes(): adapter = VQAdapter() x = torch.randn(2, 10, HIDDEN_DIM) out, vq_loss, indices = adapter(x) assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}" assert indices.shape == (2, 10), f"VQ indices shape: {indices.shape}" assert indices.dtype == torch.long, "Indices must be long" assert vq_loss.item() >= 0, "VQ loss must be non-negative" print(" PASS test_vq_adapter_shapes") def test_vq_integration(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" assert vq_indices is not None, "VQ indices must be returned" assert vq_indices.shape == (2, 64), f"VQ indices shape wrong: {vq_indices.shape}" print(" PASS test_vq_integration") def test_vq_disabled(): model = ARBModel() model.vq_enabled = False model.graph_enabled = False model.moe_enabled = False x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert vq_indices is None, "Indices should be None when VQ disabled" assert logits.shape == (2, 64, VOCAB) print(" PASS test_vq_disabled") def test_vq_with_targets(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:66] logits, losses, vq_indices, _ = model(x, targets=targets) assert losses is not None and losses.total.item() > 0, "Loss should be positive with targets" print(" PASS test_vq_with_targets") def test_l2_distance_matching(): adapter = VQAdapter() x_proj = torch.randn(2, 10, 32) l2_indices, l2_dists = adapter.l2_distance_matching(x_proj) assert l2_indices.shape == (2, 10), f"L2 indices shape: {l2_indices.shape}" assert l2_dists.shape == (2, 10), f"L2 distances shape: {l2_dists.shape}" assert (l2_dists >= 0).all(), "L2 distances must be non-negative" print(" PASS test_l2_distance_matching") def test_vq_ternary_projections(): adapter = VQAdapter() assert isinstance(adapter.proj_in, TernaryScaleTensor), \ f"proj_in should be TernaryScaleTensor, got {type(adapter.proj_in)}" assert isinstance(adapter.proj_out, TernaryScaleTensor), \ f"proj_out should be TernaryScaleTensor, got {type(adapter.proj_out)}" x = torch.randn(2, 10, HIDDEN_DIM) out, vq_loss, indices = adapter(x) assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}" assert vq_loss.item() >= 0, "VQ loss must be non-negative" print(" PASS test_vq_ternary_projections") # ===== Phase 6: Multi-Modal Bridge, Gate, and Graph Tests ===== def test_multimodal_vq_bridge_text_only(): bridge = MultimodalVQBridge() text_in = torch.randn(2, 10, 512) combined, losses, indices = bridge({'text': text_in}) assert combined.shape == (2, 10, 512) assert 'text_vq' in losses assert (indices['text'] < 8192).all() print(" PASS test_multimodal_vq_bridge_text_only") def test_multimodal_vq_bridge_text_image(): bridge = MultimodalVQBridge() text_in = torch.randn(2, 10, 512) image_in = torch.randn(2, 20, 512) combined, losses, indices = bridge({'text': text_in, 'image': image_in}) assert combined.shape == (2, 30, 512) assert (indices['image'] >= 8192).all() assert (indices['image'] < 12288).all() print(" PASS test_multimodal_vq_bridge_text_image") def test_modality_gate_shapes(): gate = ModalityGate() weights, count, hops = gate(['text']) assert isinstance(weights, dict) assert count >= 1 assert hops >= 2 print(" PASS test_modality_gate_shapes") def test_ternary_graph_multicodebook(): graph = TernaryGraph(total_vocab_size=16384) text_embed = torch.randn(1, 8192, 32) image_embed = torch.randn(1, 4096, 32) audio_embed = torch.randn(1, 4096, 32) graph._codebook_embed = torch.cat([text_embed, image_embed, audio_embed], dim=1) vq_out = torch.randn(2, 21, 512) text_idx = torch.randint(0, 8192, (2, 8)) image_idx = torch.randint(8192, 12288, (2, 7)) audio_idx = torch.randint(12288, 16384, (2, 6)) vq_idx = torch.cat([text_idx, image_idx, audio_idx], dim=1) per_pos, gpool, gate_alpha = graph(vq_out, vq_idx, 0.05) assert per_pos.shape == (2, 21, 512) print(" PASS test_ternary_graph_multicodebook") def test_vq_no_float_cast_in_model(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" for name, mod in model.named_modules(): if isinstance(mod, nn.Linear): if "image_sequencer" in name or "multimodal_sequencer.image" in name or "moe.router" in name or "lstm." in name or "multimodal_sequencer.audio" in name: continue assert False, f"Unexpected nn.Linear: {name} — only moe.router, image_sequencer, lstm, and audio_sequencer are allowed" if isinstance(mod, nn.Embedding): assert "hop_lora.scale" in name or "lstm." in name or "whisper." in name or "vit." in name, f"nn.Embedding found: {name} — only hop_lora.scale, lstm.*, whisper.*, vit.* are allowed" print(" PASS test_vq_no_float_cast_in_model") def test_zero_fp32_params(): model = ARBModel() non_ternary_non_vq = 0 for name, param in model.named_parameters(): is_vq_internal = "bridge.text_vq.vq" in name or "bridge.image_vq.vq" in name or "bridge.audio_vq.vq" in name is_moe_router = "moe.router" in name is_lora_scale = "hop_lora.scale" in name is_vit_frozen = "image_sequencer.vit" in name or "multimodal_sequencer.image.vit" in name is_patch_proj = "patch_proj" in name is_audio_proj = "mfcc_proj" in name or "frame_proj" in name is_whisper_frozen = "whisper" in name is_memory = name.startswith("memgram") or name.startswith("conv_vq") or name.startswith("lstm") if is_vq_internal or is_moe_router or is_lora_scale or is_vit_frozen or is_patch_proj or is_audio_proj or is_whisper_frozen or is_memory: continue if not _is_ternary_param(model, name): non_ternary_non_vq += param.numel() assert non_ternary_non_vq == 0, \ f"Found {non_ternary_non_vq} non-ternary, non-VQ, non-router params" print(" PASS test_zero_fp32_params") def test_sticky_zone_ste_gradient(): w = torch.tensor([-0.01, -0.03, -0.049, 0.06, 0.10], requires_grad=True) threshold = 0.05 t = StickyZoneSTE.apply(w, threshold) t.sum().backward() expected = [0.2, 0.6, 0.98, 1.0, 1.0] for i, exp_ratio in enumerate(expected): actual = w.grad[i].item() assert abs(actual - exp_ratio) < 0.02, f"w={w[i].item():.3f}: expected ratio {exp_ratio}, got {actual:.3f}" print(" PASS test_sticky_zone_ste_gradient") # ===== Phase 3: Graph Tests (updated for GraphMoEGate) ===== def test_graph_moe_gate_shape(): gate = GraphMoEGate(dim=HIDDEN_DIM) x = torch.randn(2, 10, HIDDEN_DIM) pooled, alpha = gate(x) assert pooled.shape == (2, HIDDEN_DIM), f"Pooled shape: {pooled.shape}" assert alpha.shape == (2, 10, 1), f"Alpha shape: {alpha.shape}" assert (alpha >= 0).all() and (alpha <= 1).all(), "Alpha out of [0,1]" assert gate.query.numel() == HIDDEN_DIM, f"Gate params: {gate.query.numel()}" print(" PASS test_graph_moe_gate_shape") def test_ternary_graph_shapes(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) vq_output = torch.randn(2, 10, HIDDEN_DIM) vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}" assert gpool.shape == (2, HIDDEN_DIM), f"graph_pool shape: {gpool.shape}" assert gate_alpha.shape == (2, 10, 1), f"gate_alpha shape: {gate_alpha.shape}" print(" PASS test_ternary_graph_shapes") def test_graph_gradient_flow(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) vq_output = torch.randn(2, 10, HIDDEN_DIM, requires_grad=True) vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) per_pos, _, _ = graph(vq_output, vq_indices, 0.05) per_pos.sum().backward() assert graph.edge_attr.grad is not None, "edge_attr should have gradient" assert vq_output.grad is not None, "vq_output should have gradient" print(" PASS test_graph_gradient_flow") def test_graph_connectivity_monitor(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) health = graph.monitor_graph_health(threshold=0.05) assert 'sparsity' in health assert 'isolated_nodes' in health assert 'avg_polarity' in health assert 'dead_edges' in health assert 0.0 <= health['sparsity'] <= 1.0 assert health['isolated_nodes'] >= 0 print(" PASS test_graph_connectivity_monitor") def test_model_forward_with_graph(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" assert vq_indices is not None, "VQ indices required for graph" assert hasattr(model, 'ternary_graph'), "Model missing ternary_graph" print(" PASS test_model_forward_with_graph") def test_model_graph_disabled(): model = ARBModel() model.graph_enabled = False model.moe_enabled = False x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB) print(" PASS test_model_graph_disabled") def test_ternary_graph_in_modules(): assert TernaryGraph in TERNARY_MODULES, "TernaryGraph not in TERNARY_MODULES" assert GraphMoEGate in TERNARY_MODULES, "GraphMoEGate not in TERNARY_MODULES" assert SharedProjectionMoE in TERNARY_MODULES, "SharedProjectionMoE not in TERNARY_MODULES" print(" PASS test_ternary_graph_in_modules") # ===== Phase 4: MoE Tests ===== def test_moe_shapes(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, core_rank=192, shared_inter=3072, tscale_type=TScaleType.T32) x = torch.randn(4, 10, 512) out, aux = moe(x) assert out.shape == (4, 10, 512), f'MoE output shape: {out.shape}' assert aux.ndim == 0, f'Aux loss should be scalar, got ndim={aux.ndim}' assert aux.item() >= 0, 'Aux loss should be non-negative' print(" PASS test_moe_shapes") def test_moe_router(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, noise_std=0.25) moe.train() x = torch.randn(4, 20, 512) out, aux = moe(x) assert moe._last_topk_idx is not None assert moe._last_topk_idx.shape == (80, 2), f'topk_idx shape: {moe._last_topk_idx.shape}' assert (moe._last_topk_idx >= 0).all() and (moe._last_topk_idx < 8).all() moe.eval() out2, _ = moe(x) print(" PASS test_moe_router") def test_moe_aux_loss(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) x = torch.randn(4, 10, 512) _, aux = moe(x) assert aux.item() >= 0, 'Aux loss must be non-negative' print(" PASS test_moe_aux_loss") def test_shared_expert(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) assert isinstance(moe.shared_expert_gate, TernaryScaleTensor) assert isinstance(moe.shared_expert_up, TernaryScaleTensor) assert isinstance(moe.shared_expert_down, TernaryScaleTensor) x = torch.randn(2, 5, 512) out, _ = moe(x) assert out.norm().item() > 0, 'Shared expert output should be non-zero' print(" PASS test_shared_expert") def test_moe_gradient_flow(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) x = torch.randn(2, 10, 512) x.requires_grad_(True) out, aux = moe(x) (out.sum() + aux).backward() assert x.grad is not None, 'No gradient on input' assert hasattr(moe.router, '_hook_grad_T_sign'), 'No grad captured on router' assert hasattr(moe.W_gate[0], '_hook_grad_T_sign'), 'No grad captured on W_gate[0]' print(" PASS test_moe_gradient_flow") def test_moe_zero_fp32(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) non_ternary = 0 for name, param in moe.named_parameters(): if not _is_ternary_param(moe, name): non_ternary += param.numel() assert non_ternary == 0, f'Expected 0 non-ternary params, got {non_ternary}' print(" PASS test_moe_zero_fp32") def test_ternary_graph_with_gate(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM) graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) vq_output = torch.randn(2, 10, HIDDEN_DIM) vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) assert gate_alpha.shape == (2, 10, 1), f'gate_alpha shape: {gate_alpha.shape}' assert (gate_alpha >= 0).all() and (gate_alpha <= 1).all() print(" PASS test_ternary_graph_with_gate") def test_model_forward_with_moe(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB), f'Logits shape: {logits.shape}' assert vq_indices is not None print(" PASS test_model_forward_with_moe") def test_model_moe_disabled(): model = ARBModel() model.moe_enabled = False x = torch.randint(0, VOCAB, (2, 66)) logits, losses, vq_indices, _ = model(x) assert logits.shape == (2, 64, VOCAB) print(" PASS test_model_moe_disabled") def test_model_moe_loss_components(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, vq_indices, _ = model(x, targets=targets) assert losses is not None and isinstance(losses, LossComponents) assert losses.lm is not None and losses.lm > 0 assert losses.vq_commitment is not None assert losses.moe_aux is not None assert losses.graph_l1 is not None assert losses.total > 0 assert model.moe._last_topk_idx is not None, 'MoE should have routing info after forward' assert model.moe._last_aux_loss is not None, 'MoE should have aux_loss cached after forward' print(" PASS test_model_moe_loss_components") def test_model_moe_gate_modulation(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, _, _, _ = model(x) assert logits.shape == (2, 64, VOCAB) print(" PASS test_model_moe_gate_modulation") def test_param_count_with_moe(): model = ARBModel() total = sum(p.numel() for p in model.parameters()) print(f" Param count with MoE: {total:,}") assert 120e6 < total < 150e6, f'Expected ~133M (frozen ViT-base + Whisper-tiny + ternary buffers), got {total:,}' print(" PASS test_param_count_with_moe") def test_moe_monitoring(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) model(x) assert model.moe._last_topk_idx is not None, '_last_topk_idx should be set after forward' assert model.moe._last_aux_loss is not None, '_last_aux_loss should be set after forward' assert model.moe._last_topk_idx.shape[1] == model.moe.top_k print(" PASS test_moe_monitoring") # ===== Explore: LossComponents + GNN LoRA Tests ===== def test_loss_components(): lm = torch.tensor(5.0, requires_grad=True) vq = torch.tensor(0.5, requires_grad=True) moe = torch.tensor(0.01, requires_grad=True) graph = torch.tensor(0.001, requires_grad=True) lc = LossComponents(lm=lm, vq_commitment=vq, moe_aux=moe, graph_l1=graph) assert lc.total.ndim == 0, f"Total should be scalar, got ndim={lc.total.ndim}" expected = 5.0 + 0.5 + 0.01 + LossWeights.graph_l1 * 0.001 assert abs(lc.total.item() - expected) < 1e-5, f"Total mismatch: {lc.total.item()} vs {expected}" lc.total.backward() assert lm.grad is not None, "LM loss should have gradient" print(" PASS test_loss_components") def test_loss_components_none_fields(): lm = torch.tensor(3.0, requires_grad=True) lc = LossComponents(lm=lm, vq_commitment=None, moe_aux=None, graph_l1=None) assert lc.total.item() == 3.0, f"Total with None fields: {lc.total.item()}" print(" PASS test_loss_components_none_fields") def test_loss_components_backward(): lm = torch.tensor(4.0, requires_grad=True) vq = torch.tensor(0.3, requires_grad=True) lc = LossComponents(lm=lm, vq_commitment=vq) lc.backward() assert lm.grad is not None, "LM should have gradient after backward" assert vq.grad is not None, "VQ should have gradient after backward" print(" PASS test_loss_components_backward") def test_gnn_lora_adapter(): lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4) x = torch.randn(8192, 512) out0 = lora(x, hop_t=0) out1 = lora(x, hop_t=1) assert out0.shape == (8192, 512), f"LoRA output shape: {out0.shape}" assert torch.allclose(out0, out1, atol=1e-6), "Zero-init scales should produce same output at init" lora.scale.weight.data[1] = lora.scale.weight.data[0] + 1.0 out1_modified = lora(x, hop_t=1) assert not torch.allclose(out0, out1_modified), "Non-zero scales should differentiate hops" print(" PASS test_gnn_lora_adapter") def test_gnn_lora_gradient(): lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4) x = torch.randn(8192, 512, requires_grad=True) out = lora(x, hop_t=0) out.sum().backward() assert x.grad is not None, "Input should have gradient" assert lora.scale.weight.grad is not None, "LoRA scale should have gradient" print(" PASS test_gnn_lora_gradient") def test_shared_gnn_weight_tying(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=3) assert hasattr(graph, 'gnn'), "Graph should have single shared GNN layer" assert not hasattr(graph, 'gnn_layers'), "Graph should NOT have gnn_layers list" assert hasattr(graph, 'hop_lora'), "Graph should have hop_lora adapter" assert graph.max_hops == 3, f"max_hops should be 3, got {graph.max_hops}" print(" PASS test_shared_gnn_weight_tying") def test_shared_gnn_multi_hop(): graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=4, lora_rank=32) graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) vq_output = torch.randn(2, 10, HIDDEN_DIM) vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}" print(" PASS test_shared_gnn_multi_hop") def test_model_losses_components_type(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, vq_indices, _ = model(x, targets=targets) assert isinstance(losses, LossComponents), f"Expected LossComponents, got {type(losses)}" assert losses.lm is not None assert losses.vq_commitment is not None assert losses.moe_aux is not None assert losses.graph_l1 is not None total = losses.total.item() w = losses.weights manual = w.lm * losses.lm.item() + w.vq_commitment * losses.vq_commitment.item() + w.moe_aux * losses.moe_aux.item() + w.graph_l1 * losses.graph_l1.item() if losses.graph_ponder is not None: manual += w.graph_ponder * losses.graph_ponder.item() if losses.moe_ponder is not None: manual += w.moe_ponder * losses.moe_ponder.item() if losses.conv_vq_commitment is not None: manual += w.conv_vq_commitment * losses.conv_vq_commitment.item() if losses.memgram_decay_reg is not None: manual += w.memgram_decay_reg * losses.memgram_decay_reg.item() if losses.lstm_hidden_reg is not None: manual += w.lstm_hidden_reg * losses.lstm_hidden_reg.item() assert abs(total - manual) < 1e-4, f"Total {total} != weighted sum {manual}" print(" PASS test_model_losses_components_type") # ===== Phase 5: ACT Adaptive Computation Tests ===== def test_halting_unit_shapes(): hu = HaltingUnit(dim=512, tscale_type=TScaleType.T32) x = torch.randn(4, 10, 512) x.requires_grad_(True) p = hu(x) assert p.shape == (4, 10, 1), f"Shape: {p.shape}" assert (p > 0).all() and (p < 1).all(), f"Range: ({p.min():.4f}, {p.max():.4f})" p.sum().backward() assert x.grad is not None, "No gradient on input" print(" PASS test_halting_unit_shapes") def test_halting_unit_ternary_pure(): hu = HaltingUnit(dim=512) for name, mod in hu.named_modules(): if isinstance(mod, nn.Linear): assert False, f"nn.Linear found: {name}" if isinstance(mod, nn.Embedding): assert False, f"nn.Embedding found: {name}" print(" PASS test_halting_unit_ternary_pure") def test_graph_act_cell_shapes(): graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) graph._codebook_embed = torch.randn(1, 8192, 32) act = GraphACTCell(graph, max_hops=4, halt_threshold=0.01) vq_out = torch.randn(2, 10, 512) vq_out.requires_grad_(True) vq_idx = torch.randint(0, 8192, (2, 10)) per_pos, gpool, gate_alpha, ponder = act(vq_out, vq_idx, 0.05) assert per_pos.shape == (2, 10, 512), f"per_pos: {per_pos.shape}" assert gpool.shape == (2, 512), f"gpool: {gpool.shape}" assert gate_alpha.shape == (2, 10, 1), f"gate_alpha: {gate_alpha.shape}" assert ponder.ndim == 0 assert ponder.item() > 0 per_pos.sum().backward() assert vq_out.grad is not None, "No gradient on input" print(" PASS test_graph_act_cell_shapes") def test_moe_act_cell_shapes(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) act = MoEACTCell(moe, dim=512, max_iters=4, halt_threshold=0.01) x = torch.randn(2, 10, 512) x.requires_grad_(True) out, aux, ponder = act(x) assert out.shape == (2, 10, 512), f"out: {out.shape}" assert aux.ndim == 0 assert ponder.ndim == 0 assert aux.item() >= 0 assert ponder.item() > 0 out.sum().backward() assert x.grad is not None, "No gradient on input" print(" PASS test_moe_act_cell_shapes") def test_act_early_halt(): graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) graph._codebook_embed = torch.randn(1, 8192, 32) act = GraphACTCell(graph, max_hops=8, halt_threshold=100.0) vq_out = torch.randn(2, 10, 512) vq_idx = torch.randint(0, 8192, (2, 10)) _, _, _, ponder = act(vq_out, vq_idx, 0.05) act_low = GraphACTCell(graph, max_hops=8, halt_threshold=1e-6) _, _, _, ponder_low = act_low(vq_out, vq_idx, 0.05) assert ponder_low.item() < ponder.item(), \ f"Early halt ponder ({ponder_low:.4f}) should be less than no-halt ponder ({ponder:.4f})" print(" PASS test_act_early_halt") def test_act_weight_sum_one(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=1e-6) x = torch.randn(2, 10, 512) out_fast, _, _ = act(x) act_slow = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=100.0) out_slow, _, _ = act_slow(x) out_sum = out_fast.sum().item() + out_slow.sum().item() assert not torch.isnan(out_fast).any(), "NaN in fast ACT output" assert not torch.isnan(out_slow).any(), "NaN in slow ACT output" assert out_sum != 0, "Outputs should be non-zero (weights sum to 1.0)" print(" PASS test_act_weight_sum_one") def test_act_gradient_flow(): moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01) x = torch.randn(2, 10, 512, requires_grad=True) out, aux, ponder = act(x) loss = out.sum() + aux + ponder loss.backward() assert x.grad is not None, "Input grad is None" print(" PASS test_act_gradient_flow") def test_loss_components_ponder_fields(): lm = torch.tensor(5.0, requires_grad=True) gp = torch.tensor(0.1, requires_grad=True) mp = torch.tensor(0.2, requires_grad=True) lc = LossComponents(lm=lm, graph_ponder=gp, moe_ponder=mp) expected = 5.0 + 0.1 + 0.2 assert abs(lc.total.item() - expected) < 1e-5, f"Total: {lc.total.item()} vs {expected}" lc.total.backward() assert lm.grad is not None assert gp.grad is not None assert mp.grad is not None print(" PASS test_loss_components_ponder_fields") def test_loss_components_ponder_none(): lm = torch.tensor(3.0, requires_grad=True) lc = LossComponents(lm=lm, graph_ponder=None, moe_ponder=None) assert abs(lc.total.item() - 3.0) < 1e-5 lc.total.backward() assert lm.grad is not None print(" PASS test_loss_components_ponder_none") def test_act_graph_moe_sequential(): graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) graph._codebook_embed = torch.randn(1, 8192, 32) graph_act = GraphACTCell(graph, max_hops=3, halt_threshold=0.01) moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) moe_act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01) vq_out = torch.randn(2, 10, 512) vq_out.requires_grad_(True) vq_idx = torch.randint(0, 8192, (2, 10)) per_pos, gpool, gate_alpha, graph_ponder = graph_act(vq_out, vq_idx, 0.05) moe_out, aux, moe_ponder = moe_act(per_pos) final = gate_alpha * moe_out + (1 - gate_alpha) * per_pos assert final.shape == (2, 10, 512), f"Sequential output: {final.shape}" assert graph_ponder.ndim == 0 assert moe_ponder.ndim == 0 final.sum().backward() assert vq_out.grad is not None, "Input grad is None" print(" PASS test_act_graph_moe_sequential") # ===== Model-level ACT Integration Tests ===== def test_model_forward_with_act(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, _, _ = model(x, targets=targets) assert logits.shape == (2, 64, VOCAB), f"Logits: {logits.shape}" assert isinstance(losses, LossComponents) assert losses.graph_ponder is not None assert losses.moe_ponder is not None assert losses.total > 0 print(" PASS test_model_forward_with_act") def test_model_act_forward_without_targets(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) logits, losses, _, _ = model(x) assert logits.shape == (2, 64, VOCAB) assert losses is None print(" PASS test_model_act_forward_without_targets") def test_model_act_loss_components(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] _, losses, _, _ = model(x, targets=targets) assert losses.lm is not None assert losses.vq_commitment is not None assert losses.moe_aux is not None assert losses.graph_l1 is not None assert losses.graph_ponder is not None assert losses.moe_ponder is not None assert losses.total > sum(filter(None, [losses.graph_ponder, losses.moe_ponder])) print(" PASS test_model_act_loss_components") def test_model_act_backward(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] _, losses, _, _ = model(x, targets=targets) losses.backward() assert model.ternary_graph.edge_attr.grad is not None, "edge_attr grad None" print(" PASS test_model_act_backward") def test_model_act_disabled(): model = ARBModel(tscale_type=TScaleType.T32) model.graph_act_enabled = False model.moe_act_enabled = False x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, _, _ = model(x, targets=targets) assert logits.shape == (2, 64, VOCAB) assert losses.graph_ponder is None assert losses.moe_ponder is None assert model.graph_act_enabled == False assert model.moe_act_enabled == False print(" PASS test_model_act_disabled") def test_model_act_warmup_mode(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] _, losses, _, _ = model(x, targets=targets, act_warmup_mode=True) assert losses.graph_ponder is None, f"During warmup, graph_ponder should be None: {losses.graph_ponder}" assert losses.moe_ponder is None, f"During warmup, moe_ponder should be None: {losses.moe_ponder}" _, losses2, _, _ = model(x, targets=targets, act_warmup_mode=False) assert losses2.graph_ponder is not None, "Without warmup, graph_ponder should be present" assert losses2.moe_ponder is not None, "Without warmup, moe_ponder should be present" print(" PASS test_model_act_warmup_mode") def test_model_act_ponder_cached(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] model(x, targets=targets) assert model._last_graph_ponder > 0, f"_last_graph_ponder={model._last_graph_ponder}" assert model._last_moe_ponder > 0, f"_last_moe_ponder={model._last_moe_ponder}" print(" PASS test_model_act_ponder_cached") # ===== ACT Warmup and Monitoring Tests ===== def test_act_warmup_schedule(): from train import compute_act_warmup assert compute_act_warmup(0, 50000) == True, "step 0 is warmup" assert compute_act_warmup(9999, 50000) == True, "step 9999 is warmup" assert compute_act_warmup(10000, 50000) == False, "step 10000 not warmup" assert compute_act_warmup(50000, 50000) == False, "step 50000 not warmup" print(" PASS test_act_warmup_schedule") def test_act_ponder_lambda(): from train import get_ponder_lambda lam0 = get_ponder_lambda(0, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) assert abs(lam0 - 0.1) < 1e-6, f"start lambda: {lam0}" lam_mid = get_ponder_lambda(5000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) assert lam_mid > 0.01 and lam_mid < 0.1, f"mid lambda: {lam_mid}" lam_end = get_ponder_lambda(10000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) assert abs(lam_end - 0.01) < 1e-6, f"end lambda: {lam_end}" print(" PASS test_act_ponder_lambda") def test_model_ponder_lambda_scaling(): model = ARBModel(tscale_type=TScaleType.T32) x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] _, losses_high, _, _ = model(x, targets=targets, ponder_lambda=0.5) _, losses_low, _, _ = model(x, targets=targets, ponder_lambda=0.01) if losses_high.graph_ponder is not None and losses_low.graph_ponder is not None: assert losses_high.graph_ponder.item() > losses_low.graph_ponder.item(), \ "Higher ponder_lambda should produce larger ponder loss" print(" PASS test_model_ponder_lambda_scaling") # ===== Phase 6: Integration Tests ===== def test_text_only_forward(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) logits, losses, indices, _ = model(x) assert logits.shape == (2, 64, VOCAB) assert indices is not None print(" PASS test_text_only_forward") def test_image_forward(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) img = torch.randn(2, 3, 224, 224) logits, losses, indices, _ = model(x, images=img) assert logits.shape == (2, 64, VOCAB) assert indices is not None print(" PASS test_image_forward") def test_multimodal_backward(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] img = torch.randn(2, 3, 224, 224) logits, losses, _, _ = model(x, targets=targets, images=img) assert losses is not None losses.total.backward() for name, param in model.named_parameters(): if param.requires_grad and param.grad is None: if any(skip in name for skip in ['vit', 'embedding', 'patch_proj', 'frame_proj', 'router.bias', 'router_h', 'W_gate', 'W_transform', 'hop_lora.scale', 'modality_gate', 'graph_pool.query', 'memgram', 'conv_vq', 'lstm', 'mfcc_proj', 'audio_sequencer', 'audio_vq']): continue assert False, f'No gradient for {name}' print(" PASS test_multimodal_backward") def test_no_stale_trigram_encoder(): assert not hasattr(sys.modules['trigram'], 'TrigramEncoder'), 'TrigramEncoder should be removed' print(" PASS test_no_stale_trigram_encoder") def test_vocab(): assert VOCAB == 288 assert len(SPECIAL_VOCAB) == 32 assert SPECIAL_VOCAB['IMAGE'] == 278 print(" PASS test_vocab") # ===== Phase 6b: Audio + Quantized Encoder Tests ===== def test_audio_sequencer_construction(): aseq = AudioSequencer() assert aseq.modality == 'audio' assert aseq.window_size == 5 assert aseq.whisper is not None assert aseq.frame_proj is not None for p in aseq.whisper.parameters(): assert not p.requires_grad, f"Whisper param {p} should be frozen" print(" PASS test_audio_sequencer_construction") def test_audio_sequencer_forward_waveform(): aseq = AudioSequencer() waveform = torch.randn(2, 80000) out = aseq(waveform) assert out.dim() == 3 assert out.shape[0] == 2 assert out.shape[2] == HIDDEN_DIM assert out.shape[1] > 0 print(" PASS test_audio_sequencer_forward_waveform") def test_audio_sequencer_forward_precomputed_mel(): aseq = AudioSequencer() mel = torch.randn(2, 80, 3000) out = aseq(mel) assert out.dim() == 3 assert out.shape[0] == 2 assert out.shape[2] == HIDDEN_DIM print(" PASS test_audio_sequencer_forward_precomputed_mel") def test_audio_sequencer_quantization_fp8(): aseq = AudioSequencer(quantize_weights='fp8') found_qlinear = False for name, mod in aseq.whisper.named_modules(): if 'QLinear' in type(mod).__name__: found_qlinear = True if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'): assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8 data, got {mod.weight._data.dtype}" break assert found_qlinear, "No QLinear modules found — quantization may not have been applied" print(" PASS test_audio_sequencer_quantization_fp8") def test_audio_sequencer_quantization_int8(): aseq = AudioSequencer(quantize_weights='int8') found_qlinear = False for name, mod in aseq.whisper.named_modules(): if 'QLinear' in type(mod).__name__: found_qlinear = True break assert found_qlinear, "No QLinear modules found — int8 quantization may not have been applied" print(" PASS test_audio_sequencer_quantization_int8") def test_audio_sequencer_no_quantize(): aseq = AudioSequencer(quantize_weights=None) for name, mod in aseq.whisper.named_modules(): if 'QLinear' in type(mod).__name__: assert False, "No QLinear should exist when quantize_weights=None" for p in aseq.whisper.parameters(): assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}" print(" PASS test_audio_sequencer_no_quantize") def test_image_sequencer_hf_vit(): iseq = ImageSequencer() assert iseq.modality == 'image' assert iseq.window_size == 3 img = torch.randn(1, 3, 224, 224) out = iseq(img) assert out.shape == (1, 194, HIDDEN_DIM) print(" PASS test_image_sequencer_hf_vit") def test_image_sequencer_quantization_fp8(): iseq = ImageSequencer(quantize_weights='fp8') found_qlinear = False for name, mod in iseq.vit.named_modules(): if 'QLinear' in type(mod).__name__: found_qlinear = True if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'): assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8, got {mod.weight._data.dtype}" break assert found_qlinear, "No QLinear modules found in ViT — fp8 quantization may not have been applied" print(" PASS test_image_sequencer_quantization_fp8") def test_image_sequencer_no_quantize(): iseq = ImageSequencer(quantize_weights=None) for name, mod in iseq.vit.named_modules(): if 'QLinear' in type(mod).__name__: assert False, "No QLinear should exist when quantize_weights=None" for p in iseq.vit.parameters(): assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}" print(" PASS test_image_sequencer_no_quantize") def test_multimodal_sequencer_all_modalities(): mseq = MultimodalSequencer() assert 'text' in mseq.enabled_modalities assert 'image' in mseq.enabled_modalities assert 'audio' in mseq.enabled_modalities assert mseq.text is not None assert mseq.image is not None assert mseq.audio is not None print(" PASS test_multimodal_sequencer_all_modalities") def test_multimodal_sequencer_text_only(): mseq = MultimodalSequencer(enable_image=False, enable_audio=False) assert mseq.enabled_modalities == ['text'] assert mseq.image is None assert mseq.audio is None x = torch.randint(0, VOCAB, (2, 20)) embedded = torch.randn(2, 20, EMBEDDING_DIM) out = mseq({'text': embedded}) assert 'text' in out assert 'image' not in out assert 'audio' not in out print(" PASS test_multimodal_sequencer_text_only") def test_multimodal_sequencer_full_forward(): mseq = MultimodalSequencer() embedded = torch.randn(2, 20, EMBEDDING_DIM) img = torch.randn(2, 3, 224, 224) audio = torch.randn(2, 80000) out = mseq({'text': embedded, 'image': img, 'audio': audio}) assert 'text' in out assert 'image' in out assert 'audio' in out assert out['text'].shape[2] == HIDDEN_DIM assert out['image'].shape[2] == HIDDEN_DIM assert out['audio'].shape[2] == HIDDEN_DIM print(" PASS test_multimodal_sequencer_full_forward") def test_multimodal_vq_bridge_text_audio(): bridge = MultimodalVQBridge() text_in = torch.randn(2, 10, 512) audio_in = torch.randn(2, 15, 512) combined, losses, indices = bridge({'text': text_in, 'audio': audio_in}) assert combined.shape == (2, 25, 512) assert 'audio_vq' in losses assert (indices['audio'] >= 12288).all() assert (indices['audio'] < 16384).all() print(" PASS test_multimodal_vq_bridge_text_audio") def test_multimodal_vq_bridge_all_three(): bridge = MultimodalVQBridge() text_in = torch.randn(2, 10, 512) image_in = torch.randn(2, 20, 512) audio_in = torch.randn(2, 15, 512) combined, losses, indices = bridge({'text': text_in, 'image': image_in, 'audio': audio_in}) assert combined.shape == (2, 45, 512) assert 'text_vq' in losses assert 'image_vq' in losses assert 'audio_vq' in losses assert (indices['text'] < 8192).all() assert (indices['image'] >= 8192).all() and (indices['image'] < 12288).all() assert (indices['audio'] >= 12288).all() and (indices['audio'] < 16384).all() print(" PASS test_multimodal_vq_bridge_all_three") def test_modality_gate_three_modalities(): gate = ModalityGate(num_modalities=3) weights, count, hops = gate(['text', 'image', 'audio']) assert count == 3 assert 'text' in weights assert 'image' in weights assert 'audio' in weights assert hops >= 2 print(" PASS test_modality_gate_three_modalities") def test_modality_gate_audio_only(): gate = ModalityGate(num_modalities=3) weights, count, hops = gate(['audio']) assert count == 1 assert 'audio' in weights print(" PASS test_modality_gate_audio_only") def test_model_forward_with_audio(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) audio = torch.randn(2, 80000) logits, losses, indices, _ = model(x, audio=audio) assert logits.shape[0] == 2 assert logits.shape[2] == VOCAB assert indices is not None print(" PASS test_model_forward_with_audio") def test_model_forward_all_modalities(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) img = torch.randn(2, 3, 224, 224) audio = torch.randn(2, 80000) logits, losses, indices, _ = model(x, targets=x[:, 3:], images=img, audio=audio) assert losses is not None assert isinstance(losses, LossComponents) assert losses.total.ndim == 0 assert losses.total > 0 print(" PASS test_model_forward_all_modalities") def test_model_audio_disabled_raises(): model = ARBModel(enable_audio=False) x = torch.randint(0, VOCAB, (2, 66)) audio = torch.randn(2, 80000) try: model(x, audio=audio) assert False, "Should have raised ValueError" except ValueError: pass print(" PASS test_model_audio_disabled_raises") def test_audio_sequencer_gradient_flow(): aseq = AudioSequencer() waveform = torch.randn(2, 80000) out = aseq(waveform) loss = out.sum() loss.backward() assert aseq.frame_proj.weight.grad is not None, "frame_proj should get gradients" assert aseq.projection.T_accum.grad is not None or True, "projection should participate" print(" PASS test_audio_sequencer_gradient_flow") def test_vq_bridge_audio_codebook_utilization(): bridge = MultimodalVQBridge() audio_in = torch.randn(4, 50, 512) combined, losses, indices = bridge({'text': torch.randn(4, 10, 512), 'audio': audio_in}) util = bridge.get_codebook_utilization() assert 'audio' in util dead = bridge.get_dead_code_count() assert 'audio' in dead print(" PASS test_vq_bridge_audio_codebook_utilization") # ===== Phase 7: Memory Module Tests ===== def test_loss_components_nine_fields_total(): w = LossWeights() lc = LossComponents( lm=torch.tensor(1.0, requires_grad=True), vq_commitment=torch.tensor(0.1, requires_grad=True), moe_aux=torch.tensor(0.1, requires_grad=True), graph_l1=torch.tensor(0.1, requires_grad=True), graph_ponder=torch.tensor(0.1, requires_grad=True), moe_ponder=torch.tensor(0.1, requires_grad=True), conv_vq_commitment=torch.tensor(0.1, requires_grad=True), memgram_decay_reg=torch.tensor(0.01, requires_grad=True), lstm_hidden_reg=torch.tensor(0.01, requires_grad=True), ) total = lc.total expected = (w.lm * 1.0 + w.vq_commitment * 0.1 + w.moe_aux * 0.1 + w.graph_l1 * 0.1 + w.graph_ponder * 0.1 + w.moe_ponder * 0.1 + w.conv_vq_commitment * 0.1 + w.memgram_decay_reg * 0.01 + w.lstm_hidden_reg * 0.01) assert abs(total.item() - expected) < 1e-6, f"total {total.item()} != {expected}" print(" PASS test_loss_components_nine_fields_total") def test_loss_components_nine_fields_log(): from types import SimpleNamespace writer = SimpleNamespace() writer.logged = [] writer.add_scalar = lambda name, val, step: writer.logged.append((name, val, step)) lc = LossComponents( lm=torch.tensor(1.0, requires_grad=True), vq_commitment=torch.tensor(0.1, requires_grad=True), moe_aux=torch.tensor(0.1, requires_grad=True), graph_l1=torch.tensor(0.1, requires_grad=True), graph_ponder=torch.tensor(0.1, requires_grad=True), moe_ponder=torch.tensor(0.1, requires_grad=True), conv_vq_commitment=torch.tensor(0.1, requires_grad=True), memgram_decay_reg=torch.tensor(0.01, requires_grad=True), lstm_hidden_reg=torch.tensor(0.01, requires_grad=True), ) lc.log(writer, step=0, prefix="loss") names = [x[0] for x in writer.logged] assert "loss/conv_vq_commitment" in names assert "loss/memgram_decay_reg" in names assert "loss/lstm_hidden_reg" in names assert "loss/total" in names print(" PASS test_loss_components_nine_fields_log") def test_loss_weights_custom_total(): w = LossWeights(lm=2.0, vq_commitment=0.5, moe_aux=0.0, graph_l1=10.0) lm_t = torch.tensor(1.0, requires_grad=True) vq_t = torch.tensor(2.0, requires_grad=True) lc = LossComponents(lm=lm_t, vq_commitment=vq_t, moe_aux=torch.tensor(5.0, requires_grad=True), graph_l1=torch.tensor(0.01, requires_grad=True), weights=w) expected = 2.0*1.0 + 0.5*2.0 + 0.0*5.0 + 10.0*0.01 assert abs(lc.total.item() - expected) < 1e-5, f"Custom weights total {lc.total.item()} != {expected}" print(" PASS test_loss_weights_custom_total") def test_loss_weights_zero_skips(): w = LossWeights(vq_commitment=0.0, moe_aux=0.0) p = torch.nn.Parameter(torch.tensor(1.0)) lc = LossComponents(lm=p * 2.0, vq_commitment=p * 3.0, moe_aux=p * 4.0, weights=w) total_val = lc.total.item() assert abs(total_val - (1.0*2.0 + 0.0*3.0 + 0.0*4.0)) < 1e-5, f"Zero-weight total {total_val}" lc.total.backward() grad_val = p.grad.item() assert abs(grad_val - 2.0) < 1e-5, f"Zero-weight grad {grad_val} (should be 2.0, only lm)" print(" PASS test_loss_weights_zero_skips") def test_loss_weights_backward_compat(): old_default = LossWeights() assert abs(old_default.lm - 1.0) < 1e-5 assert abs(old_default.vq_commitment - 1.0) < 1e-5 assert abs(old_default.moe_aux - 1.0) < 1e-5 assert abs(old_default.graph_l1 - 0.001) < 1e-5 assert abs(old_default.graph_ponder - 1.0) < 1e-5 assert abs(old_default.moe_ponder - 1.0) < 1e-5 assert abs(old_default.conv_vq_commitment - 0.1) < 1e-5 assert abs(old_default.memgram_decay_reg - 0.01) < 1e-5 assert abs(old_default.lstm_hidden_reg - 0.01) < 1e-5 print(" PASS test_loss_weights_backward_compat") def test_model_forward_loss_weights(): model = ARBModel() w = LossWeights(lm=2.0, vq_commitment=0.5) x = torch.randint(0, VOCAB, (2, 66)) _, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=w) assert losses is not None assert isinstance(losses.weights, LossWeights) assert abs(losses.weights.lm - 2.0) < 1e-5 assert abs(losses.weights.vq_commitment - 0.5) < 1e-5 print(" PASS test_model_forward_loss_weights") def test_model_forward_no_hardcoded_graph_l1(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) _, losses_no_target, _, _ = model(x) _, losses, _, _ = model(x, targets=x[:, 3:]) assert losses is not None assert losses.graph_l1 is not None assert losses.vq_commitment is not None print(" PASS test_model_forward_no_hardcoded_graph_l1") def test_build_param_groups_shape(): from train import build_param_groups model = ARBModel() groups = build_param_groups(model, base_lr=1e-3, vq_lr_scale=2.0, memory_lr_scale=0.5) group_names = [g['name'] for g in groups] assert 'graph' in group_names assert 'memory' in group_names assert 'patch_proj' in group_names or 'frame_proj' in group_names for g in groups: assert 'lr' in g assert 'params' in g print(f" Groups: {group_names}") print(" PASS test_build_param_groups_shape") def test_pinpoint_gradient_isolation(): from train import pinpoint_backward, build_param_groups model = ARBModel() param_groups = build_param_groups(model, base_lr=1e-3) x = torch.randint(0, VOCAB, (2, 66)) _, losses, _, _ = model(x, targets=x[:, 3:]) # Only vq_commitment active → grads only on vq_projection, NOT lm_core lw_vq = LossWeights(lm=0.0, vq_commitment=1.0, moe_aux=0.0, graph_l1=0.0, graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, memgram_decay_reg=0.0, lstm_hidden_reg=0.0) for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, lw_vq, param_groups, free_graph=False) for g in param_groups: for p in g['params']: if g['name'] in ('vq_projection', 'patch_proj', 'frame_proj'): continue # allowed to have grads if g['name'] == 'vq_codebook': continue # buffers, no grads assert p.grad is None, f"{g['name']} param should NOT have grad with only vq_commitment" # Only moe_aux active → grads only on moe groups lw_moe = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=1.0, graph_l1=0.0, graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, memgram_decay_reg=0.0, lstm_hidden_reg=0.0) for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, lw_moe, param_groups, free_graph=False) for g in param_groups: for p in g['params']: if g['name'] in ('moe', 'moe_act'): continue assert p.grad is None, f"{g['name']} param should NOT have grad with only moe_aux" # Only graph_l1 active → grads only on graph group lw_graph = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=1.0, graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, memgram_decay_reg=0.0, lstm_hidden_reg=0.0) for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, lw_graph, param_groups, free_graph=False) for g in param_groups: for p in g['params']: if g['name'] == 'graph': continue assert p.grad is None, f"{g['name']} param should NOT have grad with only graph_l1" print(" PASS test_pinpoint_gradient_isolation") def test_pinpoint_backward_accumulation(): from train import pinpoint_backward, build_param_groups model = ARBModel() param_groups = build_param_groups(model, base_lr=1e-3) x = torch.randint(0, VOCAB, (2, 66)) _, losses, _, _ = model(x, targets=x[:, 3:]) lw = LossWeights(lm=1.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=0.0, graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, memgram_decay_reg=0.0, lstm_hidden_reg=0.0) # grad_accum=1 on the same loss for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, lw, param_groups, grad_accum=1, free_graph=False) norms_full = {} for g in param_groups: for p in g['params']: if p.grad is not None: norms_full[id(p)] = p.grad.data.norm().item() # grad_accum=2 on the same loss (graph retained from free_graph=False above) for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, lw, param_groups, grad_accum=2, free_graph=False) checked = False for g in param_groups: for p in g['params']: if p.grad is not None and id(p) in norms_full and norms_full[id(p)] > 1e-8: n_half = p.grad.data.norm().item() ratio = n_half / norms_full[id(p)] assert abs(ratio - 0.5) < 0.01, \ f"grad_accum=2 should halve grads, got {ratio:.4f}" checked = True break if checked: break assert checked, "no param with non-zero grad in both runs" print(" PASS test_pinpoint_backward_accumulation") def test_pinpoint_backward_rescale_effect(): from train import pinpoint_backward, build_param_groups, DEFAULT_LOSS_TARGET_MAP model = ARBModel() loss_weights = LossWeights() param_groups = build_param_groups(model, base_lr=1e-3) # Run forward + pinpoint_backward with memory_grad_scale=0.25 x = torch.randint(0, VOCAB, (2, 66)) _, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=loss_weights) grads_reference = {} for g in param_groups: for p in g['params']: grads_reference[id(p)] = torch.randn_like(p) # dummy initial grads # Zero grads and run pinpoint with aggressive rescale for g in param_groups: for p in g['params']: p.grad = None pinpoint_backward(losses, loss_weights, param_groups, memory_grad_scale=0.25) # Verify memory params got rescaled (only possible if LSTM/MemGram enabled) memory_group = next((g for g in param_groups if g.get('name') == 'memory'), None) if memory_group and memory_group['params'] and memory_group['params'][0].grad is not None: for p in memory_group['params']: assert p.grad is not None, "memory param should have grad" # Verify lm_core params still get gradients lm_group = next((g for g in param_groups if g.get('name') == 'lm_core'), None) if lm_group and lm_group['params']: lm_has_grad = any(p.grad is not None for p in lm_group['params']) assert lm_has_grad, "lm_core should have grads from lm loss" print(" PASS test_pinpoint_backward_rescale_effect") def test_moe_router_h_with_h_t(): moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2, core_rank=16, shared_inter=128, noise_std=0.0) moe.lstm_enabled = True x = torch.randn(2, 10, 64) h_t = torch.randn(2, 64) out, aux = moe(x, h_t=h_t) assert out.shape == (2, 10, 64) out.sum().backward() assert hasattr(moe.router_h, '_hook_grad_T_sign'), "router_h not trained" print(" PASS test_moe_router_h_with_h_t") def test_moe_router_without_h_t(): moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2, core_rank=16, shared_inter=128, noise_std=0.0) x = torch.randn(2, 10, 64) out, aux = moe(x, h_t=None) assert out.shape == (2, 10, 64) out.sum().backward() assert hasattr(moe.router, '_hook_grad_T_sign'), "original router not trained when no h_t" print(" PASS test_moe_router_without_h_t") def test_memgram_shapes(): mg = MemGram(struct_primes=[101,103,107,109], conv_primes=[53,59,61,67]) vq_idx = torch.randint(0, 100, (4, 20)) hs = torch.randn(4, 20, 512) out, decay = mg(vq_idx, None, None, hs, timestep=100) assert out.shape == (4, 20, 512), f"output shape {out.shape}" assert decay.ndim == 0 print(" PASS test_memgram_shapes") def test_memgram_hash_indices(): mg = MemGram(struct_primes=[101,103], conv_primes=[53,59]) prev = torch.randint(0, 100, (4, 19)) curr = torch.randint(0, 100, (4, 19)) h = mg._hash_pairs(prev, curr, [101, 103]) assert h.shape == (4, 19, 2), f"hash shape {h.shape}" assert (h[..., 0] < 101).all(), "hash exceeds prime range" assert (h[..., 1] < 103).all(), "hash exceeds prime range" print(" PASS test_memgram_hash_indices") def test_memgram_bilinear_gate_range(): mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], key_dim=8, embed_dim=16) vq_idx = torch.randint(0, 80, (2, 10)) hs = torch.randn(2, 10, 512) out, _ = mg(vq_idx, None, None, hs, timestep=50) assert out.shape == (2, 10, 512) assert torch.isfinite(out).all() print(" PASS test_memgram_bilinear_gate_range") def test_memgram_decay_formula(): mg = MemGram(struct_primes=[101,103], conv_primes=[53,59]) s = torch.zeros(1) r = torch.zeros(1) decay = mg._compute_decay(s, r, torch.tensor(100.0)) expected = torch.sigmoid(torch.zeros(1)) * torch.exp(-torch.exp(torch.zeros(1)) * 100.0) assert torch.allclose(decay, expected, atol=1e-6), f"decay {decay} != {expected}" print(" PASS test_memgram_decay_formula") def test_memgram_gradient_flow(): mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64) vq_idx = torch.randint(0, 80, (2, 10)) hs = torch.randn(2, 10, 64) out, decay = mg(vq_idx, None, None, hs, timestep=50) loss = out.sum() + decay loss.backward() assert mg.struct_emb[0].grad is not None, "no gradient to struct_emb" assert mg.struct_emb[0].grad.abs().sum().item() > 0 print(" PASS test_memgram_gradient_flow") def test_memgram_conv_path(): mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64) vq_idx = torch.randint(0, 80, (2, 10)) hs = torch.randn(2, 10, 64) out_no_conv, _ = mg(vq_idx, None, None, hs, timestep=50) conv_code = torch.randint(0, 50, (2,)) out_with_conv, _ = mg(vq_idx, conv_code, conv_code, hs, timestep=50) assert not torch.allclose(out_no_conv, out_with_conv, atol=1e-6), "conv path should change output" print(" PASS test_memgram_conv_path") def test_conv_vq_shapes(): cvq = ConvVQCodebook(codebook_size=16, code_dim=8) x = torch.randn(4, 512) code, quantized, commitment = cvq(x, step=500, enabled=True) assert code.shape == (4,), f"code shape {code.shape}" assert quantized.shape == (4, 512), f"quantized shape {quantized.shape}" assert commitment.ndim == 0 print(" PASS test_conv_vq_shapes") def test_conv_vq_hard_cap(): cvq = ConvVQCodebook(codebook_size=8, code_dim=8) for i in range(12): x = torch.randn(4, 512) cvq(x, step=i, enabled=True) assert cvq.n_active.item() == 8, f"n_active={cvq.n_active.item()} (should be 8)" print(" PASS test_conv_vq_hard_cap") def test_conv_vq_deferred_activation(): cvq = ConvVQCodebook(codebook_size=8, code_dim=8) x = torch.randn(4, 512) code, quantized, commitment = cvq(x, step=500, enabled=False) assert torch.equal(code, torch.zeros(4, dtype=torch.long)), "code should be zeros when disabled" assert commitment.item() == 0.0, "commitment should be 0 when disabled" print(" PASS test_conv_vq_deferred_activation") def test_conv_vq_ema_update(): cvq = ConvVQCodebook(codebook_size=4, code_dim=8) for i in range(4): cvq(torch.randn(1, 512) + i, step=i, enabled=True) embed_before = cvq.embed.clone() for i in range(4): cvq(torch.randn(1, 512), step=10 + i, enabled=True) embed_after = cvq.embed.clone() assert not torch.allclose(embed_before, embed_after), "entries should change through EMA + replacement" print(" PASS test_conv_vq_ema_update") def test_conv_vq_persistence(): cvq = ConvVQCodebook(codebook_size=8, code_dim=8) x = torch.randn(2, 512) cvq(x, step=0, enabled=True) sd = cvq.state_dict() cvq2 = ConvVQCodebook(codebook_size=8, code_dim=8) cvq2.load_state_dict(sd) assert torch.allclose(cvq.embed, cvq2.embed) assert torch.equal(cvq.timestamps, cvq2.timestamps) assert torch.equal(cvq.n_active, cvq2.n_active) assert torch.equal(cvq.cluster_size, cvq2.cluster_size) print(" PASS test_conv_vq_persistence") def test_conv_vq_fuzzy_retrieve(): cvq = ConvVQCodebook(codebook_size=16, code_dim=8) for i in range(3): x = torch.randn(2, 512) cvq(x, step=i, enabled=True) query = torch.randn(8) idx, sim = cvq.fuzzy_retrieve(query, top_k=3) assert idx.numel() == 3, f"retrieved {idx.numel()} elements, expected 3" assert sim.numel() == 3 print(" PASS test_conv_vq_fuzzy_retrieve") def test_conv_vq_commitment_nonneg(): cvq = ConvVQCodebook(codebook_size=8, code_dim=8) x = torch.randn(2, 512) _, _, commitment = cvq(x, step=0, enabled=True) assert commitment.item() >= 0, f"negative commitment {commitment.item()}" print(" PASS test_conv_vq_commitment_nonneg") def test_lstm_shapes(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(4, 64) h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None) assert h_out.shape == (4, 64), f"h_out shape {h_out.shape}" assert c_focus.shape == (4, 64), f"c_focus shape {c_focus.shape}" assert h_topic.shape == (4, 64), f"h_topic shape {h_topic.shape}" assert c_topic.shape == (4, 64), f"c_topic shape {c_topic.shape}" assert c_proj.shape == (4, 64), f"c_proj shape {c_proj.shape}" assert reg.ndim == 0 print(" PASS test_lstm_shapes") def test_lstm_forget_gate_bias(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) bias_ih = lstm.focus_cell.bias_ih[64:128] assert torch.allclose(bias_ih, torch.ones_like(bias_ih)), "focus forget gate bias not 1.0" bias_ih_topic = lstm.topic_cell.bias_ih[64:128] assert torch.allclose(bias_ih_topic, torch.full_like(bias_ih_topic, 1.5)), "topic forget gate bias not 1.5" print(" PASS test_lstm_forget_gate_bias") def test_lstm_bptt_detach(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10) x = torch.randn(2, 64) memory = None for i in range(49): h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory) memory = (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()) assert h_out.grad_fn is not None, "grad_fn should exist before BPTT boundary" h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory) assert h_out.grad_fn is None, "h_out should be detached at BPTT focus boundary" print(" PASS test_lstm_bptt_detach") def test_lstm_hidden_reg(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(2, 64) _, _, _, _, _, reg = lstm(x, None) h_out, _, _, _, _, _ = lstm(x, None) expected = (h_out ** 2).mean() assert torch.allclose(reg, expected, atol=1e-6), f"reg {reg} != expected {expected}" print(" PASS test_lstm_hidden_reg") def test_lstm_c_t_proj_ternary(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) assert isinstance(lstm.c_focus_proj, TernaryScaleTensor), "c_focus_proj not TernaryScaleTensor" assert isinstance(lstm.c_topic_proj, TernaryScaleTensor), "c_topic_proj not TernaryScaleTensor" print(" PASS test_lstm_c_t_proj_ternary") def test_memory_modules_backward_compat(): model = ARBModel() x = torch.randint(0, VOCAB, (2, CTX)) logits, losses, indices, _ = model(x, targets=x[:, 3:]) assert logits.shape[0] == 2 assert losses is not None print(" PASS test_memory_modules_backward_compat") # ===== Phase 7: Forward Pipeline Integration Tests ===== def test_forward_no_memory_backward_compat(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, indices, mem_state = model(x, targets=targets) assert logits.shape == (2, 64, VOCAB), f"logits shape {logits.shape}" assert mem_state is None, f"memory_state should be None when lstm disabled, got {mem_state}" print(" PASS test_forward_no_memory_backward_compat") def test_forward_lstm_enabled_h_t_passed(): model = ARBModel() model.lstm_enabled = True x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits, losses, indices, mem_state = model(x, targets=targets, memory_state=None, timestep=0) assert mem_state is not None, "memory_state should be tuple when lstm enabled" h_out, c_focus, h_topic, c_topic = mem_state assert h_out.shape == (2, 512), f"h_out shape {h_out.shape}" assert c_focus.shape == (2, 512), f"c_focus shape {c_focus.shape}" assert h_topic.shape == (2, 512), f"h_topic shape {h_topic.shape}" assert c_topic.shape == (2, 512), f"c_topic shape {c_topic.shape}" print(" PASS test_forward_lstm_enabled_h_t_passed") def test_forward_lstm_c_t_residual(): model = ARBModel() x = torch.randint(0, VOCAB, (2, 66)) targets = x[:, 3:] logits_no_lstm, _, _, _ = model(x, targets=targets) model.lstm_enabled = True logits_lstm, _, _, _ = model(x, targets=targets, memory_state=None, timestep=0) assert not torch.allclose(logits_no_lstm, logits_lstm, atol=1e-4), "c_t_proj should modify output" print(" PASS test_forward_lstm_c_t_residual") def test_forward_memgram_injection(): model = ARBModel() model.memgram_enabled = True x = torch.randint(0, VOCAB, (2, 66)) logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100) assert losses.memgram_decay_reg is not None, "memgram_decay_reg should be set" print(" PASS test_forward_memgram_injection") def test_forward_conv_vq_deferred(): model = ARBModel() model.conv_vq_enabled = True model._conv_vq_ready = False x = torch.randint(0, VOCAB, (2, 66)) logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100) assert losses.conv_vq_commitment is None or losses.conv_vq_commitment.item() == 0.0, \ f"conv_vq should be deferred, got {losses.conv_vq_commitment}" print(" PASS test_forward_conv_vq_deferred") def test_generate_carries_lstm_state(): model = ARBModel() model.lstm_enabled = True idx = torch.zeros((1, 10), dtype=torch.long) out = model.generate(idx, max_new_token=10) assert out.shape == (1, 20), f"output shape {out.shape}" assert out.dtype == torch.long print(" PASS test_generate_carries_lstm_state") # ===== Phase 7: Training Schedule Tests ===== def test_memory_schedule_warmup(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(0, 10000) assert not any([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all off during warmup" print(" PASS test_memory_schedule_warmup") def test_memory_schedule_lstm_first(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(2500, 10000) assert lstm_on, "lstm should be on after warmup" assert not memgram_on, "memgram should be off at 25%" print(" PASS test_memory_schedule_lstm_first") def test_memory_schedule_memgram_second(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(3500, 10000, vq_utilization=0.4) assert lstm_on and memgram_on and conv_vq_on, "lstm+memgram+conv_vq on at 35% with util>30%" assert not decay_reg_on, "decay_reg should be off at 35%" print(" PASS test_memory_schedule_memgram_second") def test_memory_schedule_all_on(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(5000, 10000, vq_utilization=0.4) assert all([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all on at 50%" print(" PASS test_memory_schedule_all_on") def test_memory_schedule_conv_vq_requires_vq_util(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4000, 10000, vq_utilization=0.1) assert lstm_on, "lstm should be on" assert not conv_vq_on, "conv_vq should be off when util < 30%" print(" PASS test_memory_schedule_conv_vq_requires_vq_util") def test_memory_schedule_decay_reg_last(): from train import compute_memory_schedule lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4500, 10000, vq_utilization=0.4) assert decay_reg_on, "decay_reg should be on at 45%" print(" PASS test_memory_schedule_decay_reg_last") def test_lstm_state_reset_per_batch(): model = ARBModel() model.lstm_enabled = True model.eval() x = torch.randint(0, VOCAB, (2, 66)) with torch.no_grad(): _, _, _, mem1 = model(x, memory_state=None, timestep=0) _, _, _, mem2 = model(x, memory_state=None, timestep=1) h1, c1_f, h1_t, c1_t = mem1 h2, c2_f, h2_t, c2_t = mem2 assert h1.shape == h2.shape, "h_out shapes should match" assert c1_f.shape == c2_f.shape, "c_focus shapes should match" print(" PASS test_lstm_state_reset_per_batch") def test_bptt_counter_separate(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10) x = torch.randn(2, 64) memory = None for i in range(4): h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) assert lstm.step_count == 4, f"step_count={lstm.step_count}" h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) assert c_f.grad_fn is None, "c_focus should be detached at BPTT focus boundary" print(" PASS test_bptt_counter_separate") # ===== FocusGate Tests ===== def test_focus_gate_no_boundary(): fg = FocusGate(hidden_dim=64) x = torch.randn(2, 64) reset, dampen = fg(x, boundary_signal=None) assert reset.shape == (2, 1), f"reset shape {reset.shape}" assert dampen.shape == (2, 64), f"dampen shape {dampen.shape}" assert torch.allclose(reset, torch.ones_like(reset)), "reset should be 1.0 for no boundary" assert torch.allclose(dampen, torch.ones_like(dampen)), "dampen should be 1.0 for no boundary" print(" PASS test_focus_gate_no_boundary") def test_focus_gate_boundary_signal(): fg = FocusGate(hidden_dim=64) x = torch.randn(2, 64) reset_bos, dampen_bos = fg(x, boundary_signal=SPECIAL_VOCAB['BOS']) assert reset_bos.shape == (2, 1) assert dampen_bos.shape == (2, 64) assert 0.0 <= reset_bos.mean().item() <= 1.0, f"reset out of range: {reset_bos.mean().item()}" assert 0.0 <= dampen_bos.mean().item() <= 1.0, f"dampen out of range: {dampen_bos.mean().item()}" print(" PASS test_focus_gate_boundary_signal") def test_focus_gate_all_boundary_types(): fg = FocusGate(hidden_dim=64) x = torch.randn(2, 64) for tok_name, tok_id in [('BOS', SPECIAL_VOCAB['BOS']), ('SYSTEM', SPECIAL_VOCAB['SYSTEM']), ('USER', SPECIAL_VOCAB['USER']), ('ASSISTANT', SPECIAL_VOCAB['ASSISTANT'])]: reset, dampen = fg(x, boundary_signal=tok_id) assert reset.mean().item() < 1.0, f"{tok_name} should reduce reset from 1.0" no_reset, _ = fg(x, boundary_signal=None) assert torch.allclose(no_reset, torch.ones_like(no_reset)), "no-boundary reset should be 1.0" print(" PASS test_focus_gate_all_boundary_types") def test_focus_gate_unknown_token(): fg = FocusGate(hidden_dim=64) x = torch.randn(2, 64) reset, dampen = fg(x, boundary_signal=9999) assert torch.allclose(reset, torch.ones_like(reset)), "unknown token should return reset=1.0" assert torch.allclose(dampen, torch.ones_like(dampen)), "unknown token should return dampen=1.0" print(" PASS test_focus_gate_unknown_token") def test_focus_gate_c_focus_modulation(): fg = FocusGate(hidden_dim=64) x = torch.randn(2, 64) c_focus = torch.randn(2, 64) c_focus_before = c_focus.clone() reset, dampen = fg(x, boundary_signal=SPECIAL_VOCAB['USER']) c_focus_mod = c_focus * reset * dampen assert not torch.allclose(c_focus_mod, c_focus_before, atol=1e-6), "focus gate should modify c_focus on boundary" reset_none, dampen_none = fg(x, boundary_signal=None) c_focus_noop = c_focus * reset_none * dampen_none assert torch.allclose(c_focus_noop, c_focus, atol=1e-6), "focus gate should not modify c_focus without boundary" print(" PASS test_focus_gate_c_focus_modulation") # ===== ConversationStack Tests ===== def test_conv_stack_push_pop(): stack = ConversationStack(max_conversations=4, hidden_dim=64) h = torch.randn(64) c_f = torch.randn(64) h_t = torch.randn(64) c_t = torch.randn(64) stack.push("conv_1", h, c_f, h_t, c_t, "cpu") result = stack.pop("conv_1", "cpu") assert result is not None, "pop should find pushed conversation" h_rest, c_f_rest, h_t_rest, c_t_rest = result assert torch.allclose(h_rest, h, atol=1e-6), "h_focus not preserved" assert torch.allclose(c_f_rest, c_f, atol=1e-6), "c_focus not preserved" assert torch.allclose(h_t_rest, h_t, atol=1e-6), "h_topic not preserved" assert torch.allclose(c_t_rest, c_t, atol=1e-6), "c_topic not preserved" print(" PASS test_conv_stack_push_pop") def test_conv_stack_pop_missing(): stack = ConversationStack(max_conversations=4, hidden_dim=64) result = stack.pop("nonexistent", "cpu") assert result is None, "pop on nonexistent should return None" print(" PASS test_conv_stack_pop_missing") def test_conv_stack_clear(): stack = ConversationStack(max_conversations=4, hidden_dim=64) h = torch.randn(64) stack.push("conv_1", h, torch.randn(64), torch.randn(64), torch.randn(64), "cpu") stack.clear("conv_1") result = stack.pop("conv_1", "cpu") assert result is None, "cleared conversation should not be found" print(" PASS test_conv_stack_clear") def test_conv_stack_lru_eviction(): stack = ConversationStack(max_conversations=2, hidden_dim=64) for i in range(3): h = torch.full((64,), float(i)) stack.push(f"conv_{i}", h, torch.zeros(64), torch.zeros(64), torch.zeros(64), "cpu") result_0 = stack.pop("conv_0", "cpu") assert result_0 is None, "conv_0 should be evicted (LRU)" result_1 = stack.pop("conv_1", "cpu") assert result_1 is not None, "conv_1 should still exist" result_2 = stack.pop("conv_2", "cpu") assert result_2 is not None, "conv_2 should still exist" print(" PASS test_conv_stack_lru_eviction") def test_conv_stack_reset(): stack = ConversationStack(max_conversations=4, hidden_dim=64) stack.push("conv_1", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu") stack.push("conv_2", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu") stack.reset() assert stack.pop("conv_1", "cpu") is None, "conv_1 should be gone after reset" assert stack.pop("conv_2", "cpu") is None, "conv_2 should be gone after reset" assert stack.active_slot == -1, "active_slot should be -1 after reset" print(" PASS test_conv_stack_reset") # ===== ConversationLSTM Tests ===== def test_conversation_lstm_forward(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=50, bptt_topic=200) x = torch.randn(4, 64) h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None) assert h_out.shape == (4, 64) assert c_focus.shape == (4, 64) assert h_topic.shape == (4, 64) assert c_topic.shape == (4, 64) assert c_proj.shape == (4, 64) assert reg.ndim == 0 print(" PASS test_conversation_lstm_forward") def test_conversation_lstm_dual_state(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(2, 64) h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None) for _ in range(10): h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach())) assert c_focus.abs().sum() > 0, "c_focus should be nonzero after steps" assert c_topic.abs().sum() > 0, "c_topic should be nonzero after steps" print(" PASS test_conversation_lstm_dual_state") def test_conversation_lstm_boundary_reset(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(2, 64) h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None) for _ in range(5): h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach())) c_focus_before = c_focus.clone() h_out, c_focus_after, _, _, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()), boundary_signal=SPECIAL_VOCAB['BOS']) assert lstm._last_reset < 1.0, "BOS boundary should reduce reset from 1.0" print(" PASS test_conversation_lstm_boundary_reset") def test_conversation_lstm_topic_gate(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(2, 64) _, _, h_topic_0, _, _, _ = lstm(x, None) assert h_topic_0.abs().sum() > 0, "h_topic should be nonzero after one step" print(" PASS test_conversation_lstm_topic_gate") def test_conversation_lstm_bptt_dual_windows(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=3, bptt_topic=6) x = torch.randn(2, 64) memory = None for i in range(2): h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) assert h_out.grad_fn is not None, "h_out should have grad before focus BPTT" h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) assert c_f.grad_fn is None, "c_focus should be detached at focus BPTT boundary (step 3)" print(" PASS test_conversation_lstm_bptt_dual_windows") def test_conversation_lstm_topic_preserves_on_boundary(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) x = torch.randn(2, 64) memory = None for _ in range(5): h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) c_topic_before = c_t.clone() _, c_f_after, _, c_t_after, _, _ = lstm(x, memory, boundary_signal=SPECIAL_VOCAB['USER']) decay_ratio = c_t_after.norm() / max(c_topic_before.norm(), 1e-8) focus_decay = c_f_after.norm() / max(c_f.norm(), 1e-8) assert decay_ratio >= focus_decay * 0.9, f"topic should decay less than focus on boundary: topic={decay_ratio:.3f} focus={focus_decay:.3f}" print(" PASS test_conversation_lstm_topic_preserves_on_boundary") def test_conversation_lstm_h_out_is_sum(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) lstm.eval() x = torch.randn(2, 64) with torch.no_grad(): h_out, c_f, h_topic, c_topic, c_proj, _ = lstm(x, None) assert h_out.shape == (2, 64) assert c_proj.shape == (2, 64) print(" PASS test_conversation_lstm_h_out_is_sum") def test_extract_boundary_from_input(): x_bos = torch.tensor([[SPECIAL_VOCAB['BOS'], 10, 20]]) assert _extract_boundary_from_input(x_bos) == SPECIAL_VOCAB['BOS'], "should detect BOS" x_user = torch.tensor([[10, SPECIAL_VOCAB['USER'], 20]]) assert _extract_boundary_from_input(x_user) == SPECIAL_VOCAB['USER'], "should detect USER" x_none = torch.tensor([[10, 20, 30]]) assert _extract_boundary_from_input(x_none) is None, "should return None for no boundary" print(" PASS test_extract_boundary_from_input") # ===== Integration: ConversationLSTM + Model ===== def test_model_switch_conversation(): model = ARBModel() model.lstm_enabled = True model.switch_conversation("conv_A") assert model.lstm.conv_stack._current_conv_id == "conv_A" print(" PASS test_model_switch_conversation") def test_model_reset_conversation(): model = ARBModel() model.lstm_enabled = True model.switch_conversation("conv_A") model.reset_conversation("conv_A") assert model.lstm.conv_stack._current_conv_id is None print(" PASS test_model_reset_conversation") def test_model_generate_with_conversation_id(): model = ARBModel() model.lstm_enabled = True idx = torch.zeros((1, 10), dtype=torch.long) out = model.generate(idx, max_new_token=5, conversation_id="conv_test") assert out.shape == (1, 15), f"output shape {out.shape}" print(" PASS test_model_generate_with_conversation_id") def test_conversation_lstm_ternary_projections(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) assert isinstance(lstm.c_focus_proj, TernaryScaleTensor) assert isinstance(lstm.c_topic_proj, TernaryScaleTensor) print(" PASS test_conversation_lstm_ternary_projections") def test_conversation_lstm_focus_gate_params(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) assert isinstance(lstm.focus_gate, FocusGate) assert isinstance(lstm.focus_gate.boundary_embed, nn.Embedding) assert isinstance(lstm.focus_gate.reset_fc, nn.Linear) assert isinstance(lstm.focus_gate.dampen_fc, nn.Linear) print(" PASS test_conversation_lstm_focus_gate_params") def test_conversation_lstm_topic_cell_bias(): lstm = ConversationLSTM(input_dim=64, hidden_dim=64) bias_topic = lstm.topic_cell.bias_ih[64:128] assert torch.allclose(bias_topic, torch.full_like(bias_topic, 1.5)), f"topic forget gate bias should be 1.5, got {bias_topic.mean().item()}" print(" PASS test_conversation_lstm_topic_cell_bias") if __name__ == "__main__": tests = [ test_sticky_zone_ste, test_sticky_zone_ste_dtype_preservation, test_scaled_ternary_linear, test_rmsnorm, test_byte_embedding, test_text_sequencer, test_trigram_window, test_image_sequencer, test_image_sequencer_frozen, test_target_alignment, test_model_forward, test_generate, test_param_count, test_gradient_flow, test_model_forward_with_targets, test_save_load_roundtrip, test_vq_adapter_shapes, test_vq_integration, test_vq_disabled, test_vq_with_targets, test_l2_distance_matching, test_vq_ternary_projections, test_multimodal_vq_bridge_text_only, test_multimodal_vq_bridge_text_image, test_modality_gate_shapes, test_ternary_graph_multicodebook, test_vq_no_float_cast_in_model, test_zero_fp32_params, test_sticky_zone_ste_gradient, test_graph_moe_gate_shape, test_ternary_graph_shapes, test_graph_gradient_flow, test_graph_connectivity_monitor, test_model_forward_with_graph, test_model_graph_disabled, test_ternary_graph_in_modules, test_moe_shapes, test_moe_router, test_moe_aux_loss, test_shared_expert, test_moe_gradient_flow, test_moe_zero_fp32, test_ternary_graph_with_gate, test_model_forward_with_moe, test_model_moe_disabled, test_model_moe_loss_components, test_model_moe_gate_modulation, test_param_count_with_moe, test_moe_monitoring, test_loss_components, test_loss_components_none_fields, test_loss_components_backward, test_gnn_lora_adapter, test_gnn_lora_gradient, test_shared_gnn_weight_tying, test_shared_gnn_multi_hop, test_model_losses_components_type, test_halting_unit_shapes, test_halting_unit_ternary_pure, test_graph_act_cell_shapes, test_moe_act_cell_shapes, test_act_early_halt, test_act_weight_sum_one, test_act_gradient_flow, test_loss_components_ponder_fields, test_loss_components_ponder_none, test_act_graph_moe_sequential, test_model_forward_with_act, test_model_act_forward_without_targets, test_model_act_loss_components, test_model_act_backward, test_model_act_disabled, test_model_act_warmup_mode, test_model_act_ponder_cached, test_act_warmup_schedule, test_act_ponder_lambda, test_model_ponder_lambda_scaling, test_text_only_forward, test_image_forward, test_multimodal_backward, test_no_stale_trigram_encoder, test_vocab, test_memgram_shapes, test_memgram_hash_indices, test_memgram_bilinear_gate_range, test_memgram_decay_formula, test_memgram_gradient_flow, test_memgram_conv_path, test_conv_vq_shapes, test_conv_vq_hard_cap, test_conv_vq_deferred_activation, test_conv_vq_ema_update, test_conv_vq_persistence, test_conv_vq_fuzzy_retrieve, test_conv_vq_commitment_nonneg, test_lstm_shapes, test_lstm_forget_gate_bias, test_lstm_bptt_detach, test_lstm_hidden_reg, test_lstm_c_t_proj_ternary, test_memory_modules_backward_compat, test_loss_components_nine_fields_total, test_loss_components_nine_fields_log, test_moe_router_h_with_h_t, test_moe_router_without_h_t, test_forward_no_memory_backward_compat, test_forward_lstm_enabled_h_t_passed, test_forward_lstm_c_t_residual, test_forward_memgram_injection, test_forward_conv_vq_deferred, test_generate_carries_lstm_state, test_memory_schedule_warmup, test_memory_schedule_lstm_first, test_memory_schedule_memgram_second, test_memory_schedule_all_on, test_memory_schedule_conv_vq_requires_vq_util, test_memory_schedule_decay_reg_last, test_lstm_state_reset_per_batch, test_bptt_counter_separate, test_focus_gate_no_boundary, test_focus_gate_boundary_signal, test_focus_gate_all_boundary_types, test_focus_gate_unknown_token, test_focus_gate_c_focus_modulation, test_conv_stack_push_pop, test_conv_stack_pop_missing, test_conv_stack_clear, test_conv_stack_lru_eviction, test_conv_stack_reset, test_conversation_lstm_forward, test_conversation_lstm_dual_state, test_conversation_lstm_boundary_reset, test_conversation_lstm_topic_gate, test_conversation_lstm_bptt_dual_windows, test_conversation_lstm_topic_preserves_on_boundary, test_conversation_lstm_h_out_is_sum, test_extract_boundary_from_input, test_model_switch_conversation, test_model_reset_conversation, test_model_generate_with_conversation_id, test_conversation_lstm_ternary_projections, test_conversation_lstm_focus_gate_params, test_conversation_lstm_topic_cell_bias, ] print("Running MORPH tests (Phase 1 + Phase 2 VQ + Phase 3 Graph + Phase 4 MoE + Explore + Phase 5 ACT + Phase 6 Multi-Modal + Phase 7 Memory)...\n") passed = 0 failed = 0 for t in tests: try: t() passed += 1 except Exception as e: print(f" FAIL {t.__name__}: {e}") failed += 1 print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")