| import torch |
| import os |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch.nn as nn |
| import sys |
| import os |
|
|
|
|
| from arbitor.config import ( |
| VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, |
| CODEBOOK_DIM, CODEBOOK_SIZE, |
| SPECIAL_VOCAB, |
| StickyZoneSTE, |
| ByteEmbedding, Sequencer, TextSequencer, ImageSequencer, AudioSequencer, |
| MultimodalSequencer, |
| TernaryGNNLayer, TernaryGraph, GraphMoEGate, SharedProjectionMoE, |
| ByteHead, ARBModel, VQAdapter, MultimodalVQBridge, ModalityGate, |
| LossComponents, LossWeights, GNNLoRAAdapter, |
| HaltingUnit, GraphACTCell, MoEACTCell, |
| MemGram, ConvVQCodebook, |
| FocusGate, ConversationStack, ConversationLSTM, |
| _BOUNDARY_TOKEN_MAP, _extract_boundary_from_input, |
| ) |
| from arbitor.kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType |
|
|
| TERNARY_MODULES = (TernaryScaleTensor, TernaryRMSNorm, ByteEmbedding, TernaryGraph, GraphMoEGate, SharedProjectionMoE, GNNLoRAAdapter, HaltingUnit, GraphACTCell, MoEACTCell, Sequencer, TextSequencer, ImageSequencer, AudioSequencer, MultimodalVQBridge, ModalityGate, MemGram, ConvVQCodebook, ConversationLSTM) |
|
|
|
|
| def _is_ternary_param(model, name): |
| parent_name = name.rsplit(".", 1)[0] if "." in name else "" |
| parent = dict(model.named_modules()).get(parent_name, None) |
| return isinstance(parent, TERNARY_MODULES) |
|
|
|
|
| |
|
|
| def test_sticky_zone_ste(): |
| w = torch.randn(8, 8, requires_grad=True) |
| t = StickyZoneSTE.apply(w, 0.05) |
| unique = set(t.detach().flatten().tolist()) |
| assert unique.issubset({-1.0, 0.0, 1.0}), f"Non-ternary values: {unique}" |
| t.sum().backward() |
| assert w.grad is not None |
| outside = w.abs() > 0.05 |
| if outside.any(): |
| assert (w.grad[outside] != 0).any(), "Outside threshold should have non-zero gradient" |
| dead = w.abs() <= 0.05 |
| if dead.any(): |
| assert (w.grad[dead] >= 0).all(), "Sticky zone gradient should be non-negative" |
| print(" PASS test_sticky_zone_ste") |
|
|
|
|
| def test_sticky_zone_ste_dtype_preservation(): |
| w_bf16 = torch.randn(8, 8, dtype=torch.bfloat16, requires_grad=True) |
| t = StickyZoneSTE.apply(w_bf16, 0.05) |
| assert t.dtype == torch.bfloat16, f"Expected bfloat16, got {t.dtype}" |
| t.sum().backward() |
| assert w_bf16.grad.dtype == torch.bfloat16, f"Expected bfloat16 grad, got {w_bf16.grad.dtype}" |
| print(" PASS test_sticky_zone_ste_dtype_preservation") |
|
|
|
|
| def test_scaled_ternary_linear(): |
| lin = TernaryScaleTensor(32, 16, bias=False) |
| x = torch.randn(2, 10, 32) |
| out = lin(x) |
| assert out.shape == (2, 10, 16), f"Shape: {out.shape}" |
| assert lin.bias is None, "TernaryScaleTensor bias should be None" |
| print(" PASS test_scaled_ternary_linear") |
|
|
|
|
| def test_rmsnorm(): |
| norm = TernaryRMSNorm(32) |
| x = torch.randn(2, 10, 32) |
| out = norm(x) |
| assert out.shape == x.shape, f"Shape: {out.shape}" |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-8) |
| expected = torch.ones(32, device=x.device) * (x / rms) |
| assert out.shape == expected.shape, "RMSNorm mismatch" |
| print(" PASS test_rmsnorm") |
|
|
|
|
| def test_byte_embedding(): |
| emb = ByteEmbedding() |
| x = torch.randint(0, VOCAB, (4, 20)) |
| out = emb(x) |
| assert out.shape == (4, 20, EMBEDDING_DIM), f"Shape: {out.shape}" |
| print(" PASS test_byte_embedding") |
|
|
|
|
| def test_text_sequencer(): |
| enc = TextSequencer() |
| x = torch.randn(2, 10, EMBEDDING_DIM) |
| out = enc(x) |
| assert out.shape == (2, 8, HIDDEN_DIM), f"Shape: {out.shape}, expected (2, 8, {HIDDEN_DIM})" |
| print(" PASS test_text_sequencer") |
|
|
|
|
| def test_trigram_window(): |
| x = torch.zeros(1, 5, EMBEDDING_DIM) |
| for i in range(5): |
| x[0, i, :] = i + 1 |
| windows = x.unfold(dimension=1, size=3, step=1) |
| assert windows.shape == (1, 3, EMBEDDING_DIM, 3), f"Unfold shape: {windows.shape}" |
| assert windows[0, 0, 0, 0].item() == 1.0 |
| assert windows[0, 0, 0, 1].item() == 2.0 |
| assert windows[0, 0, 0, 2].item() == 3.0 |
| print(" PASS test_trigram_window") |
|
|
|
|
| def test_image_sequencer(): |
| iseq = ImageSequencer() |
| x = torch.randn(1, 3, 224, 224) |
| out = iseq(x) |
| assert out.shape == (1, 194, HIDDEN_DIM) |
| print(" PASS test_image_sequencer") |
|
|
|
|
| def test_image_sequencer_frozen(): |
| iseq = ImageSequencer() |
| for p in iseq.vit.parameters(): |
| assert not p.requires_grad |
| print(" PASS test_image_sequencer_frozen") |
|
|
|
|
| def test_target_alignment(): |
| model = ARBModel() |
| x = torch.tensor([[SPECIAL_VOCAB["BOS"], 10, 20, 30, 40, 50, SPECIAL_VOCAB["EOS"]]]) |
| targets = x[:, 3:] |
| logits, losses, _, _ = model(x, targets=targets) |
| assert losses is not None, "Losses should be computed" |
| assert logits[:, :-1, :].shape[1] == targets.shape[1], "Target alignment mismatch" |
| print(" PASS test_target_alignment") |
|
|
|
|
| def test_model_forward(): |
| model = ARBModel() |
| B, T = 2, 66 |
| x = torch.randint(0, VOCAB, (B, T)) |
| logits, losses, _, _ = model(x) |
| assert logits.shape == (B, T - 2, VOCAB), f"Shape: {logits.shape}, expected ({B}, {T-2}, {VOCAB})" |
| assert losses is None, "Losses should be None without targets" |
| print(" PASS test_model_forward") |
|
|
|
|
| def test_generate(): |
| model = ARBModel() |
| model.eval() |
| seed = torch.tensor([[SPECIAL_VOCAB["BOS"], ord("H"), ord("e"), ord("l")]]) |
| with torch.no_grad(): |
| output = model.generate(seed, max_new_token=10, temperature=1.0) |
| assert output.shape == (1, 14), f"Shape: {output.shape}" |
| assert (output >= 0).all() and (output < VOCAB).all(), "Tokens out of vocab range" |
| print(" PASS test_generate") |
|
|
|
|
| def test_param_count(): |
| model = ARBModel() |
| total = sum(p.numel() for p in model.parameters()) |
| print(f" Param count: {total:,}") |
| assert 120e6 < total < 150e6, f"Param count {total:,} outside expected range (frozen ViT-base + Whisper-tiny + MemGram + LSTM)" |
| print(" PASS test_param_count") |
|
|
|
|
| def test_gradient_flow(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (4, 20)) |
| targets = x[:, 3:] |
| logits, losses, _, _ = model(x, targets=targets) |
| losses.total.backward() |
| for name, param in model.named_parameters(): |
| if param.requires_grad and "embed" not in name and "graph_pool.query" not in name: |
| if "moe.W_gate" in name or "moe.W_transform" in name or "moe.W_gate_norms" in name or "moe.W_transform_norms" in name: |
| continue |
| if "moe.router.bias" in name: |
| continue |
| if "moe.router_h" in name: |
| continue |
| if "memgram" in name or "conv_vq" in name or "lstm" in name: |
| continue |
| if "patch_proj" in name or "image_sequencer.projection" in name or "image_sequencer.norm" in name: |
| continue |
| if "audio_sequencer" in name or "multimodal_sequencer.audio" in name: |
| continue |
| if "bridge.image_vq" in name or "bridge.audio_vq" in name or "bridge.bridge_norm" in name or "modality_gate" in name: |
| continue |
| assert param.grad is not None, f"No gradient for {name}" |
| print(" PASS test_gradient_flow") |
|
|
|
|
| def test_model_forward_with_targets(): |
| model = ARBModel() |
| B, T = 4, CTX |
| x = torch.randint(0, VOCAB, (B, T)) |
| targets = torch.randint(0, VOCAB, (B, T - 3)) |
| logits, losses, _, _ = model(x, targets=targets) |
| assert losses is not None |
| assert isinstance(losses, LossComponents) |
| assert losses.total.ndim == 0 |
| assert losses.total > 0 |
| print(" PASS test_model_forward_with_targets") |
|
|
|
|
| def test_save_load_roundtrip(): |
| try: |
| from arbitor.converters.convert_to_ternary8 import save_model, load_model |
| except ImportError: |
| print(" SKIP test_save_load_roundtrip (convert_to_ternary not available)") |
| return |
| model = ARBModel() |
| save_model(model, "/tmp/test-morph-roundtrip.pt") |
| loaded = load_model("/tmp/test-morph-roundtrip.pt") |
| x = torch.randint(0, VOCAB, (1, 10)) |
| model.eval() |
| loaded.eval() |
| with torch.no_grad(): |
| logits_orig, _, _, _ = model(x) |
| logits_loaded, _, _ = loaded(x) |
| assert torch.allclose(logits_orig, logits_loaded, atol=1e-6), "Save/load roundtrip mismatch" |
| print(" PASS test_save_load_roundtrip") |
|
|
|
|
| |
|
|
| def test_vq_adapter_shapes(): |
| adapter = VQAdapter() |
| x = torch.randn(2, 10, HIDDEN_DIM) |
| out, vq_loss, indices = adapter(x) |
| assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}" |
| assert indices.shape == (2, 10), f"VQ indices shape: {indices.shape}" |
| assert indices.dtype == torch.long, "Indices must be long" |
| assert vq_loss.item() >= 0, "VQ loss must be non-negative" |
| print(" PASS test_vq_adapter_shapes") |
|
|
|
|
| def test_vq_integration(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" |
| assert vq_indices is not None, "VQ indices must be returned" |
| assert vq_indices.shape == (2, 64), f"VQ indices shape wrong: {vq_indices.shape}" |
| print(" PASS test_vq_integration") |
|
|
|
|
| def test_vq_disabled(): |
| model = ARBModel() |
| model.vq_enabled = False |
| model.graph_enabled = False |
| model.moe_enabled = False |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert vq_indices is None, "Indices should be None when VQ disabled" |
| assert logits.shape == (2, 64, VOCAB) |
| print(" PASS test_vq_disabled") |
|
|
|
|
| def test_vq_with_targets(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:66] |
| logits, losses, vq_indices, _ = model(x, targets=targets) |
| assert losses is not None and losses.total.item() > 0, "Loss should be positive with targets" |
| print(" PASS test_vq_with_targets") |
|
|
|
|
| def test_l2_distance_matching(): |
| adapter = VQAdapter() |
| x_proj = torch.randn(2, 10, 32) |
| l2_indices, l2_dists = adapter.l2_distance_matching(x_proj) |
| assert l2_indices.shape == (2, 10), f"L2 indices shape: {l2_indices.shape}" |
| assert l2_dists.shape == (2, 10), f"L2 distances shape: {l2_dists.shape}" |
| assert (l2_dists >= 0).all(), "L2 distances must be non-negative" |
| print(" PASS test_l2_distance_matching") |
|
|
|
|
| def test_vq_ternary_projections(): |
| adapter = VQAdapter() |
| assert isinstance(adapter.proj_in, TernaryScaleTensor), \ |
| f"proj_in should be TernaryScaleTensor, got {type(adapter.proj_in)}" |
| assert isinstance(adapter.proj_out, TernaryScaleTensor), \ |
| f"proj_out should be TernaryScaleTensor, got {type(adapter.proj_out)}" |
| x = torch.randn(2, 10, HIDDEN_DIM) |
| out, vq_loss, indices = adapter(x) |
| assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}" |
| assert vq_loss.item() >= 0, "VQ loss must be non-negative" |
| print(" PASS test_vq_ternary_projections") |
|
|
|
|
| |
|
|
| def test_multimodal_vq_bridge_text_only(): |
| bridge = MultimodalVQBridge() |
| text_in = torch.randn(2, 10, 512) |
| combined, losses, indices = bridge({'text': text_in}) |
| assert combined.shape == (2, 10, 512) |
| assert 'text_vq' in losses |
| assert (indices['text'] < 8192).all() |
| print(" PASS test_multimodal_vq_bridge_text_only") |
|
|
|
|
| def test_multimodal_vq_bridge_text_image(): |
| bridge = MultimodalVQBridge() |
| text_in = torch.randn(2, 10, 512) |
| image_in = torch.randn(2, 20, 512) |
| combined, losses, indices = bridge({'text': text_in, 'image': image_in}) |
| assert combined.shape == (2, 30, 512) |
| assert (indices['image'] >= 8192).all() |
| assert (indices['image'] < 12288).all() |
| print(" PASS test_multimodal_vq_bridge_text_image") |
|
|
|
|
| def test_modality_gate_shapes(): |
| gate = ModalityGate() |
| weights, count, hops = gate(['text']) |
| assert isinstance(weights, dict) |
| assert count >= 1 |
| assert hops >= 2 |
| print(" PASS test_modality_gate_shapes") |
|
|
|
|
| def test_ternary_graph_multicodebook(): |
| graph = TernaryGraph(total_vocab_size=16384) |
| text_embed = torch.randn(1, 8192, 32) |
| image_embed = torch.randn(1, 4096, 32) |
| audio_embed = torch.randn(1, 4096, 32) |
| graph._codebook_embed = torch.cat([text_embed, image_embed, audio_embed], dim=1) |
| vq_out = torch.randn(2, 21, 512) |
| text_idx = torch.randint(0, 8192, (2, 8)) |
| image_idx = torch.randint(8192, 12288, (2, 7)) |
| audio_idx = torch.randint(12288, 16384, (2, 6)) |
| vq_idx = torch.cat([text_idx, image_idx, audio_idx], dim=1) |
| per_pos, gpool, gate_alpha = graph(vq_out, vq_idx, 0.05) |
| assert per_pos.shape == (2, 21, 512) |
| print(" PASS test_ternary_graph_multicodebook") |
|
|
|
|
| def test_vq_no_float_cast_in_model(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" |
| for name, mod in model.named_modules(): |
| if isinstance(mod, nn.Linear): |
| if "image_sequencer" in name or "multimodal_sequencer.image" in name or "moe.router" in name or "lstm." in name or "multimodal_sequencer.audio" in name: |
| continue |
| assert False, f"Unexpected nn.Linear: {name} — only moe.router, image_sequencer, lstm, and audio_sequencer are allowed" |
| if isinstance(mod, nn.Embedding): |
| assert "hop_lora.scale" in name or "lstm." in name or "whisper." in name or "vit." in name, f"nn.Embedding found: {name} — only hop_lora.scale, lstm.*, whisper.*, vit.* are allowed" |
| print(" PASS test_vq_no_float_cast_in_model") |
|
|
|
|
| def test_zero_fp32_params(): |
| model = ARBModel() |
| non_ternary_non_vq = 0 |
| for name, param in model.named_parameters(): |
| is_vq_internal = "bridge.text_vq.vq" in name or "bridge.image_vq.vq" in name or "bridge.audio_vq.vq" in name |
| is_moe_router = "moe.router" in name |
| is_lora_scale = "hop_lora.scale" in name |
| is_vit_frozen = "image_sequencer.vit" in name or "multimodal_sequencer.image.vit" in name |
| is_patch_proj = "patch_proj" in name |
| is_audio_proj = "mfcc_proj" in name or "frame_proj" in name |
| is_whisper_frozen = "whisper" in name |
| is_memory = name.startswith("memgram") or name.startswith("conv_vq") or name.startswith("lstm") |
| if is_vq_internal or is_moe_router or is_lora_scale or is_vit_frozen or is_patch_proj or is_audio_proj or is_whisper_frozen or is_memory: |
| continue |
| if not _is_ternary_param(model, name): |
| non_ternary_non_vq += param.numel() |
| assert non_ternary_non_vq == 0, \ |
| f"Found {non_ternary_non_vq} non-ternary, non-VQ, non-router params" |
| print(" PASS test_zero_fp32_params") |
|
|
|
|
| def test_sticky_zone_ste_gradient(): |
| w = torch.tensor([-0.01, -0.03, -0.049, 0.06, 0.10], requires_grad=True) |
| threshold = 0.05 |
| t = StickyZoneSTE.apply(w, threshold) |
| t.sum().backward() |
| expected = [0.2, 0.6, 0.98, 1.0, 1.0] |
| for i, exp_ratio in enumerate(expected): |
| actual = w.grad[i].item() |
| assert abs(actual - exp_ratio) < 0.02, f"w={w[i].item():.3f}: expected ratio {exp_ratio}, got {actual:.3f}" |
| print(" PASS test_sticky_zone_ste_gradient") |
|
|
|
|
| |
|
|
| def test_graph_moe_gate_shape(): |
| gate = GraphMoEGate(dim=HIDDEN_DIM) |
| x = torch.randn(2, 10, HIDDEN_DIM) |
| pooled, alpha = gate(x) |
| assert pooled.shape == (2, HIDDEN_DIM), f"Pooled shape: {pooled.shape}" |
| assert alpha.shape == (2, 10, 1), f"Alpha shape: {alpha.shape}" |
| assert (alpha >= 0).all() and (alpha <= 1).all(), "Alpha out of [0,1]" |
| assert gate.query.numel() == HIDDEN_DIM, f"Gate params: {gate.query.numel()}" |
| print(" PASS test_graph_moe_gate_shape") |
|
|
|
|
| def test_ternary_graph_shapes(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) |
| graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) |
| vq_output = torch.randn(2, 10, HIDDEN_DIM) |
| vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) |
| per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) |
| assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}" |
| assert gpool.shape == (2, HIDDEN_DIM), f"graph_pool shape: {gpool.shape}" |
| assert gate_alpha.shape == (2, 10, 1), f"gate_alpha shape: {gate_alpha.shape}" |
| print(" PASS test_ternary_graph_shapes") |
|
|
|
|
| def test_graph_gradient_flow(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) |
| graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) |
| vq_output = torch.randn(2, 10, HIDDEN_DIM, requires_grad=True) |
| vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) |
| per_pos, _, _ = graph(vq_output, vq_indices, 0.05) |
| per_pos.sum().backward() |
| assert graph.edge_attr.grad is not None, "edge_attr should have gradient" |
| assert vq_output.grad is not None, "vq_output should have gradient" |
| print(" PASS test_graph_gradient_flow") |
|
|
|
|
| def test_graph_connectivity_monitor(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2) |
| health = graph.monitor_graph_health(threshold=0.05) |
| assert 'sparsity' in health |
| assert 'isolated_nodes' in health |
| assert 'avg_polarity' in health |
| assert 'dead_edges' in health |
| assert 0.0 <= health['sparsity'] <= 1.0 |
| assert health['isolated_nodes'] >= 0 |
| print(" PASS test_graph_connectivity_monitor") |
|
|
|
|
| def test_model_forward_with_graph(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}" |
| assert vq_indices is not None, "VQ indices required for graph" |
| assert hasattr(model, 'ternary_graph'), "Model missing ternary_graph" |
| print(" PASS test_model_forward_with_graph") |
|
|
|
|
| def test_model_graph_disabled(): |
| model = ARBModel() |
| model.graph_enabled = False |
| model.moe_enabled = False |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB) |
| print(" PASS test_model_graph_disabled") |
|
|
|
|
| def test_ternary_graph_in_modules(): |
| assert TernaryGraph in TERNARY_MODULES, "TernaryGraph not in TERNARY_MODULES" |
| assert GraphMoEGate in TERNARY_MODULES, "GraphMoEGate not in TERNARY_MODULES" |
| assert SharedProjectionMoE in TERNARY_MODULES, "SharedProjectionMoE not in TERNARY_MODULES" |
| print(" PASS test_ternary_graph_in_modules") |
|
|
|
|
| |
|
|
| def test_moe_shapes(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, core_rank=192, shared_inter=3072, tscale_type=TScaleType.T32) |
| x = torch.randn(4, 10, 512) |
| out, aux = moe(x) |
| assert out.shape == (4, 10, 512), f'MoE output shape: {out.shape}' |
| assert aux.ndim == 0, f'Aux loss should be scalar, got ndim={aux.ndim}' |
| assert aux.item() >= 0, 'Aux loss should be non-negative' |
| print(" PASS test_moe_shapes") |
|
|
|
|
| def test_moe_router(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, noise_std=0.25) |
| moe.train() |
| x = torch.randn(4, 20, 512) |
| out, aux = moe(x) |
| assert moe._last_topk_idx is not None |
| assert moe._last_topk_idx.shape == (80, 2), f'topk_idx shape: {moe._last_topk_idx.shape}' |
| assert (moe._last_topk_idx >= 0).all() and (moe._last_topk_idx < 8).all() |
| moe.eval() |
| out2, _ = moe(x) |
| print(" PASS test_moe_router") |
|
|
|
|
| def test_moe_aux_loss(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) |
| x = torch.randn(4, 10, 512) |
| _, aux = moe(x) |
| assert aux.item() >= 0, 'Aux loss must be non-negative' |
| print(" PASS test_moe_aux_loss") |
|
|
|
|
| def test_shared_expert(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) |
| assert isinstance(moe.shared_expert_gate, TernaryScaleTensor) |
| assert isinstance(moe.shared_expert_up, TernaryScaleTensor) |
| assert isinstance(moe.shared_expert_down, TernaryScaleTensor) |
| x = torch.randn(2, 5, 512) |
| out, _ = moe(x) |
| assert out.norm().item() > 0, 'Shared expert output should be non-zero' |
| print(" PASS test_shared_expert") |
|
|
|
|
| def test_moe_gradient_flow(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) |
| x = torch.randn(2, 10, 512) |
| x.requires_grad_(True) |
| out, aux = moe(x) |
| (out.sum() + aux).backward() |
| assert x.grad is not None, 'No gradient on input' |
| assert hasattr(moe.router, '_hook_grad_T_sign'), 'No grad captured on router' |
| assert hasattr(moe.W_gate[0], '_hook_grad_T_sign'), 'No grad captured on W_gate[0]' |
| print(" PASS test_moe_gradient_flow") |
|
|
|
|
| def test_moe_zero_fp32(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2) |
| non_ternary = 0 |
| for name, param in moe.named_parameters(): |
| if not _is_ternary_param(moe, name): |
| non_ternary += param.numel() |
| assert non_ternary == 0, f'Expected 0 non-ternary params, got {non_ternary}' |
| print(" PASS test_moe_zero_fp32") |
|
|
|
|
| def test_ternary_graph_with_gate(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM) |
| graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) |
| vq_output = torch.randn(2, 10, HIDDEN_DIM) |
| vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) |
| per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) |
| assert gate_alpha.shape == (2, 10, 1), f'gate_alpha shape: {gate_alpha.shape}' |
| assert (gate_alpha >= 0).all() and (gate_alpha <= 1).all() |
| print(" PASS test_ternary_graph_with_gate") |
|
|
|
|
| def test_model_forward_with_moe(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB), f'Logits shape: {logits.shape}' |
| assert vq_indices is not None |
| print(" PASS test_model_forward_with_moe") |
|
|
|
|
| def test_model_moe_disabled(): |
| model = ARBModel() |
| model.moe_enabled = False |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, vq_indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB) |
| print(" PASS test_model_moe_disabled") |
|
|
|
|
| def test_model_moe_loss_components(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, vq_indices, _ = model(x, targets=targets) |
| assert losses is not None and isinstance(losses, LossComponents) |
| assert losses.lm is not None and losses.lm > 0 |
| assert losses.vq_commitment is not None |
| assert losses.moe_aux is not None |
| assert losses.graph_l1 is not None |
| assert losses.total > 0 |
| assert model.moe._last_topk_idx is not None, 'MoE should have routing info after forward' |
| assert model.moe._last_aux_loss is not None, 'MoE should have aux_loss cached after forward' |
| print(" PASS test_model_moe_loss_components") |
|
|
|
|
| def test_model_moe_gate_modulation(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, _, _, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB) |
| print(" PASS test_model_moe_gate_modulation") |
|
|
|
|
| def test_param_count_with_moe(): |
| model = ARBModel() |
| total = sum(p.numel() for p in model.parameters()) |
| print(f" Param count with MoE: {total:,}") |
| assert 120e6 < total < 150e6, f'Expected ~133M (frozen ViT-base + Whisper-tiny + ternary buffers), got {total:,}' |
| print(" PASS test_param_count_with_moe") |
|
|
|
|
| def test_moe_monitoring(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| model(x) |
| assert model.moe._last_topk_idx is not None, '_last_topk_idx should be set after forward' |
| assert model.moe._last_aux_loss is not None, '_last_aux_loss should be set after forward' |
| assert model.moe._last_topk_idx.shape[1] == model.moe.top_k |
| print(" PASS test_moe_monitoring") |
|
|
|
|
| |
|
|
| def test_loss_components(): |
| lm = torch.tensor(5.0, requires_grad=True) |
| vq = torch.tensor(0.5, requires_grad=True) |
| moe = torch.tensor(0.01, requires_grad=True) |
| graph = torch.tensor(0.001, requires_grad=True) |
| lc = LossComponents(lm=lm, vq_commitment=vq, moe_aux=moe, graph_l1=graph) |
| assert lc.total.ndim == 0, f"Total should be scalar, got ndim={lc.total.ndim}" |
| expected = 5.0 + 0.5 + 0.01 + LossWeights.graph_l1 * 0.001 |
| assert abs(lc.total.item() - expected) < 1e-5, f"Total mismatch: {lc.total.item()} vs {expected}" |
| lc.total.backward() |
| assert lm.grad is not None, "LM loss should have gradient" |
| print(" PASS test_loss_components") |
|
|
|
|
| def test_loss_components_none_fields(): |
| lm = torch.tensor(3.0, requires_grad=True) |
| lc = LossComponents(lm=lm, vq_commitment=None, moe_aux=None, graph_l1=None) |
| assert lc.total.item() == 3.0, f"Total with None fields: {lc.total.item()}" |
| print(" PASS test_loss_components_none_fields") |
|
|
|
|
| def test_loss_components_backward(): |
| lm = torch.tensor(4.0, requires_grad=True) |
| vq = torch.tensor(0.3, requires_grad=True) |
| lc = LossComponents(lm=lm, vq_commitment=vq) |
| lc.backward() |
| assert lm.grad is not None, "LM should have gradient after backward" |
| assert vq.grad is not None, "VQ should have gradient after backward" |
| print(" PASS test_loss_components_backward") |
|
|
|
|
| def test_gnn_lora_adapter(): |
| lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4) |
| x = torch.randn(8192, 512) |
| out0 = lora(x, hop_t=0) |
| out1 = lora(x, hop_t=1) |
| assert out0.shape == (8192, 512), f"LoRA output shape: {out0.shape}" |
| assert torch.allclose(out0, out1, atol=1e-6), "Zero-init scales should produce same output at init" |
| lora.scale.weight.data[1] = lora.scale.weight.data[0] + 1.0 |
| out1_modified = lora(x, hop_t=1) |
| assert not torch.allclose(out0, out1_modified), "Non-zero scales should differentiate hops" |
| print(" PASS test_gnn_lora_adapter") |
|
|
|
|
| def test_gnn_lora_gradient(): |
| lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4) |
| x = torch.randn(8192, 512, requires_grad=True) |
| out = lora(x, hop_t=0) |
| out.sum().backward() |
| assert x.grad is not None, "Input should have gradient" |
| assert lora.scale.weight.grad is not None, "LoRA scale should have gradient" |
| print(" PASS test_gnn_lora_gradient") |
|
|
|
|
| def test_shared_gnn_weight_tying(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=3) |
| assert hasattr(graph, 'gnn'), "Graph should have single shared GNN layer" |
| assert not hasattr(graph, 'gnn_layers'), "Graph should NOT have gnn_layers list" |
| assert hasattr(graph, 'hop_lora'), "Graph should have hop_lora adapter" |
| assert graph.max_hops == 3, f"max_hops should be 3, got {graph.max_hops}" |
| print(" PASS test_shared_gnn_weight_tying") |
|
|
|
|
| def test_shared_gnn_multi_hop(): |
| graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=4, lora_rank=32) |
| graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM) |
| vq_output = torch.randn(2, 10, HIDDEN_DIM) |
| vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10)) |
| per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05) |
| assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}" |
| print(" PASS test_shared_gnn_multi_hop") |
|
|
|
|
| def test_model_losses_components_type(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, vq_indices, _ = model(x, targets=targets) |
| assert isinstance(losses, LossComponents), f"Expected LossComponents, got {type(losses)}" |
| assert losses.lm is not None |
| assert losses.vq_commitment is not None |
| assert losses.moe_aux is not None |
| assert losses.graph_l1 is not None |
| total = losses.total.item() |
| w = losses.weights |
| manual = w.lm * losses.lm.item() + w.vq_commitment * losses.vq_commitment.item() + w.moe_aux * losses.moe_aux.item() + w.graph_l1 * losses.graph_l1.item() |
| if losses.graph_ponder is not None: |
| manual += w.graph_ponder * losses.graph_ponder.item() |
| if losses.moe_ponder is not None: |
| manual += w.moe_ponder * losses.moe_ponder.item() |
| if losses.conv_vq_commitment is not None: |
| manual += w.conv_vq_commitment * losses.conv_vq_commitment.item() |
| if losses.memgram_decay_reg is not None: |
| manual += w.memgram_decay_reg * losses.memgram_decay_reg.item() |
| if losses.lstm_hidden_reg is not None: |
| manual += w.lstm_hidden_reg * losses.lstm_hidden_reg.item() |
| assert abs(total - manual) < 1e-4, f"Total {total} != weighted sum {manual}" |
| print(" PASS test_model_losses_components_type") |
|
|
|
|
| |
|
|
| def test_halting_unit_shapes(): |
| hu = HaltingUnit(dim=512, tscale_type=TScaleType.T32) |
| x = torch.randn(4, 10, 512) |
| x.requires_grad_(True) |
| p = hu(x) |
| assert p.shape == (4, 10, 1), f"Shape: {p.shape}" |
| assert (p > 0).all() and (p < 1).all(), f"Range: ({p.min():.4f}, {p.max():.4f})" |
| p.sum().backward() |
| assert x.grad is not None, "No gradient on input" |
| print(" PASS test_halting_unit_shapes") |
|
|
|
|
| def test_halting_unit_ternary_pure(): |
| hu = HaltingUnit(dim=512) |
| for name, mod in hu.named_modules(): |
| if isinstance(mod, nn.Linear): |
| assert False, f"nn.Linear found: {name}" |
| if isinstance(mod, nn.Embedding): |
| assert False, f"nn.Embedding found: {name}" |
| print(" PASS test_halting_unit_ternary_pure") |
|
|
|
|
| def test_graph_act_cell_shapes(): |
| graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) |
| graph._codebook_embed = torch.randn(1, 8192, 32) |
| act = GraphACTCell(graph, max_hops=4, halt_threshold=0.01) |
| vq_out = torch.randn(2, 10, 512) |
| vq_out.requires_grad_(True) |
| vq_idx = torch.randint(0, 8192, (2, 10)) |
| per_pos, gpool, gate_alpha, ponder = act(vq_out, vq_idx, 0.05) |
| assert per_pos.shape == (2, 10, 512), f"per_pos: {per_pos.shape}" |
| assert gpool.shape == (2, 512), f"gpool: {gpool.shape}" |
| assert gate_alpha.shape == (2, 10, 1), f"gate_alpha: {gate_alpha.shape}" |
| assert ponder.ndim == 0 |
| assert ponder.item() > 0 |
| per_pos.sum().backward() |
| assert vq_out.grad is not None, "No gradient on input" |
| print(" PASS test_graph_act_cell_shapes") |
|
|
|
|
| def test_moe_act_cell_shapes(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) |
| act = MoEACTCell(moe, dim=512, max_iters=4, halt_threshold=0.01) |
| x = torch.randn(2, 10, 512) |
| x.requires_grad_(True) |
| out, aux, ponder = act(x) |
| assert out.shape == (2, 10, 512), f"out: {out.shape}" |
| assert aux.ndim == 0 |
| assert ponder.ndim == 0 |
| assert aux.item() >= 0 |
| assert ponder.item() > 0 |
| out.sum().backward() |
| assert x.grad is not None, "No gradient on input" |
| print(" PASS test_moe_act_cell_shapes") |
|
|
|
|
| def test_act_early_halt(): |
| graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) |
| graph._codebook_embed = torch.randn(1, 8192, 32) |
| act = GraphACTCell(graph, max_hops=8, halt_threshold=100.0) |
| vq_out = torch.randn(2, 10, 512) |
| vq_idx = torch.randint(0, 8192, (2, 10)) |
| _, _, _, ponder = act(vq_out, vq_idx, 0.05) |
| act_low = GraphACTCell(graph, max_hops=8, halt_threshold=1e-6) |
| _, _, _, ponder_low = act_low(vq_out, vq_idx, 0.05) |
| assert ponder_low.item() < ponder.item(), \ |
| f"Early halt ponder ({ponder_low:.4f}) should be less than no-halt ponder ({ponder:.4f})" |
| print(" PASS test_act_early_halt") |
|
|
|
|
| def test_act_weight_sum_one(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) |
| act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=1e-6) |
| x = torch.randn(2, 10, 512) |
| out_fast, _, _ = act(x) |
| act_slow = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=100.0) |
| out_slow, _, _ = act_slow(x) |
|
|
| out_sum = out_fast.sum().item() + out_slow.sum().item() |
| assert not torch.isnan(out_fast).any(), "NaN in fast ACT output" |
| assert not torch.isnan(out_slow).any(), "NaN in slow ACT output" |
| assert out_sum != 0, "Outputs should be non-zero (weights sum to 1.0)" |
| print(" PASS test_act_weight_sum_one") |
|
|
|
|
| def test_act_gradient_flow(): |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) |
| act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01) |
| x = torch.randn(2, 10, 512, requires_grad=True) |
| out, aux, ponder = act(x) |
| loss = out.sum() + aux + ponder |
| loss.backward() |
| assert x.grad is not None, "Input grad is None" |
| print(" PASS test_act_gradient_flow") |
|
|
|
|
| def test_loss_components_ponder_fields(): |
| lm = torch.tensor(5.0, requires_grad=True) |
| gp = torch.tensor(0.1, requires_grad=True) |
| mp = torch.tensor(0.2, requires_grad=True) |
| lc = LossComponents(lm=lm, graph_ponder=gp, moe_ponder=mp) |
| expected = 5.0 + 0.1 + 0.2 |
| assert abs(lc.total.item() - expected) < 1e-5, f"Total: {lc.total.item()} vs {expected}" |
| lc.total.backward() |
| assert lm.grad is not None |
| assert gp.grad is not None |
| assert mp.grad is not None |
| print(" PASS test_loss_components_ponder_fields") |
|
|
|
|
| def test_loss_components_ponder_none(): |
| lm = torch.tensor(3.0, requires_grad=True) |
| lc = LossComponents(lm=lm, graph_ponder=None, moe_ponder=None) |
| assert abs(lc.total.item() - 3.0) < 1e-5 |
| lc.total.backward() |
| assert lm.grad is not None |
| print(" PASS test_loss_components_ponder_none") |
|
|
|
|
| def test_act_graph_moe_sequential(): |
| graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32) |
| graph._codebook_embed = torch.randn(1, 8192, 32) |
| graph_act = GraphACTCell(graph, max_hops=3, halt_threshold=0.01) |
| moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32) |
| moe_act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01) |
|
|
| vq_out = torch.randn(2, 10, 512) |
| vq_out.requires_grad_(True) |
| vq_idx = torch.randint(0, 8192, (2, 10)) |
| per_pos, gpool, gate_alpha, graph_ponder = graph_act(vq_out, vq_idx, 0.05) |
| moe_out, aux, moe_ponder = moe_act(per_pos) |
| final = gate_alpha * moe_out + (1 - gate_alpha) * per_pos |
| assert final.shape == (2, 10, 512), f"Sequential output: {final.shape}" |
| assert graph_ponder.ndim == 0 |
| assert moe_ponder.ndim == 0 |
| final.sum().backward() |
| assert vq_out.grad is not None, "Input grad is None" |
| print(" PASS test_act_graph_moe_sequential") |
|
|
|
|
| |
|
|
| def test_model_forward_with_act(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, _, _ = model(x, targets=targets) |
| assert logits.shape == (2, 64, VOCAB), f"Logits: {logits.shape}" |
| assert isinstance(losses, LossComponents) |
| assert losses.graph_ponder is not None |
| assert losses.moe_ponder is not None |
| assert losses.total > 0 |
| print(" PASS test_model_forward_with_act") |
|
|
|
|
| def test_model_act_forward_without_targets(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, _, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB) |
| assert losses is None |
| print(" PASS test_model_act_forward_without_targets") |
|
|
|
|
| def test_model_act_loss_components(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| _, losses, _, _ = model(x, targets=targets) |
| assert losses.lm is not None |
| assert losses.vq_commitment is not None |
| assert losses.moe_aux is not None |
| assert losses.graph_l1 is not None |
| assert losses.graph_ponder is not None |
| assert losses.moe_ponder is not None |
| assert losses.total > sum(filter(None, [losses.graph_ponder, losses.moe_ponder])) |
| print(" PASS test_model_act_loss_components") |
|
|
|
|
| def test_model_act_backward(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| _, losses, _, _ = model(x, targets=targets) |
| losses.backward() |
| assert model.ternary_graph.edge_attr.grad is not None, "edge_attr grad None" |
| print(" PASS test_model_act_backward") |
|
|
|
|
| def test_model_act_disabled(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| model.graph_act_enabled = False |
| model.moe_act_enabled = False |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, _, _ = model(x, targets=targets) |
| assert logits.shape == (2, 64, VOCAB) |
| assert losses.graph_ponder is None |
| assert losses.moe_ponder is None |
| assert model.graph_act_enabled == False |
| assert model.moe_act_enabled == False |
| print(" PASS test_model_act_disabled") |
|
|
|
|
| def test_model_act_warmup_mode(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| _, losses, _, _ = model(x, targets=targets, act_warmup_mode=True) |
| assert losses.graph_ponder is None, f"During warmup, graph_ponder should be None: {losses.graph_ponder}" |
| assert losses.moe_ponder is None, f"During warmup, moe_ponder should be None: {losses.moe_ponder}" |
| _, losses2, _, _ = model(x, targets=targets, act_warmup_mode=False) |
| assert losses2.graph_ponder is not None, "Without warmup, graph_ponder should be present" |
| assert losses2.moe_ponder is not None, "Without warmup, moe_ponder should be present" |
| print(" PASS test_model_act_warmup_mode") |
|
|
|
|
| def test_model_act_ponder_cached(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| model(x, targets=targets) |
| assert model._last_graph_ponder > 0, f"_last_graph_ponder={model._last_graph_ponder}" |
| assert model._last_moe_ponder > 0, f"_last_moe_ponder={model._last_moe_ponder}" |
| print(" PASS test_model_act_ponder_cached") |
|
|
|
|
| |
|
|
| def test_act_warmup_schedule(): |
| from train import compute_act_warmup |
| assert compute_act_warmup(0, 50000) == True, "step 0 is warmup" |
| assert compute_act_warmup(9999, 50000) == True, "step 9999 is warmup" |
| assert compute_act_warmup(10000, 50000) == False, "step 10000 not warmup" |
| assert compute_act_warmup(50000, 50000) == False, "step 50000 not warmup" |
| print(" PASS test_act_warmup_schedule") |
|
|
|
|
| def test_act_ponder_lambda(): |
| from train import get_ponder_lambda |
| lam0 = get_ponder_lambda(0, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) |
| assert abs(lam0 - 0.1) < 1e-6, f"start lambda: {lam0}" |
| lam_mid = get_ponder_lambda(5000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) |
| assert lam_mid > 0.01 and lam_mid < 0.1, f"mid lambda: {lam_mid}" |
| lam_end = get_ponder_lambda(10000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01) |
| assert abs(lam_end - 0.01) < 1e-6, f"end lambda: {lam_end}" |
| print(" PASS test_act_ponder_lambda") |
|
|
|
|
| def test_model_ponder_lambda_scaling(): |
| model = ARBModel(tscale_type=TScaleType.T32) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| _, losses_high, _, _ = model(x, targets=targets, ponder_lambda=0.5) |
| _, losses_low, _, _ = model(x, targets=targets, ponder_lambda=0.01) |
| if losses_high.graph_ponder is not None and losses_low.graph_ponder is not None: |
| assert losses_high.graph_ponder.item() > losses_low.graph_ponder.item(), \ |
| "Higher ponder_lambda should produce larger ponder loss" |
| print(" PASS test_model_ponder_lambda_scaling") |
|
|
|
|
| |
|
|
| def test_text_only_forward(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, indices, _ = model(x) |
| assert logits.shape == (2, 64, VOCAB) |
| assert indices is not None |
| print(" PASS test_text_only_forward") |
|
|
|
|
| def test_image_forward(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| img = torch.randn(2, 3, 224, 224) |
| logits, losses, indices, _ = model(x, images=img) |
| assert logits.shape == (2, 64, VOCAB) |
| assert indices is not None |
| print(" PASS test_image_forward") |
|
|
|
|
| def test_multimodal_backward(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| img = torch.randn(2, 3, 224, 224) |
| logits, losses, _, _ = model(x, targets=targets, images=img) |
| assert losses is not None |
| losses.total.backward() |
| for name, param in model.named_parameters(): |
| if param.requires_grad and param.grad is None: |
| if any(skip in name for skip in ['vit', 'embedding', 'patch_proj', 'frame_proj', 'router.bias', 'router_h', 'W_gate', 'W_transform', 'hop_lora.scale', 'modality_gate', 'graph_pool.query', 'memgram', 'conv_vq', 'lstm', 'mfcc_proj', 'audio_sequencer', 'audio_vq']): |
| continue |
| assert False, f'No gradient for {name}' |
| print(" PASS test_multimodal_backward") |
|
|
|
|
| def test_no_stale_trigram_encoder(): |
| assert not hasattr(sys.modules['trigram'], 'TrigramEncoder'), 'TrigramEncoder should be removed' |
| print(" PASS test_no_stale_trigram_encoder") |
|
|
|
|
| def test_vocab(): |
| assert VOCAB == 288 |
| assert len(SPECIAL_VOCAB) == 32 |
| assert SPECIAL_VOCAB['IMAGE'] == 278 |
| print(" PASS test_vocab") |
|
|
|
|
| |
|
|
| def test_audio_sequencer_construction(): |
| aseq = AudioSequencer() |
| assert aseq.modality == 'audio' |
| assert aseq.window_size == 5 |
| assert aseq.whisper is not None |
| assert aseq.frame_proj is not None |
| for p in aseq.whisper.parameters(): |
| assert not p.requires_grad, f"Whisper param {p} should be frozen" |
| print(" PASS test_audio_sequencer_construction") |
|
|
|
|
| def test_audio_sequencer_forward_waveform(): |
| aseq = AudioSequencer() |
| waveform = torch.randn(2, 80000) |
| out = aseq(waveform) |
| assert out.dim() == 3 |
| assert out.shape[0] == 2 |
| assert out.shape[2] == HIDDEN_DIM |
| assert out.shape[1] > 0 |
| print(" PASS test_audio_sequencer_forward_waveform") |
|
|
|
|
| def test_audio_sequencer_forward_precomputed_mel(): |
| aseq = AudioSequencer() |
| mel = torch.randn(2, 80, 3000) |
| out = aseq(mel) |
| assert out.dim() == 3 |
| assert out.shape[0] == 2 |
| assert out.shape[2] == HIDDEN_DIM |
| print(" PASS test_audio_sequencer_forward_precomputed_mel") |
|
|
|
|
| def test_audio_sequencer_quantization_fp8(): |
| aseq = AudioSequencer(quantize_weights='fp8') |
| found_qlinear = False |
| for name, mod in aseq.whisper.named_modules(): |
| if 'QLinear' in type(mod).__name__: |
| found_qlinear = True |
| if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'): |
| assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8 data, got {mod.weight._data.dtype}" |
| break |
| assert found_qlinear, "No QLinear modules found — quantization may not have been applied" |
| print(" PASS test_audio_sequencer_quantization_fp8") |
|
|
|
|
| def test_audio_sequencer_quantization_int8(): |
| aseq = AudioSequencer(quantize_weights='int8') |
| found_qlinear = False |
| for name, mod in aseq.whisper.named_modules(): |
| if 'QLinear' in type(mod).__name__: |
| found_qlinear = True |
| break |
| assert found_qlinear, "No QLinear modules found — int8 quantization may not have been applied" |
| print(" PASS test_audio_sequencer_quantization_int8") |
|
|
|
|
| def test_audio_sequencer_no_quantize(): |
| aseq = AudioSequencer(quantize_weights=None) |
| for name, mod in aseq.whisper.named_modules(): |
| if 'QLinear' in type(mod).__name__: |
| assert False, "No QLinear should exist when quantize_weights=None" |
| for p in aseq.whisper.parameters(): |
| assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}" |
| print(" PASS test_audio_sequencer_no_quantize") |
|
|
|
|
| def test_image_sequencer_hf_vit(): |
| iseq = ImageSequencer() |
| assert iseq.modality == 'image' |
| assert iseq.window_size == 3 |
| img = torch.randn(1, 3, 224, 224) |
| out = iseq(img) |
| assert out.shape == (1, 194, HIDDEN_DIM) |
| print(" PASS test_image_sequencer_hf_vit") |
|
|
|
|
| def test_image_sequencer_quantization_fp8(): |
| iseq = ImageSequencer(quantize_weights='fp8') |
| found_qlinear = False |
| for name, mod in iseq.vit.named_modules(): |
| if 'QLinear' in type(mod).__name__: |
| found_qlinear = True |
| if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'): |
| assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8, got {mod.weight._data.dtype}" |
| break |
| assert found_qlinear, "No QLinear modules found in ViT — fp8 quantization may not have been applied" |
| print(" PASS test_image_sequencer_quantization_fp8") |
|
|
|
|
| def test_image_sequencer_no_quantize(): |
| iseq = ImageSequencer(quantize_weights=None) |
| for name, mod in iseq.vit.named_modules(): |
| if 'QLinear' in type(mod).__name__: |
| assert False, "No QLinear should exist when quantize_weights=None" |
| for p in iseq.vit.parameters(): |
| assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}" |
| print(" PASS test_image_sequencer_no_quantize") |
|
|
|
|
| def test_multimodal_sequencer_all_modalities(): |
| mseq = MultimodalSequencer() |
| assert 'text' in mseq.enabled_modalities |
| assert 'image' in mseq.enabled_modalities |
| assert 'audio' in mseq.enabled_modalities |
| assert mseq.text is not None |
| assert mseq.image is not None |
| assert mseq.audio is not None |
| print(" PASS test_multimodal_sequencer_all_modalities") |
|
|
|
|
| def test_multimodal_sequencer_text_only(): |
| mseq = MultimodalSequencer(enable_image=False, enable_audio=False) |
| assert mseq.enabled_modalities == ['text'] |
| assert mseq.image is None |
| assert mseq.audio is None |
| x = torch.randint(0, VOCAB, (2, 20)) |
| embedded = torch.randn(2, 20, EMBEDDING_DIM) |
| out = mseq({'text': embedded}) |
| assert 'text' in out |
| assert 'image' not in out |
| assert 'audio' not in out |
| print(" PASS test_multimodal_sequencer_text_only") |
|
|
|
|
| def test_multimodal_sequencer_full_forward(): |
| mseq = MultimodalSequencer() |
| embedded = torch.randn(2, 20, EMBEDDING_DIM) |
| img = torch.randn(2, 3, 224, 224) |
| audio = torch.randn(2, 80000) |
| out = mseq({'text': embedded, 'image': img, 'audio': audio}) |
| assert 'text' in out |
| assert 'image' in out |
| assert 'audio' in out |
| assert out['text'].shape[2] == HIDDEN_DIM |
| assert out['image'].shape[2] == HIDDEN_DIM |
| assert out['audio'].shape[2] == HIDDEN_DIM |
| print(" PASS test_multimodal_sequencer_full_forward") |
|
|
|
|
| def test_multimodal_vq_bridge_text_audio(): |
| bridge = MultimodalVQBridge() |
| text_in = torch.randn(2, 10, 512) |
| audio_in = torch.randn(2, 15, 512) |
| combined, losses, indices = bridge({'text': text_in, 'audio': audio_in}) |
| assert combined.shape == (2, 25, 512) |
| assert 'audio_vq' in losses |
| assert (indices['audio'] >= 12288).all() |
| assert (indices['audio'] < 16384).all() |
| print(" PASS test_multimodal_vq_bridge_text_audio") |
|
|
|
|
| def test_multimodal_vq_bridge_all_three(): |
| bridge = MultimodalVQBridge() |
| text_in = torch.randn(2, 10, 512) |
| image_in = torch.randn(2, 20, 512) |
| audio_in = torch.randn(2, 15, 512) |
| combined, losses, indices = bridge({'text': text_in, 'image': image_in, 'audio': audio_in}) |
| assert combined.shape == (2, 45, 512) |
| assert 'text_vq' in losses |
| assert 'image_vq' in losses |
| assert 'audio_vq' in losses |
| assert (indices['text'] < 8192).all() |
| assert (indices['image'] >= 8192).all() and (indices['image'] < 12288).all() |
| assert (indices['audio'] >= 12288).all() and (indices['audio'] < 16384).all() |
| print(" PASS test_multimodal_vq_bridge_all_three") |
|
|
|
|
| def test_modality_gate_three_modalities(): |
| gate = ModalityGate(num_modalities=3) |
| weights, count, hops = gate(['text', 'image', 'audio']) |
| assert count == 3 |
| assert 'text' in weights |
| assert 'image' in weights |
| assert 'audio' in weights |
| assert hops >= 2 |
| print(" PASS test_modality_gate_three_modalities") |
|
|
|
|
| def test_modality_gate_audio_only(): |
| gate = ModalityGate(num_modalities=3) |
| weights, count, hops = gate(['audio']) |
| assert count == 1 |
| assert 'audio' in weights |
| print(" PASS test_modality_gate_audio_only") |
|
|
|
|
| def test_model_forward_with_audio(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| audio = torch.randn(2, 80000) |
| logits, losses, indices, _ = model(x, audio=audio) |
| assert logits.shape[0] == 2 |
| assert logits.shape[2] == VOCAB |
| assert indices is not None |
| print(" PASS test_model_forward_with_audio") |
|
|
|
|
| def test_model_forward_all_modalities(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| img = torch.randn(2, 3, 224, 224) |
| audio = torch.randn(2, 80000) |
| logits, losses, indices, _ = model(x, targets=x[:, 3:], images=img, audio=audio) |
| assert losses is not None |
| assert isinstance(losses, LossComponents) |
| assert losses.total.ndim == 0 |
| assert losses.total > 0 |
| print(" PASS test_model_forward_all_modalities") |
|
|
|
|
| def test_model_audio_disabled_raises(): |
| model = ARBModel(enable_audio=False) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| audio = torch.randn(2, 80000) |
| try: |
| model(x, audio=audio) |
| assert False, "Should have raised ValueError" |
| except ValueError: |
| pass |
| print(" PASS test_model_audio_disabled_raises") |
|
|
|
|
| def test_audio_sequencer_gradient_flow(): |
| aseq = AudioSequencer() |
| waveform = torch.randn(2, 80000) |
| out = aseq(waveform) |
| loss = out.sum() |
| loss.backward() |
| assert aseq.frame_proj.weight.grad is not None, "frame_proj should get gradients" |
| assert aseq.projection.T_accum.grad is not None or True, "projection should participate" |
| print(" PASS test_audio_sequencer_gradient_flow") |
|
|
|
|
| def test_vq_bridge_audio_codebook_utilization(): |
| bridge = MultimodalVQBridge() |
| audio_in = torch.randn(4, 50, 512) |
| combined, losses, indices = bridge({'text': torch.randn(4, 10, 512), 'audio': audio_in}) |
| util = bridge.get_codebook_utilization() |
| assert 'audio' in util |
| dead = bridge.get_dead_code_count() |
| assert 'audio' in dead |
| print(" PASS test_vq_bridge_audio_codebook_utilization") |
|
|
|
|
| |
|
|
| def test_loss_components_nine_fields_total(): |
| w = LossWeights() |
| lc = LossComponents( |
| lm=torch.tensor(1.0, requires_grad=True), |
| vq_commitment=torch.tensor(0.1, requires_grad=True), |
| moe_aux=torch.tensor(0.1, requires_grad=True), |
| graph_l1=torch.tensor(0.1, requires_grad=True), |
| graph_ponder=torch.tensor(0.1, requires_grad=True), |
| moe_ponder=torch.tensor(0.1, requires_grad=True), |
| conv_vq_commitment=torch.tensor(0.1, requires_grad=True), |
| memgram_decay_reg=torch.tensor(0.01, requires_grad=True), |
| lstm_hidden_reg=torch.tensor(0.01, requires_grad=True), |
| ) |
| total = lc.total |
| expected = (w.lm * 1.0 + w.vq_commitment * 0.1 + w.moe_aux * 0.1 |
| + w.graph_l1 * 0.1 + w.graph_ponder * 0.1 + w.moe_ponder * 0.1 |
| + w.conv_vq_commitment * 0.1 + w.memgram_decay_reg * 0.01 |
| + w.lstm_hidden_reg * 0.01) |
| assert abs(total.item() - expected) < 1e-6, f"total {total.item()} != {expected}" |
| print(" PASS test_loss_components_nine_fields_total") |
|
|
| def test_loss_components_nine_fields_log(): |
| from types import SimpleNamespace |
| writer = SimpleNamespace() |
| writer.logged = [] |
| writer.add_scalar = lambda name, val, step: writer.logged.append((name, val, step)) |
| lc = LossComponents( |
| lm=torch.tensor(1.0, requires_grad=True), |
| vq_commitment=torch.tensor(0.1, requires_grad=True), |
| moe_aux=torch.tensor(0.1, requires_grad=True), |
| graph_l1=torch.tensor(0.1, requires_grad=True), |
| graph_ponder=torch.tensor(0.1, requires_grad=True), |
| moe_ponder=torch.tensor(0.1, requires_grad=True), |
| conv_vq_commitment=torch.tensor(0.1, requires_grad=True), |
| memgram_decay_reg=torch.tensor(0.01, requires_grad=True), |
| lstm_hidden_reg=torch.tensor(0.01, requires_grad=True), |
| ) |
| lc.log(writer, step=0, prefix="loss") |
| names = [x[0] for x in writer.logged] |
| assert "loss/conv_vq_commitment" in names |
| assert "loss/memgram_decay_reg" in names |
| assert "loss/lstm_hidden_reg" in names |
| assert "loss/total" in names |
| print(" PASS test_loss_components_nine_fields_log") |
|
|
|
|
| def test_loss_weights_custom_total(): |
| w = LossWeights(lm=2.0, vq_commitment=0.5, moe_aux=0.0, graph_l1=10.0) |
| lm_t = torch.tensor(1.0, requires_grad=True) |
| vq_t = torch.tensor(2.0, requires_grad=True) |
| lc = LossComponents(lm=lm_t, vq_commitment=vq_t, moe_aux=torch.tensor(5.0, requires_grad=True), graph_l1=torch.tensor(0.01, requires_grad=True), weights=w) |
| expected = 2.0*1.0 + 0.5*2.0 + 0.0*5.0 + 10.0*0.01 |
| assert abs(lc.total.item() - expected) < 1e-5, f"Custom weights total {lc.total.item()} != {expected}" |
| print(" PASS test_loss_weights_custom_total") |
|
|
|
|
| def test_loss_weights_zero_skips(): |
| w = LossWeights(vq_commitment=0.0, moe_aux=0.0) |
| p = torch.nn.Parameter(torch.tensor(1.0)) |
| lc = LossComponents(lm=p * 2.0, vq_commitment=p * 3.0, moe_aux=p * 4.0, weights=w) |
| total_val = lc.total.item() |
| assert abs(total_val - (1.0*2.0 + 0.0*3.0 + 0.0*4.0)) < 1e-5, f"Zero-weight total {total_val}" |
| lc.total.backward() |
| grad_val = p.grad.item() |
| assert abs(grad_val - 2.0) < 1e-5, f"Zero-weight grad {grad_val} (should be 2.0, only lm)" |
| print(" PASS test_loss_weights_zero_skips") |
|
|
|
|
| def test_loss_weights_backward_compat(): |
| old_default = LossWeights() |
| assert abs(old_default.lm - 1.0) < 1e-5 |
| assert abs(old_default.vq_commitment - 1.0) < 1e-5 |
| assert abs(old_default.moe_aux - 1.0) < 1e-5 |
| assert abs(old_default.graph_l1 - 0.001) < 1e-5 |
| assert abs(old_default.graph_ponder - 1.0) < 1e-5 |
| assert abs(old_default.moe_ponder - 1.0) < 1e-5 |
| assert abs(old_default.conv_vq_commitment - 0.1) < 1e-5 |
| assert abs(old_default.memgram_decay_reg - 0.01) < 1e-5 |
| assert abs(old_default.lstm_hidden_reg - 0.01) < 1e-5 |
| print(" PASS test_loss_weights_backward_compat") |
|
|
|
|
| def test_model_forward_loss_weights(): |
| model = ARBModel() |
| w = LossWeights(lm=2.0, vq_commitment=0.5) |
| x = torch.randint(0, VOCAB, (2, 66)) |
| _, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=w) |
| assert losses is not None |
| assert isinstance(losses.weights, LossWeights) |
| assert abs(losses.weights.lm - 2.0) < 1e-5 |
| assert abs(losses.weights.vq_commitment - 0.5) < 1e-5 |
| print(" PASS test_model_forward_loss_weights") |
|
|
|
|
| def test_model_forward_no_hardcoded_graph_l1(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| _, losses_no_target, _, _ = model(x) |
| _, losses, _, _ = model(x, targets=x[:, 3:]) |
| assert losses is not None |
| assert losses.graph_l1 is not None |
| assert losses.vq_commitment is not None |
| print(" PASS test_model_forward_no_hardcoded_graph_l1") |
|
|
|
|
| def test_build_param_groups_shape(): |
| from train import build_param_groups |
| model = ARBModel() |
| groups = build_param_groups(model, base_lr=1e-3, vq_lr_scale=2.0, memory_lr_scale=0.5) |
| group_names = [g['name'] for g in groups] |
| assert 'graph' in group_names |
| assert 'memory' in group_names |
| assert 'patch_proj' in group_names or 'frame_proj' in group_names |
| for g in groups: |
| assert 'lr' in g |
| assert 'params' in g |
| print(f" Groups: {group_names}") |
| print(" PASS test_build_param_groups_shape") |
|
|
|
|
| def test_pinpoint_gradient_isolation(): |
| from train import pinpoint_backward, build_param_groups |
| model = ARBModel() |
| param_groups = build_param_groups(model, base_lr=1e-3) |
|
|
| x = torch.randint(0, VOCAB, (2, 66)) |
| _, losses, _, _ = model(x, targets=x[:, 3:]) |
|
|
| |
| lw_vq = LossWeights(lm=0.0, vq_commitment=1.0, moe_aux=0.0, graph_l1=0.0, |
| graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, |
| memgram_decay_reg=0.0, lstm_hidden_reg=0.0) |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, lw_vq, param_groups, free_graph=False) |
|
|
| for g in param_groups: |
| for p in g['params']: |
| if g['name'] in ('vq_projection', 'patch_proj', 'frame_proj'): |
| continue |
| if g['name'] == 'vq_codebook': |
| continue |
| assert p.grad is None, f"{g['name']} param should NOT have grad with only vq_commitment" |
|
|
| |
| lw_moe = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=1.0, graph_l1=0.0, |
| graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, |
| memgram_decay_reg=0.0, lstm_hidden_reg=0.0) |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, lw_moe, param_groups, free_graph=False) |
|
|
| for g in param_groups: |
| for p in g['params']: |
| if g['name'] in ('moe', 'moe_act'): |
| continue |
| assert p.grad is None, f"{g['name']} param should NOT have grad with only moe_aux" |
|
|
| |
| lw_graph = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=1.0, |
| graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, |
| memgram_decay_reg=0.0, lstm_hidden_reg=0.0) |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, lw_graph, param_groups, free_graph=False) |
|
|
| for g in param_groups: |
| for p in g['params']: |
| if g['name'] == 'graph': |
| continue |
| assert p.grad is None, f"{g['name']} param should NOT have grad with only graph_l1" |
|
|
| print(" PASS test_pinpoint_gradient_isolation") |
|
|
|
|
| def test_pinpoint_backward_accumulation(): |
| from train import pinpoint_backward, build_param_groups |
| model = ARBModel() |
| param_groups = build_param_groups(model, base_lr=1e-3) |
|
|
| x = torch.randint(0, VOCAB, (2, 66)) |
| _, losses, _, _ = model(x, targets=x[:, 3:]) |
|
|
| lw = LossWeights(lm=1.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=0.0, |
| graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0, |
| memgram_decay_reg=0.0, lstm_hidden_reg=0.0) |
|
|
| |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, lw, param_groups, grad_accum=1, free_graph=False) |
|
|
| norms_full = {} |
| for g in param_groups: |
| for p in g['params']: |
| if p.grad is not None: |
| norms_full[id(p)] = p.grad.data.norm().item() |
|
|
| |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, lw, param_groups, grad_accum=2, free_graph=False) |
|
|
| checked = False |
| for g in param_groups: |
| for p in g['params']: |
| if p.grad is not None and id(p) in norms_full and norms_full[id(p)] > 1e-8: |
| n_half = p.grad.data.norm().item() |
| ratio = n_half / norms_full[id(p)] |
| assert abs(ratio - 0.5) < 0.01, \ |
| f"grad_accum=2 should halve grads, got {ratio:.4f}" |
| checked = True |
| break |
| if checked: |
| break |
| assert checked, "no param with non-zero grad in both runs" |
| print(" PASS test_pinpoint_backward_accumulation") |
|
|
|
|
| def test_pinpoint_backward_rescale_effect(): |
| from train import pinpoint_backward, build_param_groups, DEFAULT_LOSS_TARGET_MAP |
| model = ARBModel() |
| loss_weights = LossWeights() |
| param_groups = build_param_groups(model, base_lr=1e-3) |
| |
| x = torch.randint(0, VOCAB, (2, 66)) |
| _, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=loss_weights) |
| grads_reference = {} |
| for g in param_groups: |
| for p in g['params']: |
| grads_reference[id(p)] = torch.randn_like(p) |
|
|
| |
| for g in param_groups: |
| for p in g['params']: |
| p.grad = None |
| pinpoint_backward(losses, loss_weights, param_groups, memory_grad_scale=0.25) |
|
|
| |
| memory_group = next((g for g in param_groups if g.get('name') == 'memory'), None) |
| if memory_group and memory_group['params'] and memory_group['params'][0].grad is not None: |
| for p in memory_group['params']: |
| assert p.grad is not None, "memory param should have grad" |
| |
| lm_group = next((g for g in param_groups if g.get('name') == 'lm_core'), None) |
| if lm_group and lm_group['params']: |
| lm_has_grad = any(p.grad is not None for p in lm_group['params']) |
| assert lm_has_grad, "lm_core should have grads from lm loss" |
| print(" PASS test_pinpoint_backward_rescale_effect") |
|
|
| def test_moe_router_h_with_h_t(): |
| moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2, |
| core_rank=16, shared_inter=128, noise_std=0.0) |
| moe.lstm_enabled = True |
| x = torch.randn(2, 10, 64) |
| h_t = torch.randn(2, 64) |
| out, aux = moe(x, h_t=h_t) |
| assert out.shape == (2, 10, 64) |
| out.sum().backward() |
| assert hasattr(moe.router_h, '_hook_grad_T_sign'), "router_h not trained" |
| print(" PASS test_moe_router_h_with_h_t") |
|
|
| def test_moe_router_without_h_t(): |
| moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2, |
| core_rank=16, shared_inter=128, noise_std=0.0) |
| x = torch.randn(2, 10, 64) |
| out, aux = moe(x, h_t=None) |
| assert out.shape == (2, 10, 64) |
| out.sum().backward() |
| assert hasattr(moe.router, '_hook_grad_T_sign'), "original router not trained when no h_t" |
| print(" PASS test_moe_router_without_h_t") |
|
|
| def test_memgram_shapes(): |
| mg = MemGram(struct_primes=[101,103,107,109], conv_primes=[53,59,61,67]) |
| vq_idx = torch.randint(0, 100, (4, 20)) |
| hs = torch.randn(4, 20, 512) |
| out, decay = mg(vq_idx, None, None, hs, timestep=100) |
| assert out.shape == (4, 20, 512), f"output shape {out.shape}" |
| assert decay.ndim == 0 |
| print(" PASS test_memgram_shapes") |
|
|
| def test_memgram_hash_indices(): |
| mg = MemGram(struct_primes=[101,103], conv_primes=[53,59]) |
| prev = torch.randint(0, 100, (4, 19)) |
| curr = torch.randint(0, 100, (4, 19)) |
| h = mg._hash_pairs(prev, curr, [101, 103]) |
| assert h.shape == (4, 19, 2), f"hash shape {h.shape}" |
| assert (h[..., 0] < 101).all(), "hash exceeds prime range" |
| assert (h[..., 1] < 103).all(), "hash exceeds prime range" |
| print(" PASS test_memgram_hash_indices") |
|
|
| def test_memgram_bilinear_gate_range(): |
| mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], key_dim=8, embed_dim=16) |
| vq_idx = torch.randint(0, 80, (2, 10)) |
| hs = torch.randn(2, 10, 512) |
| out, _ = mg(vq_idx, None, None, hs, timestep=50) |
| assert out.shape == (2, 10, 512) |
| assert torch.isfinite(out).all() |
| print(" PASS test_memgram_bilinear_gate_range") |
|
|
| def test_memgram_decay_formula(): |
| mg = MemGram(struct_primes=[101,103], conv_primes=[53,59]) |
| s = torch.zeros(1) |
| r = torch.zeros(1) |
| decay = mg._compute_decay(s, r, torch.tensor(100.0)) |
| expected = torch.sigmoid(torch.zeros(1)) * torch.exp(-torch.exp(torch.zeros(1)) * 100.0) |
| assert torch.allclose(decay, expected, atol=1e-6), f"decay {decay} != {expected}" |
| print(" PASS test_memgram_decay_formula") |
|
|
| def test_memgram_gradient_flow(): |
| mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64) |
| vq_idx = torch.randint(0, 80, (2, 10)) |
| hs = torch.randn(2, 10, 64) |
| out, decay = mg(vq_idx, None, None, hs, timestep=50) |
| loss = out.sum() + decay |
| loss.backward() |
| assert mg.struct_emb[0].grad is not None, "no gradient to struct_emb" |
| assert mg.struct_emb[0].grad.abs().sum().item() > 0 |
| print(" PASS test_memgram_gradient_flow") |
|
|
| def test_memgram_conv_path(): |
| mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64) |
| vq_idx = torch.randint(0, 80, (2, 10)) |
| hs = torch.randn(2, 10, 64) |
| out_no_conv, _ = mg(vq_idx, None, None, hs, timestep=50) |
| conv_code = torch.randint(0, 50, (2,)) |
| out_with_conv, _ = mg(vq_idx, conv_code, conv_code, hs, timestep=50) |
| assert not torch.allclose(out_no_conv, out_with_conv, atol=1e-6), "conv path should change output" |
| print(" PASS test_memgram_conv_path") |
|
|
| def test_conv_vq_shapes(): |
| cvq = ConvVQCodebook(codebook_size=16, code_dim=8) |
| x = torch.randn(4, 512) |
| code, quantized, commitment = cvq(x, step=500, enabled=True) |
| assert code.shape == (4,), f"code shape {code.shape}" |
| assert quantized.shape == (4, 512), f"quantized shape {quantized.shape}" |
| assert commitment.ndim == 0 |
| print(" PASS test_conv_vq_shapes") |
|
|
| def test_conv_vq_hard_cap(): |
| cvq = ConvVQCodebook(codebook_size=8, code_dim=8) |
| for i in range(12): |
| x = torch.randn(4, 512) |
| cvq(x, step=i, enabled=True) |
| assert cvq.n_active.item() == 8, f"n_active={cvq.n_active.item()} (should be 8)" |
| print(" PASS test_conv_vq_hard_cap") |
|
|
| def test_conv_vq_deferred_activation(): |
| cvq = ConvVQCodebook(codebook_size=8, code_dim=8) |
| x = torch.randn(4, 512) |
| code, quantized, commitment = cvq(x, step=500, enabled=False) |
| assert torch.equal(code, torch.zeros(4, dtype=torch.long)), "code should be zeros when disabled" |
| assert commitment.item() == 0.0, "commitment should be 0 when disabled" |
| print(" PASS test_conv_vq_deferred_activation") |
|
|
| def test_conv_vq_ema_update(): |
| cvq = ConvVQCodebook(codebook_size=4, code_dim=8) |
| for i in range(4): |
| cvq(torch.randn(1, 512) + i, step=i, enabled=True) |
| embed_before = cvq.embed.clone() |
| for i in range(4): |
| cvq(torch.randn(1, 512), step=10 + i, enabled=True) |
| embed_after = cvq.embed.clone() |
| assert not torch.allclose(embed_before, embed_after), "entries should change through EMA + replacement" |
| print(" PASS test_conv_vq_ema_update") |
|
|
| def test_conv_vq_persistence(): |
| cvq = ConvVQCodebook(codebook_size=8, code_dim=8) |
| x = torch.randn(2, 512) |
| cvq(x, step=0, enabled=True) |
| sd = cvq.state_dict() |
| cvq2 = ConvVQCodebook(codebook_size=8, code_dim=8) |
| cvq2.load_state_dict(sd) |
| assert torch.allclose(cvq.embed, cvq2.embed) |
| assert torch.equal(cvq.timestamps, cvq2.timestamps) |
| assert torch.equal(cvq.n_active, cvq2.n_active) |
| assert torch.equal(cvq.cluster_size, cvq2.cluster_size) |
| print(" PASS test_conv_vq_persistence") |
|
|
| def test_conv_vq_fuzzy_retrieve(): |
| cvq = ConvVQCodebook(codebook_size=16, code_dim=8) |
| for i in range(3): |
| x = torch.randn(2, 512) |
| cvq(x, step=i, enabled=True) |
| query = torch.randn(8) |
| idx, sim = cvq.fuzzy_retrieve(query, top_k=3) |
| assert idx.numel() == 3, f"retrieved {idx.numel()} elements, expected 3" |
| assert sim.numel() == 3 |
| print(" PASS test_conv_vq_fuzzy_retrieve") |
|
|
| def test_conv_vq_commitment_nonneg(): |
| cvq = ConvVQCodebook(codebook_size=8, code_dim=8) |
| x = torch.randn(2, 512) |
| _, _, commitment = cvq(x, step=0, enabled=True) |
| assert commitment.item() >= 0, f"negative commitment {commitment.item()}" |
| print(" PASS test_conv_vq_commitment_nonneg") |
|
|
| def test_lstm_shapes(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(4, 64) |
| h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None) |
| assert h_out.shape == (4, 64), f"h_out shape {h_out.shape}" |
| assert c_focus.shape == (4, 64), f"c_focus shape {c_focus.shape}" |
| assert h_topic.shape == (4, 64), f"h_topic shape {h_topic.shape}" |
| assert c_topic.shape == (4, 64), f"c_topic shape {c_topic.shape}" |
| assert c_proj.shape == (4, 64), f"c_proj shape {c_proj.shape}" |
| assert reg.ndim == 0 |
| print(" PASS test_lstm_shapes") |
|
|
| def test_lstm_forget_gate_bias(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| bias_ih = lstm.focus_cell.bias_ih[64:128] |
| assert torch.allclose(bias_ih, torch.ones_like(bias_ih)), "focus forget gate bias not 1.0" |
| bias_ih_topic = lstm.topic_cell.bias_ih[64:128] |
| assert torch.allclose(bias_ih_topic, torch.full_like(bias_ih_topic, 1.5)), "topic forget gate bias not 1.5" |
| print(" PASS test_lstm_forget_gate_bias") |
|
|
| def test_lstm_bptt_detach(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10) |
| x = torch.randn(2, 64) |
| memory = None |
| for i in range(49): |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory) |
| memory = (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()) |
| assert h_out.grad_fn is not None, "grad_fn should exist before BPTT boundary" |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory) |
| assert h_out.grad_fn is None, "h_out should be detached at BPTT focus boundary" |
| print(" PASS test_lstm_bptt_detach") |
|
|
| def test_lstm_hidden_reg(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(2, 64) |
| _, _, _, _, _, reg = lstm(x, None) |
| h_out, _, _, _, _, _ = lstm(x, None) |
| expected = (h_out ** 2).mean() |
| assert torch.allclose(reg, expected, atol=1e-6), f"reg {reg} != expected {expected}" |
| print(" PASS test_lstm_hidden_reg") |
|
|
| def test_lstm_c_t_proj_ternary(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| assert isinstance(lstm.c_focus_proj, TernaryScaleTensor), "c_focus_proj not TernaryScaleTensor" |
| assert isinstance(lstm.c_topic_proj, TernaryScaleTensor), "c_topic_proj not TernaryScaleTensor" |
| print(" PASS test_lstm_c_t_proj_ternary") |
|
|
| def test_memory_modules_backward_compat(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, CTX)) |
| logits, losses, indices, _ = model(x, targets=x[:, 3:]) |
| assert logits.shape[0] == 2 |
| assert losses is not None |
| print(" PASS test_memory_modules_backward_compat") |
|
|
|
|
| |
|
|
| def test_forward_no_memory_backward_compat(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, indices, mem_state = model(x, targets=targets) |
| assert logits.shape == (2, 64, VOCAB), f"logits shape {logits.shape}" |
| assert mem_state is None, f"memory_state should be None when lstm disabled, got {mem_state}" |
| print(" PASS test_forward_no_memory_backward_compat") |
|
|
| def test_forward_lstm_enabled_h_t_passed(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits, losses, indices, mem_state = model(x, targets=targets, memory_state=None, timestep=0) |
| assert mem_state is not None, "memory_state should be tuple when lstm enabled" |
| h_out, c_focus, h_topic, c_topic = mem_state |
| assert h_out.shape == (2, 512), f"h_out shape {h_out.shape}" |
| assert c_focus.shape == (2, 512), f"c_focus shape {c_focus.shape}" |
| assert h_topic.shape == (2, 512), f"h_topic shape {h_topic.shape}" |
| assert c_topic.shape == (2, 512), f"c_topic shape {c_topic.shape}" |
| print(" PASS test_forward_lstm_enabled_h_t_passed") |
|
|
| def test_forward_lstm_c_t_residual(): |
| model = ARBModel() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| targets = x[:, 3:] |
| logits_no_lstm, _, _, _ = model(x, targets=targets) |
| model.lstm_enabled = True |
| logits_lstm, _, _, _ = model(x, targets=targets, memory_state=None, timestep=0) |
| assert not torch.allclose(logits_no_lstm, logits_lstm, atol=1e-4), "c_t_proj should modify output" |
| print(" PASS test_forward_lstm_c_t_residual") |
|
|
| def test_forward_memgram_injection(): |
| model = ARBModel() |
| model.memgram_enabled = True |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100) |
| assert losses.memgram_decay_reg is not None, "memgram_decay_reg should be set" |
| print(" PASS test_forward_memgram_injection") |
|
|
| def test_forward_conv_vq_deferred(): |
| model = ARBModel() |
| model.conv_vq_enabled = True |
| model._conv_vq_ready = False |
| x = torch.randint(0, VOCAB, (2, 66)) |
| logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100) |
| assert losses.conv_vq_commitment is None or losses.conv_vq_commitment.item() == 0.0, \ |
| f"conv_vq should be deferred, got {losses.conv_vq_commitment}" |
| print(" PASS test_forward_conv_vq_deferred") |
|
|
| def test_generate_carries_lstm_state(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| idx = torch.zeros((1, 10), dtype=torch.long) |
| out = model.generate(idx, max_new_token=10) |
| assert out.shape == (1, 20), f"output shape {out.shape}" |
| assert out.dtype == torch.long |
| print(" PASS test_generate_carries_lstm_state") |
|
|
|
|
| |
|
|
| def test_memory_schedule_warmup(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(0, 10000) |
| assert not any([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all off during warmup" |
| print(" PASS test_memory_schedule_warmup") |
|
|
| def test_memory_schedule_lstm_first(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(2500, 10000) |
| assert lstm_on, "lstm should be on after warmup" |
| assert not memgram_on, "memgram should be off at 25%" |
| print(" PASS test_memory_schedule_lstm_first") |
|
|
| def test_memory_schedule_memgram_second(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(3500, 10000, vq_utilization=0.4) |
| assert lstm_on and memgram_on and conv_vq_on, "lstm+memgram+conv_vq on at 35% with util>30%" |
| assert not decay_reg_on, "decay_reg should be off at 35%" |
| print(" PASS test_memory_schedule_memgram_second") |
|
|
| def test_memory_schedule_all_on(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(5000, 10000, vq_utilization=0.4) |
| assert all([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all on at 50%" |
| print(" PASS test_memory_schedule_all_on") |
|
|
| def test_memory_schedule_conv_vq_requires_vq_util(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4000, 10000, vq_utilization=0.1) |
| assert lstm_on, "lstm should be on" |
| assert not conv_vq_on, "conv_vq should be off when util < 30%" |
| print(" PASS test_memory_schedule_conv_vq_requires_vq_util") |
|
|
| def test_memory_schedule_decay_reg_last(): |
| from train import compute_memory_schedule |
| lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4500, 10000, vq_utilization=0.4) |
| assert decay_reg_on, "decay_reg should be on at 45%" |
| print(" PASS test_memory_schedule_decay_reg_last") |
|
|
| def test_lstm_state_reset_per_batch(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| model.eval() |
| x = torch.randint(0, VOCAB, (2, 66)) |
| with torch.no_grad(): |
| _, _, _, mem1 = model(x, memory_state=None, timestep=0) |
| _, _, _, mem2 = model(x, memory_state=None, timestep=1) |
| h1, c1_f, h1_t, c1_t = mem1 |
| h2, c2_f, h2_t, c2_t = mem2 |
| assert h1.shape == h2.shape, "h_out shapes should match" |
| assert c1_f.shape == c2_f.shape, "c_focus shapes should match" |
| print(" PASS test_lstm_state_reset_per_batch") |
|
|
| def test_bptt_counter_separate(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10) |
| x = torch.randn(2, 64) |
| memory = None |
| for i in range(4): |
| h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) |
| memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) |
| assert lstm.step_count == 4, f"step_count={lstm.step_count}" |
| h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) |
| assert c_f.grad_fn is None, "c_focus should be detached at BPTT focus boundary" |
| print(" PASS test_bptt_counter_separate") |
|
|
|
|
| |
|
|
| def test_focus_gate_no_boundary(): |
| fg = FocusGate(hidden_dim=64) |
| x = torch.randn(2, 64) |
| reset, dampen = fg(x, boundary_signal=None) |
| assert reset.shape == (2, 1), f"reset shape {reset.shape}" |
| assert dampen.shape == (2, 64), f"dampen shape {dampen.shape}" |
| assert torch.allclose(reset, torch.ones_like(reset)), "reset should be 1.0 for no boundary" |
| assert torch.allclose(dampen, torch.ones_like(dampen)), "dampen should be 1.0 for no boundary" |
| print(" PASS test_focus_gate_no_boundary") |
|
|
| def test_focus_gate_boundary_signal(): |
| fg = FocusGate(hidden_dim=64) |
| x = torch.randn(2, 64) |
| reset_bos, dampen_bos = fg(x, boundary_signal=SPECIAL_VOCAB['BOS']) |
| assert reset_bos.shape == (2, 1) |
| assert dampen_bos.shape == (2, 64) |
| assert 0.0 <= reset_bos.mean().item() <= 1.0, f"reset out of range: {reset_bos.mean().item()}" |
| assert 0.0 <= dampen_bos.mean().item() <= 1.0, f"dampen out of range: {dampen_bos.mean().item()}" |
| print(" PASS test_focus_gate_boundary_signal") |
|
|
| def test_focus_gate_all_boundary_types(): |
| fg = FocusGate(hidden_dim=64) |
| x = torch.randn(2, 64) |
| for tok_name, tok_id in [('BOS', SPECIAL_VOCAB['BOS']), |
| ('SYSTEM', SPECIAL_VOCAB['SYSTEM']), |
| ('USER', SPECIAL_VOCAB['USER']), |
| ('ASSISTANT', SPECIAL_VOCAB['ASSISTANT'])]: |
| reset, dampen = fg(x, boundary_signal=tok_id) |
| assert reset.mean().item() < 1.0, f"{tok_name} should reduce reset from 1.0" |
| no_reset, _ = fg(x, boundary_signal=None) |
| assert torch.allclose(no_reset, torch.ones_like(no_reset)), "no-boundary reset should be 1.0" |
| print(" PASS test_focus_gate_all_boundary_types") |
|
|
| def test_focus_gate_unknown_token(): |
| fg = FocusGate(hidden_dim=64) |
| x = torch.randn(2, 64) |
| reset, dampen = fg(x, boundary_signal=9999) |
| assert torch.allclose(reset, torch.ones_like(reset)), "unknown token should return reset=1.0" |
| assert torch.allclose(dampen, torch.ones_like(dampen)), "unknown token should return dampen=1.0" |
| print(" PASS test_focus_gate_unknown_token") |
|
|
| def test_focus_gate_c_focus_modulation(): |
| fg = FocusGate(hidden_dim=64) |
| x = torch.randn(2, 64) |
| c_focus = torch.randn(2, 64) |
| c_focus_before = c_focus.clone() |
| reset, dampen = fg(x, boundary_signal=SPECIAL_VOCAB['USER']) |
| c_focus_mod = c_focus * reset * dampen |
| assert not torch.allclose(c_focus_mod, c_focus_before, atol=1e-6), "focus gate should modify c_focus on boundary" |
| reset_none, dampen_none = fg(x, boundary_signal=None) |
| c_focus_noop = c_focus * reset_none * dampen_none |
| assert torch.allclose(c_focus_noop, c_focus, atol=1e-6), "focus gate should not modify c_focus without boundary" |
| print(" PASS test_focus_gate_c_focus_modulation") |
|
|
|
|
| |
|
|
| def test_conv_stack_push_pop(): |
| stack = ConversationStack(max_conversations=4, hidden_dim=64) |
| h = torch.randn(64) |
| c_f = torch.randn(64) |
| h_t = torch.randn(64) |
| c_t = torch.randn(64) |
| stack.push("conv_1", h, c_f, h_t, c_t, "cpu") |
| result = stack.pop("conv_1", "cpu") |
| assert result is not None, "pop should find pushed conversation" |
| h_rest, c_f_rest, h_t_rest, c_t_rest = result |
| assert torch.allclose(h_rest, h, atol=1e-6), "h_focus not preserved" |
| assert torch.allclose(c_f_rest, c_f, atol=1e-6), "c_focus not preserved" |
| assert torch.allclose(h_t_rest, h_t, atol=1e-6), "h_topic not preserved" |
| assert torch.allclose(c_t_rest, c_t, atol=1e-6), "c_topic not preserved" |
| print(" PASS test_conv_stack_push_pop") |
|
|
| def test_conv_stack_pop_missing(): |
| stack = ConversationStack(max_conversations=4, hidden_dim=64) |
| result = stack.pop("nonexistent", "cpu") |
| assert result is None, "pop on nonexistent should return None" |
| print(" PASS test_conv_stack_pop_missing") |
|
|
| def test_conv_stack_clear(): |
| stack = ConversationStack(max_conversations=4, hidden_dim=64) |
| h = torch.randn(64) |
| stack.push("conv_1", h, torch.randn(64), torch.randn(64), torch.randn(64), "cpu") |
| stack.clear("conv_1") |
| result = stack.pop("conv_1", "cpu") |
| assert result is None, "cleared conversation should not be found" |
| print(" PASS test_conv_stack_clear") |
|
|
| def test_conv_stack_lru_eviction(): |
| stack = ConversationStack(max_conversations=2, hidden_dim=64) |
| for i in range(3): |
| h = torch.full((64,), float(i)) |
| stack.push(f"conv_{i}", h, torch.zeros(64), torch.zeros(64), torch.zeros(64), "cpu") |
| result_0 = stack.pop("conv_0", "cpu") |
| assert result_0 is None, "conv_0 should be evicted (LRU)" |
| result_1 = stack.pop("conv_1", "cpu") |
| assert result_1 is not None, "conv_1 should still exist" |
| result_2 = stack.pop("conv_2", "cpu") |
| assert result_2 is not None, "conv_2 should still exist" |
| print(" PASS test_conv_stack_lru_eviction") |
|
|
| def test_conv_stack_reset(): |
| stack = ConversationStack(max_conversations=4, hidden_dim=64) |
| stack.push("conv_1", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu") |
| stack.push("conv_2", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu") |
| stack.reset() |
| assert stack.pop("conv_1", "cpu") is None, "conv_1 should be gone after reset" |
| assert stack.pop("conv_2", "cpu") is None, "conv_2 should be gone after reset" |
| assert stack.active_slot == -1, "active_slot should be -1 after reset" |
| print(" PASS test_conv_stack_reset") |
|
|
|
|
| |
|
|
| def test_conversation_lstm_forward(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=50, bptt_topic=200) |
| x = torch.randn(4, 64) |
| h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None) |
| assert h_out.shape == (4, 64) |
| assert c_focus.shape == (4, 64) |
| assert h_topic.shape == (4, 64) |
| assert c_topic.shape == (4, 64) |
| assert c_proj.shape == (4, 64) |
| assert reg.ndim == 0 |
| print(" PASS test_conversation_lstm_forward") |
|
|
| def test_conversation_lstm_dual_state(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(2, 64) |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None) |
| for _ in range(10): |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach())) |
| assert c_focus.abs().sum() > 0, "c_focus should be nonzero after steps" |
| assert c_topic.abs().sum() > 0, "c_topic should be nonzero after steps" |
| print(" PASS test_conversation_lstm_dual_state") |
|
|
| def test_conversation_lstm_boundary_reset(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(2, 64) |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None) |
| for _ in range(5): |
| h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach())) |
| c_focus_before = c_focus.clone() |
| h_out, c_focus_after, _, _, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()), boundary_signal=SPECIAL_VOCAB['BOS']) |
| assert lstm._last_reset < 1.0, "BOS boundary should reduce reset from 1.0" |
| print(" PASS test_conversation_lstm_boundary_reset") |
|
|
| def test_conversation_lstm_topic_gate(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(2, 64) |
| _, _, h_topic_0, _, _, _ = lstm(x, None) |
| assert h_topic_0.abs().sum() > 0, "h_topic should be nonzero after one step" |
| print(" PASS test_conversation_lstm_topic_gate") |
|
|
| def test_conversation_lstm_bptt_dual_windows(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=3, bptt_topic=6) |
| x = torch.randn(2, 64) |
| memory = None |
| for i in range(2): |
| h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) |
| memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) |
| assert h_out.grad_fn is not None, "h_out should have grad before focus BPTT" |
| h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) |
| assert c_f.grad_fn is None, "c_focus should be detached at focus BPTT boundary (step 3)" |
| print(" PASS test_conversation_lstm_bptt_dual_windows") |
|
|
| def test_conversation_lstm_topic_preserves_on_boundary(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| x = torch.randn(2, 64) |
| memory = None |
| for _ in range(5): |
| h_out, c_f, h_t, c_t, _, _ = lstm(x, memory) |
| memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach()) |
| c_topic_before = c_t.clone() |
| _, c_f_after, _, c_t_after, _, _ = lstm(x, memory, boundary_signal=SPECIAL_VOCAB['USER']) |
| decay_ratio = c_t_after.norm() / max(c_topic_before.norm(), 1e-8) |
| focus_decay = c_f_after.norm() / max(c_f.norm(), 1e-8) |
| assert decay_ratio >= focus_decay * 0.9, f"topic should decay less than focus on boundary: topic={decay_ratio:.3f} focus={focus_decay:.3f}" |
| print(" PASS test_conversation_lstm_topic_preserves_on_boundary") |
|
|
| def test_conversation_lstm_h_out_is_sum(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| lstm.eval() |
| x = torch.randn(2, 64) |
| with torch.no_grad(): |
| h_out, c_f, h_topic, c_topic, c_proj, _ = lstm(x, None) |
| assert h_out.shape == (2, 64) |
| assert c_proj.shape == (2, 64) |
| print(" PASS test_conversation_lstm_h_out_is_sum") |
|
|
| def test_extract_boundary_from_input(): |
| x_bos = torch.tensor([[SPECIAL_VOCAB['BOS'], 10, 20]]) |
| assert _extract_boundary_from_input(x_bos) == SPECIAL_VOCAB['BOS'], "should detect BOS" |
| x_user = torch.tensor([[10, SPECIAL_VOCAB['USER'], 20]]) |
| assert _extract_boundary_from_input(x_user) == SPECIAL_VOCAB['USER'], "should detect USER" |
| x_none = torch.tensor([[10, 20, 30]]) |
| assert _extract_boundary_from_input(x_none) is None, "should return None for no boundary" |
| print(" PASS test_extract_boundary_from_input") |
|
|
|
|
| |
|
|
| def test_model_switch_conversation(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| model.switch_conversation("conv_A") |
| assert model.lstm.conv_stack._current_conv_id == "conv_A" |
| print(" PASS test_model_switch_conversation") |
|
|
| def test_model_reset_conversation(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| model.switch_conversation("conv_A") |
| model.reset_conversation("conv_A") |
| assert model.lstm.conv_stack._current_conv_id is None |
| print(" PASS test_model_reset_conversation") |
|
|
| def test_model_generate_with_conversation_id(): |
| model = ARBModel() |
| model.lstm_enabled = True |
| idx = torch.zeros((1, 10), dtype=torch.long) |
| out = model.generate(idx, max_new_token=5, conversation_id="conv_test") |
| assert out.shape == (1, 15), f"output shape {out.shape}" |
| print(" PASS test_model_generate_with_conversation_id") |
|
|
| def test_conversation_lstm_ternary_projections(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| assert isinstance(lstm.c_focus_proj, TernaryScaleTensor) |
| assert isinstance(lstm.c_topic_proj, TernaryScaleTensor) |
| print(" PASS test_conversation_lstm_ternary_projections") |
|
|
| def test_conversation_lstm_focus_gate_params(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| assert isinstance(lstm.focus_gate, FocusGate) |
| assert isinstance(lstm.focus_gate.boundary_embed, nn.Embedding) |
| assert isinstance(lstm.focus_gate.reset_fc, nn.Linear) |
| assert isinstance(lstm.focus_gate.dampen_fc, nn.Linear) |
| print(" PASS test_conversation_lstm_focus_gate_params") |
|
|
| def test_conversation_lstm_topic_cell_bias(): |
| lstm = ConversationLSTM(input_dim=64, hidden_dim=64) |
| bias_topic = lstm.topic_cell.bias_ih[64:128] |
| assert torch.allclose(bias_topic, torch.full_like(bias_topic, 1.5)), f"topic forget gate bias should be 1.5, got {bias_topic.mean().item()}" |
| print(" PASS test_conversation_lstm_topic_cell_bias") |
|
|
|
|
| if __name__ == "__main__": |
| tests = [ |
| test_sticky_zone_ste, |
| test_sticky_zone_ste_dtype_preservation, |
| test_scaled_ternary_linear, |
| test_rmsnorm, |
| test_byte_embedding, |
| test_text_sequencer, |
| test_trigram_window, |
| test_image_sequencer, |
| test_image_sequencer_frozen, |
| test_target_alignment, |
| test_model_forward, |
| test_generate, |
| test_param_count, |
| test_gradient_flow, |
| test_model_forward_with_targets, |
| test_save_load_roundtrip, |
| test_vq_adapter_shapes, |
| test_vq_integration, |
| test_vq_disabled, |
| test_vq_with_targets, |
| test_l2_distance_matching, |
| test_vq_ternary_projections, |
| test_multimodal_vq_bridge_text_only, |
| test_multimodal_vq_bridge_text_image, |
| test_modality_gate_shapes, |
| test_ternary_graph_multicodebook, |
| test_vq_no_float_cast_in_model, |
| test_zero_fp32_params, |
| test_sticky_zone_ste_gradient, |
| test_graph_moe_gate_shape, |
| test_ternary_graph_shapes, |
| test_graph_gradient_flow, |
| test_graph_connectivity_monitor, |
| test_model_forward_with_graph, |
| test_model_graph_disabled, |
| test_ternary_graph_in_modules, |
| test_moe_shapes, |
| test_moe_router, |
| test_moe_aux_loss, |
| test_shared_expert, |
| test_moe_gradient_flow, |
| test_moe_zero_fp32, |
| test_ternary_graph_with_gate, |
| test_model_forward_with_moe, |
| test_model_moe_disabled, |
| test_model_moe_loss_components, |
| test_model_moe_gate_modulation, |
| test_param_count_with_moe, |
| test_moe_monitoring, |
| test_loss_components, |
| test_loss_components_none_fields, |
| test_loss_components_backward, |
| test_gnn_lora_adapter, |
| test_gnn_lora_gradient, |
| test_shared_gnn_weight_tying, |
| test_shared_gnn_multi_hop, |
| test_model_losses_components_type, |
| test_halting_unit_shapes, |
| test_halting_unit_ternary_pure, |
| test_graph_act_cell_shapes, |
| test_moe_act_cell_shapes, |
| test_act_early_halt, |
| test_act_weight_sum_one, |
| test_act_gradient_flow, |
| test_loss_components_ponder_fields, |
| test_loss_components_ponder_none, |
| test_act_graph_moe_sequential, |
| test_model_forward_with_act, |
| test_model_act_forward_without_targets, |
| test_model_act_loss_components, |
| test_model_act_backward, |
| test_model_act_disabled, |
| test_model_act_warmup_mode, |
| test_model_act_ponder_cached, |
| test_act_warmup_schedule, |
| test_act_ponder_lambda, |
| test_model_ponder_lambda_scaling, |
| test_text_only_forward, |
| test_image_forward, |
| test_multimodal_backward, |
| test_no_stale_trigram_encoder, |
| test_vocab, |
| test_memgram_shapes, |
| test_memgram_hash_indices, |
| test_memgram_bilinear_gate_range, |
| test_memgram_decay_formula, |
| test_memgram_gradient_flow, |
| test_memgram_conv_path, |
| test_conv_vq_shapes, |
| test_conv_vq_hard_cap, |
| test_conv_vq_deferred_activation, |
| test_conv_vq_ema_update, |
| test_conv_vq_persistence, |
| test_conv_vq_fuzzy_retrieve, |
| test_conv_vq_commitment_nonneg, |
| test_lstm_shapes, |
| test_lstm_forget_gate_bias, |
| test_lstm_bptt_detach, |
| test_lstm_hidden_reg, |
| test_lstm_c_t_proj_ternary, |
| test_memory_modules_backward_compat, |
| test_loss_components_nine_fields_total, |
| test_loss_components_nine_fields_log, |
| test_moe_router_h_with_h_t, |
| test_moe_router_without_h_t, |
| test_forward_no_memory_backward_compat, |
| test_forward_lstm_enabled_h_t_passed, |
| test_forward_lstm_c_t_residual, |
| test_forward_memgram_injection, |
| test_forward_conv_vq_deferred, |
| test_generate_carries_lstm_state, |
| test_memory_schedule_warmup, |
| test_memory_schedule_lstm_first, |
| test_memory_schedule_memgram_second, |
| test_memory_schedule_all_on, |
| test_memory_schedule_conv_vq_requires_vq_util, |
| test_memory_schedule_decay_reg_last, |
| test_lstm_state_reset_per_batch, |
| test_bptt_counter_separate, |
| test_focus_gate_no_boundary, |
| test_focus_gate_boundary_signal, |
| test_focus_gate_all_boundary_types, |
| test_focus_gate_unknown_token, |
| test_focus_gate_c_focus_modulation, |
| test_conv_stack_push_pop, |
| test_conv_stack_pop_missing, |
| test_conv_stack_clear, |
| test_conv_stack_lru_eviction, |
| test_conv_stack_reset, |
| test_conversation_lstm_forward, |
| test_conversation_lstm_dual_state, |
| test_conversation_lstm_boundary_reset, |
| test_conversation_lstm_topic_gate, |
| test_conversation_lstm_bptt_dual_windows, |
| test_conversation_lstm_topic_preserves_on_boundary, |
| test_conversation_lstm_h_out_is_sum, |
| test_extract_boundary_from_input, |
| test_model_switch_conversation, |
| test_model_reset_conversation, |
| test_model_generate_with_conversation_id, |
| test_conversation_lstm_ternary_projections, |
| test_conversation_lstm_focus_gate_params, |
| test_conversation_lstm_topic_cell_bias, |
| ] |
| print("Running MORPH tests (Phase 1 + Phase 2 VQ + Phase 3 Graph + Phase 4 MoE + Explore + Phase 5 ACT + Phase 6 Multi-Modal + Phase 7 Memory)...\n") |
| passed = 0 |
| failed = 0 |
| for t in tests: |
| try: |
| t() |
| passed += 1 |
| except Exception as e: |
| print(f" FAIL {t.__name__}: {e}") |
| failed += 1 |
| print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests") |
|
|