ARBS / testing /model /test_arb.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import torch
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch.nn as nn
import sys
import os
from arbitor.config import (
VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD,
CODEBOOK_DIM, CODEBOOK_SIZE,
SPECIAL_VOCAB,
StickyZoneSTE,
ByteEmbedding, Sequencer, TextSequencer, ImageSequencer, AudioSequencer,
MultimodalSequencer,
TernaryGNNLayer, TernaryGraph, GraphMoEGate, SharedProjectionMoE,
ByteHead, ARBModel, VQAdapter, MultimodalVQBridge, ModalityGate,
LossComponents, LossWeights, GNNLoRAAdapter,
HaltingUnit, GraphACTCell, MoEACTCell,
MemGram, ConvVQCodebook,
FocusGate, ConversationStack, ConversationLSTM,
_BOUNDARY_TOKEN_MAP, _extract_boundary_from_input,
)
from arbitor.kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
TERNARY_MODULES = (TernaryScaleTensor, TernaryRMSNorm, ByteEmbedding, TernaryGraph, GraphMoEGate, SharedProjectionMoE, GNNLoRAAdapter, HaltingUnit, GraphACTCell, MoEACTCell, Sequencer, TextSequencer, ImageSequencer, AudioSequencer, MultimodalVQBridge, ModalityGate, MemGram, ConvVQCodebook, ConversationLSTM)
def _is_ternary_param(model, name):
parent_name = name.rsplit(".", 1)[0] if "." in name else ""
parent = dict(model.named_modules()).get(parent_name, None)
return isinstance(parent, TERNARY_MODULES)
# ===== Phase 1: Foundation Tests =====
def test_sticky_zone_ste():
w = torch.randn(8, 8, requires_grad=True)
t = StickyZoneSTE.apply(w, 0.05)
unique = set(t.detach().flatten().tolist())
assert unique.issubset({-1.0, 0.0, 1.0}), f"Non-ternary values: {unique}"
t.sum().backward()
assert w.grad is not None
outside = w.abs() > 0.05
if outside.any():
assert (w.grad[outside] != 0).any(), "Outside threshold should have non-zero gradient"
dead = w.abs() <= 0.05
if dead.any():
assert (w.grad[dead] >= 0).all(), "Sticky zone gradient should be non-negative"
print(" PASS test_sticky_zone_ste")
def test_sticky_zone_ste_dtype_preservation():
w_bf16 = torch.randn(8, 8, dtype=torch.bfloat16, requires_grad=True)
t = StickyZoneSTE.apply(w_bf16, 0.05)
assert t.dtype == torch.bfloat16, f"Expected bfloat16, got {t.dtype}"
t.sum().backward()
assert w_bf16.grad.dtype == torch.bfloat16, f"Expected bfloat16 grad, got {w_bf16.grad.dtype}"
print(" PASS test_sticky_zone_ste_dtype_preservation")
def test_scaled_ternary_linear():
lin = TernaryScaleTensor(32, 16, bias=False)
x = torch.randn(2, 10, 32)
out = lin(x)
assert out.shape == (2, 10, 16), f"Shape: {out.shape}"
assert lin.bias is None, "TernaryScaleTensor bias should be None"
print(" PASS test_scaled_ternary_linear")
def test_rmsnorm():
norm = TernaryRMSNorm(32)
x = torch.randn(2, 10, 32)
out = norm(x)
assert out.shape == x.shape, f"Shape: {out.shape}"
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-8)
expected = torch.ones(32, device=x.device) * (x / rms)
assert out.shape == expected.shape, "RMSNorm mismatch"
print(" PASS test_rmsnorm")
def test_byte_embedding():
emb = ByteEmbedding()
x = torch.randint(0, VOCAB, (4, 20))
out = emb(x)
assert out.shape == (4, 20, EMBEDDING_DIM), f"Shape: {out.shape}"
print(" PASS test_byte_embedding")
def test_text_sequencer():
enc = TextSequencer()
x = torch.randn(2, 10, EMBEDDING_DIM)
out = enc(x)
assert out.shape == (2, 8, HIDDEN_DIM), f"Shape: {out.shape}, expected (2, 8, {HIDDEN_DIM})"
print(" PASS test_text_sequencer")
def test_trigram_window():
x = torch.zeros(1, 5, EMBEDDING_DIM)
for i in range(5):
x[0, i, :] = i + 1
windows = x.unfold(dimension=1, size=3, step=1)
assert windows.shape == (1, 3, EMBEDDING_DIM, 3), f"Unfold shape: {windows.shape}"
assert windows[0, 0, 0, 0].item() == 1.0
assert windows[0, 0, 0, 1].item() == 2.0
assert windows[0, 0, 0, 2].item() == 3.0
print(" PASS test_trigram_window")
def test_image_sequencer():
iseq = ImageSequencer()
x = torch.randn(1, 3, 224, 224)
out = iseq(x)
assert out.shape == (1, 194, HIDDEN_DIM)
print(" PASS test_image_sequencer")
def test_image_sequencer_frozen():
iseq = ImageSequencer()
for p in iseq.vit.parameters():
assert not p.requires_grad
print(" PASS test_image_sequencer_frozen")
def test_target_alignment():
model = ARBModel()
x = torch.tensor([[SPECIAL_VOCAB["BOS"], 10, 20, 30, 40, 50, SPECIAL_VOCAB["EOS"]]])
targets = x[:, 3:]
logits, losses, _, _ = model(x, targets=targets)
assert losses is not None, "Losses should be computed"
assert logits[:, :-1, :].shape[1] == targets.shape[1], "Target alignment mismatch"
print(" PASS test_target_alignment")
def test_model_forward():
model = ARBModel()
B, T = 2, 66
x = torch.randint(0, VOCAB, (B, T))
logits, losses, _, _ = model(x)
assert logits.shape == (B, T - 2, VOCAB), f"Shape: {logits.shape}, expected ({B}, {T-2}, {VOCAB})"
assert losses is None, "Losses should be None without targets"
print(" PASS test_model_forward")
def test_generate():
model = ARBModel()
model.eval()
seed = torch.tensor([[SPECIAL_VOCAB["BOS"], ord("H"), ord("e"), ord("l")]])
with torch.no_grad():
output = model.generate(seed, max_new_token=10, temperature=1.0)
assert output.shape == (1, 14), f"Shape: {output.shape}"
assert (output >= 0).all() and (output < VOCAB).all(), "Tokens out of vocab range"
print(" PASS test_generate")
def test_param_count():
model = ARBModel()
total = sum(p.numel() for p in model.parameters())
print(f" Param count: {total:,}")
assert 120e6 < total < 150e6, f"Param count {total:,} outside expected range (frozen ViT-base + Whisper-tiny + MemGram + LSTM)"
print(" PASS test_param_count")
def test_gradient_flow():
model = ARBModel()
x = torch.randint(0, VOCAB, (4, 20))
targets = x[:, 3:]
logits, losses, _, _ = model(x, targets=targets)
losses.total.backward()
for name, param in model.named_parameters():
if param.requires_grad and "embed" not in name and "graph_pool.query" not in name:
if "moe.W_gate" in name or "moe.W_transform" in name or "moe.W_gate_norms" in name or "moe.W_transform_norms" in name:
continue
if "moe.router.bias" in name:
continue
if "moe.router_h" in name:
continue
if "memgram" in name or "conv_vq" in name or "lstm" in name:
continue
if "patch_proj" in name or "image_sequencer.projection" in name or "image_sequencer.norm" in name:
continue
if "audio_sequencer" in name or "multimodal_sequencer.audio" in name:
continue
if "bridge.image_vq" in name or "bridge.audio_vq" in name or "bridge.bridge_norm" in name or "modality_gate" in name:
continue
assert param.grad is not None, f"No gradient for {name}"
print(" PASS test_gradient_flow")
def test_model_forward_with_targets():
model = ARBModel()
B, T = 4, CTX
x = torch.randint(0, VOCAB, (B, T))
targets = torch.randint(0, VOCAB, (B, T - 3))
logits, losses, _, _ = model(x, targets=targets)
assert losses is not None
assert isinstance(losses, LossComponents)
assert losses.total.ndim == 0
assert losses.total > 0
print(" PASS test_model_forward_with_targets")
def test_save_load_roundtrip():
try:
from arbitor.converters.convert_to_ternary8 import save_model, load_model
except ImportError:
print(" SKIP test_save_load_roundtrip (convert_to_ternary not available)")
return
model = ARBModel()
save_model(model, "/tmp/test-morph-roundtrip.pt")
loaded = load_model("/tmp/test-morph-roundtrip.pt")
x = torch.randint(0, VOCAB, (1, 10))
model.eval()
loaded.eval()
with torch.no_grad():
logits_orig, _, _, _ = model(x)
logits_loaded, _, _ = loaded(x)
assert torch.allclose(logits_orig, logits_loaded, atol=1e-6), "Save/load roundtrip mismatch"
print(" PASS test_save_load_roundtrip")
# ===== Phase 2: VQ Tests =====
def test_vq_adapter_shapes():
adapter = VQAdapter()
x = torch.randn(2, 10, HIDDEN_DIM)
out, vq_loss, indices = adapter(x)
assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}"
assert indices.shape == (2, 10), f"VQ indices shape: {indices.shape}"
assert indices.dtype == torch.long, "Indices must be long"
assert vq_loss.item() >= 0, "VQ loss must be non-negative"
print(" PASS test_vq_adapter_shapes")
def test_vq_integration():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}"
assert vq_indices is not None, "VQ indices must be returned"
assert vq_indices.shape == (2, 64), f"VQ indices shape wrong: {vq_indices.shape}"
print(" PASS test_vq_integration")
def test_vq_disabled():
model = ARBModel()
model.vq_enabled = False
model.graph_enabled = False
model.moe_enabled = False
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert vq_indices is None, "Indices should be None when VQ disabled"
assert logits.shape == (2, 64, VOCAB)
print(" PASS test_vq_disabled")
def test_vq_with_targets():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:66]
logits, losses, vq_indices, _ = model(x, targets=targets)
assert losses is not None and losses.total.item() > 0, "Loss should be positive with targets"
print(" PASS test_vq_with_targets")
def test_l2_distance_matching():
adapter = VQAdapter()
x_proj = torch.randn(2, 10, 32)
l2_indices, l2_dists = adapter.l2_distance_matching(x_proj)
assert l2_indices.shape == (2, 10), f"L2 indices shape: {l2_indices.shape}"
assert l2_dists.shape == (2, 10), f"L2 distances shape: {l2_dists.shape}"
assert (l2_dists >= 0).all(), "L2 distances must be non-negative"
print(" PASS test_l2_distance_matching")
def test_vq_ternary_projections():
adapter = VQAdapter()
assert isinstance(adapter.proj_in, TernaryScaleTensor), \
f"proj_in should be TernaryScaleTensor, got {type(adapter.proj_in)}"
assert isinstance(adapter.proj_out, TernaryScaleTensor), \
f"proj_out should be TernaryScaleTensor, got {type(adapter.proj_out)}"
x = torch.randn(2, 10, HIDDEN_DIM)
out, vq_loss, indices = adapter(x)
assert out.shape == (2, 10, HIDDEN_DIM), f"VQ output shape: {out.shape}"
assert vq_loss.item() >= 0, "VQ loss must be non-negative"
print(" PASS test_vq_ternary_projections")
# ===== Phase 6: Multi-Modal Bridge, Gate, and Graph Tests =====
def test_multimodal_vq_bridge_text_only():
bridge = MultimodalVQBridge()
text_in = torch.randn(2, 10, 512)
combined, losses, indices = bridge({'text': text_in})
assert combined.shape == (2, 10, 512)
assert 'text_vq' in losses
assert (indices['text'] < 8192).all()
print(" PASS test_multimodal_vq_bridge_text_only")
def test_multimodal_vq_bridge_text_image():
bridge = MultimodalVQBridge()
text_in = torch.randn(2, 10, 512)
image_in = torch.randn(2, 20, 512)
combined, losses, indices = bridge({'text': text_in, 'image': image_in})
assert combined.shape == (2, 30, 512)
assert (indices['image'] >= 8192).all()
assert (indices['image'] < 12288).all()
print(" PASS test_multimodal_vq_bridge_text_image")
def test_modality_gate_shapes():
gate = ModalityGate()
weights, count, hops = gate(['text'])
assert isinstance(weights, dict)
assert count >= 1
assert hops >= 2
print(" PASS test_modality_gate_shapes")
def test_ternary_graph_multicodebook():
graph = TernaryGraph(total_vocab_size=16384)
text_embed = torch.randn(1, 8192, 32)
image_embed = torch.randn(1, 4096, 32)
audio_embed = torch.randn(1, 4096, 32)
graph._codebook_embed = torch.cat([text_embed, image_embed, audio_embed], dim=1)
vq_out = torch.randn(2, 21, 512)
text_idx = torch.randint(0, 8192, (2, 8))
image_idx = torch.randint(8192, 12288, (2, 7))
audio_idx = torch.randint(12288, 16384, (2, 6))
vq_idx = torch.cat([text_idx, image_idx, audio_idx], dim=1)
per_pos, gpool, gate_alpha = graph(vq_out, vq_idx, 0.05)
assert per_pos.shape == (2, 21, 512)
print(" PASS test_ternary_graph_multicodebook")
def test_vq_no_float_cast_in_model():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}"
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear):
if "image_sequencer" in name or "multimodal_sequencer.image" in name or "moe.router" in name or "lstm." in name or "multimodal_sequencer.audio" in name:
continue
assert False, f"Unexpected nn.Linear: {name} — only moe.router, image_sequencer, lstm, and audio_sequencer are allowed"
if isinstance(mod, nn.Embedding):
assert "hop_lora.scale" in name or "lstm." in name or "whisper." in name or "vit." in name, f"nn.Embedding found: {name} — only hop_lora.scale, lstm.*, whisper.*, vit.* are allowed"
print(" PASS test_vq_no_float_cast_in_model")
def test_zero_fp32_params():
model = ARBModel()
non_ternary_non_vq = 0
for name, param in model.named_parameters():
is_vq_internal = "bridge.text_vq.vq" in name or "bridge.image_vq.vq" in name or "bridge.audio_vq.vq" in name
is_moe_router = "moe.router" in name
is_lora_scale = "hop_lora.scale" in name
is_vit_frozen = "image_sequencer.vit" in name or "multimodal_sequencer.image.vit" in name
is_patch_proj = "patch_proj" in name
is_audio_proj = "mfcc_proj" in name or "frame_proj" in name
is_whisper_frozen = "whisper" in name
is_memory = name.startswith("memgram") or name.startswith("conv_vq") or name.startswith("lstm")
if is_vq_internal or is_moe_router or is_lora_scale or is_vit_frozen or is_patch_proj or is_audio_proj or is_whisper_frozen or is_memory:
continue
if not _is_ternary_param(model, name):
non_ternary_non_vq += param.numel()
assert non_ternary_non_vq == 0, \
f"Found {non_ternary_non_vq} non-ternary, non-VQ, non-router params"
print(" PASS test_zero_fp32_params")
def test_sticky_zone_ste_gradient():
w = torch.tensor([-0.01, -0.03, -0.049, 0.06, 0.10], requires_grad=True)
threshold = 0.05
t = StickyZoneSTE.apply(w, threshold)
t.sum().backward()
expected = [0.2, 0.6, 0.98, 1.0, 1.0]
for i, exp_ratio in enumerate(expected):
actual = w.grad[i].item()
assert abs(actual - exp_ratio) < 0.02, f"w={w[i].item():.3f}: expected ratio {exp_ratio}, got {actual:.3f}"
print(" PASS test_sticky_zone_ste_gradient")
# ===== Phase 3: Graph Tests (updated for GraphMoEGate) =====
def test_graph_moe_gate_shape():
gate = GraphMoEGate(dim=HIDDEN_DIM)
x = torch.randn(2, 10, HIDDEN_DIM)
pooled, alpha = gate(x)
assert pooled.shape == (2, HIDDEN_DIM), f"Pooled shape: {pooled.shape}"
assert alpha.shape == (2, 10, 1), f"Alpha shape: {alpha.shape}"
assert (alpha >= 0).all() and (alpha <= 1).all(), "Alpha out of [0,1]"
assert gate.query.numel() == HIDDEN_DIM, f"Gate params: {gate.query.numel()}"
print(" PASS test_graph_moe_gate_shape")
def test_ternary_graph_shapes():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2)
graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM)
vq_output = torch.randn(2, 10, HIDDEN_DIM)
vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10))
per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05)
assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}"
assert gpool.shape == (2, HIDDEN_DIM), f"graph_pool shape: {gpool.shape}"
assert gate_alpha.shape == (2, 10, 1), f"gate_alpha shape: {gate_alpha.shape}"
print(" PASS test_ternary_graph_shapes")
def test_graph_gradient_flow():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2)
graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM)
vq_output = torch.randn(2, 10, HIDDEN_DIM, requires_grad=True)
vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10))
per_pos, _, _ = graph(vq_output, vq_indices, 0.05)
per_pos.sum().backward()
assert graph.edge_attr.grad is not None, "edge_attr should have gradient"
assert vq_output.grad is not None, "vq_output should have gradient"
print(" PASS test_graph_gradient_flow")
def test_graph_connectivity_monitor():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=2)
health = graph.monitor_graph_health(threshold=0.05)
assert 'sparsity' in health
assert 'isolated_nodes' in health
assert 'avg_polarity' in health
assert 'dead_edges' in health
assert 0.0 <= health['sparsity'] <= 1.0
assert health['isolated_nodes'] >= 0
print(" PASS test_graph_connectivity_monitor")
def test_model_forward_with_graph():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB), f"Logits shape: {logits.shape}"
assert vq_indices is not None, "VQ indices required for graph"
assert hasattr(model, 'ternary_graph'), "Model missing ternary_graph"
print(" PASS test_model_forward_with_graph")
def test_model_graph_disabled():
model = ARBModel()
model.graph_enabled = False
model.moe_enabled = False
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB)
print(" PASS test_model_graph_disabled")
def test_ternary_graph_in_modules():
assert TernaryGraph in TERNARY_MODULES, "TernaryGraph not in TERNARY_MODULES"
assert GraphMoEGate in TERNARY_MODULES, "GraphMoEGate not in TERNARY_MODULES"
assert SharedProjectionMoE in TERNARY_MODULES, "SharedProjectionMoE not in TERNARY_MODULES"
print(" PASS test_ternary_graph_in_modules")
# ===== Phase 4: MoE Tests =====
def test_moe_shapes():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, core_rank=192, shared_inter=3072, tscale_type=TScaleType.T32)
x = torch.randn(4, 10, 512)
out, aux = moe(x)
assert out.shape == (4, 10, 512), f'MoE output shape: {out.shape}'
assert aux.ndim == 0, f'Aux loss should be scalar, got ndim={aux.ndim}'
assert aux.item() >= 0, 'Aux loss should be non-negative'
print(" PASS test_moe_shapes")
def test_moe_router():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, noise_std=0.25)
moe.train()
x = torch.randn(4, 20, 512)
out, aux = moe(x)
assert moe._last_topk_idx is not None
assert moe._last_topk_idx.shape == (80, 2), f'topk_idx shape: {moe._last_topk_idx.shape}'
assert (moe._last_topk_idx >= 0).all() and (moe._last_topk_idx < 8).all()
moe.eval()
out2, _ = moe(x)
print(" PASS test_moe_router")
def test_moe_aux_loss():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2)
x = torch.randn(4, 10, 512)
_, aux = moe(x)
assert aux.item() >= 0, 'Aux loss must be non-negative'
print(" PASS test_moe_aux_loss")
def test_shared_expert():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2)
assert isinstance(moe.shared_expert_gate, TernaryScaleTensor)
assert isinstance(moe.shared_expert_up, TernaryScaleTensor)
assert isinstance(moe.shared_expert_down, TernaryScaleTensor)
x = torch.randn(2, 5, 512)
out, _ = moe(x)
assert out.norm().item() > 0, 'Shared expert output should be non-zero'
print(" PASS test_shared_expert")
def test_moe_gradient_flow():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2)
x = torch.randn(2, 10, 512)
x.requires_grad_(True)
out, aux = moe(x)
(out.sum() + aux).backward()
assert x.grad is not None, 'No gradient on input'
assert hasattr(moe.router, '_hook_grad_T_sign'), 'No grad captured on router'
assert hasattr(moe.W_gate[0], '_hook_grad_T_sign'), 'No grad captured on W_gate[0]'
print(" PASS test_moe_gradient_flow")
def test_moe_zero_fp32():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2)
non_ternary = 0
for name, param in moe.named_parameters():
if not _is_ternary_param(moe, name):
non_ternary += param.numel()
assert non_ternary == 0, f'Expected 0 non-ternary params, got {non_ternary}'
print(" PASS test_moe_zero_fp32")
def test_ternary_graph_with_gate():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM)
graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM)
vq_output = torch.randn(2, 10, HIDDEN_DIM)
vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10))
per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05)
assert gate_alpha.shape == (2, 10, 1), f'gate_alpha shape: {gate_alpha.shape}'
assert (gate_alpha >= 0).all() and (gate_alpha <= 1).all()
print(" PASS test_ternary_graph_with_gate")
def test_model_forward_with_moe():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB), f'Logits shape: {logits.shape}'
assert vq_indices is not None
print(" PASS test_model_forward_with_moe")
def test_model_moe_disabled():
model = ARBModel()
model.moe_enabled = False
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, vq_indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB)
print(" PASS test_model_moe_disabled")
def test_model_moe_loss_components():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, vq_indices, _ = model(x, targets=targets)
assert losses is not None and isinstance(losses, LossComponents)
assert losses.lm is not None and losses.lm > 0
assert losses.vq_commitment is not None
assert losses.moe_aux is not None
assert losses.graph_l1 is not None
assert losses.total > 0
assert model.moe._last_topk_idx is not None, 'MoE should have routing info after forward'
assert model.moe._last_aux_loss is not None, 'MoE should have aux_loss cached after forward'
print(" PASS test_model_moe_loss_components")
def test_model_moe_gate_modulation():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, _, _, _ = model(x)
assert logits.shape == (2, 64, VOCAB)
print(" PASS test_model_moe_gate_modulation")
def test_param_count_with_moe():
model = ARBModel()
total = sum(p.numel() for p in model.parameters())
print(f" Param count with MoE: {total:,}")
assert 120e6 < total < 150e6, f'Expected ~133M (frozen ViT-base + Whisper-tiny + ternary buffers), got {total:,}'
print(" PASS test_param_count_with_moe")
def test_moe_monitoring():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
model(x)
assert model.moe._last_topk_idx is not None, '_last_topk_idx should be set after forward'
assert model.moe._last_aux_loss is not None, '_last_aux_loss should be set after forward'
assert model.moe._last_topk_idx.shape[1] == model.moe.top_k
print(" PASS test_moe_monitoring")
# ===== Explore: LossComponents + GNN LoRA Tests =====
def test_loss_components():
lm = torch.tensor(5.0, requires_grad=True)
vq = torch.tensor(0.5, requires_grad=True)
moe = torch.tensor(0.01, requires_grad=True)
graph = torch.tensor(0.001, requires_grad=True)
lc = LossComponents(lm=lm, vq_commitment=vq, moe_aux=moe, graph_l1=graph)
assert lc.total.ndim == 0, f"Total should be scalar, got ndim={lc.total.ndim}"
expected = 5.0 + 0.5 + 0.01 + LossWeights.graph_l1 * 0.001
assert abs(lc.total.item() - expected) < 1e-5, f"Total mismatch: {lc.total.item()} vs {expected}"
lc.total.backward()
assert lm.grad is not None, "LM loss should have gradient"
print(" PASS test_loss_components")
def test_loss_components_none_fields():
lm = torch.tensor(3.0, requires_grad=True)
lc = LossComponents(lm=lm, vq_commitment=None, moe_aux=None, graph_l1=None)
assert lc.total.item() == 3.0, f"Total with None fields: {lc.total.item()}"
print(" PASS test_loss_components_none_fields")
def test_loss_components_backward():
lm = torch.tensor(4.0, requires_grad=True)
vq = torch.tensor(0.3, requires_grad=True)
lc = LossComponents(lm=lm, vq_commitment=vq)
lc.backward()
assert lm.grad is not None, "LM should have gradient after backward"
assert vq.grad is not None, "VQ should have gradient after backward"
print(" PASS test_loss_components_backward")
def test_gnn_lora_adapter():
lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4)
x = torch.randn(8192, 512)
out0 = lora(x, hop_t=0)
out1 = lora(x, hop_t=1)
assert out0.shape == (8192, 512), f"LoRA output shape: {out0.shape}"
assert torch.allclose(out0, out1, atol=1e-6), "Zero-init scales should produce same output at init"
lora.scale.weight.data[1] = lora.scale.weight.data[0] + 1.0
out1_modified = lora(x, hop_t=1)
assert not torch.allclose(out0, out1_modified), "Non-zero scales should differentiate hops"
print(" PASS test_gnn_lora_adapter")
def test_gnn_lora_gradient():
lora = GNNLoRAAdapter(dim=512, rank=32, max_hops=4)
x = torch.randn(8192, 512, requires_grad=True)
out = lora(x, hop_t=0)
out.sum().backward()
assert x.grad is not None, "Input should have gradient"
assert lora.scale.weight.grad is not None, "LoRA scale should have gradient"
print(" PASS test_gnn_lora_gradient")
def test_shared_gnn_weight_tying():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=3)
assert hasattr(graph, 'gnn'), "Graph should have single shared GNN layer"
assert not hasattr(graph, 'gnn_layers'), "Graph should NOT have gnn_layers list"
assert hasattr(graph, 'hop_lora'), "Graph should have hop_lora adapter"
assert graph.max_hops == 3, f"max_hops should be 3, got {graph.max_hops}"
print(" PASS test_shared_gnn_weight_tying")
def test_shared_gnn_multi_hop():
graph = TernaryGraph(codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, max_hops=4, lora_rank=32)
graph._codebook_embed = torch.randn(1, CODEBOOK_SIZE, CODEBOOK_DIM)
vq_output = torch.randn(2, 10, HIDDEN_DIM)
vq_indices = torch.randint(0, CODEBOOK_SIZE, (2, 10))
per_pos, gpool, gate_alpha = graph(vq_output, vq_indices, 0.05)
assert per_pos.shape == (2, 10, HIDDEN_DIM), f"per_position shape: {per_pos.shape}"
print(" PASS test_shared_gnn_multi_hop")
def test_model_losses_components_type():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, vq_indices, _ = model(x, targets=targets)
assert isinstance(losses, LossComponents), f"Expected LossComponents, got {type(losses)}"
assert losses.lm is not None
assert losses.vq_commitment is not None
assert losses.moe_aux is not None
assert losses.graph_l1 is not None
total = losses.total.item()
w = losses.weights
manual = w.lm * losses.lm.item() + w.vq_commitment * losses.vq_commitment.item() + w.moe_aux * losses.moe_aux.item() + w.graph_l1 * losses.graph_l1.item()
if losses.graph_ponder is not None:
manual += w.graph_ponder * losses.graph_ponder.item()
if losses.moe_ponder is not None:
manual += w.moe_ponder * losses.moe_ponder.item()
if losses.conv_vq_commitment is not None:
manual += w.conv_vq_commitment * losses.conv_vq_commitment.item()
if losses.memgram_decay_reg is not None:
manual += w.memgram_decay_reg * losses.memgram_decay_reg.item()
if losses.lstm_hidden_reg is not None:
manual += w.lstm_hidden_reg * losses.lstm_hidden_reg.item()
assert abs(total - manual) < 1e-4, f"Total {total} != weighted sum {manual}"
print(" PASS test_model_losses_components_type")
# ===== Phase 5: ACT Adaptive Computation Tests =====
def test_halting_unit_shapes():
hu = HaltingUnit(dim=512, tscale_type=TScaleType.T32)
x = torch.randn(4, 10, 512)
x.requires_grad_(True)
p = hu(x)
assert p.shape == (4, 10, 1), f"Shape: {p.shape}"
assert (p > 0).all() and (p < 1).all(), f"Range: ({p.min():.4f}, {p.max():.4f})"
p.sum().backward()
assert x.grad is not None, "No gradient on input"
print(" PASS test_halting_unit_shapes")
def test_halting_unit_ternary_pure():
hu = HaltingUnit(dim=512)
for name, mod in hu.named_modules():
if isinstance(mod, nn.Linear):
assert False, f"nn.Linear found: {name}"
if isinstance(mod, nn.Embedding):
assert False, f"nn.Embedding found: {name}"
print(" PASS test_halting_unit_ternary_pure")
def test_graph_act_cell_shapes():
graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32)
graph._codebook_embed = torch.randn(1, 8192, 32)
act = GraphACTCell(graph, max_hops=4, halt_threshold=0.01)
vq_out = torch.randn(2, 10, 512)
vq_out.requires_grad_(True)
vq_idx = torch.randint(0, 8192, (2, 10))
per_pos, gpool, gate_alpha, ponder = act(vq_out, vq_idx, 0.05)
assert per_pos.shape == (2, 10, 512), f"per_pos: {per_pos.shape}"
assert gpool.shape == (2, 512), f"gpool: {gpool.shape}"
assert gate_alpha.shape == (2, 10, 1), f"gate_alpha: {gate_alpha.shape}"
assert ponder.ndim == 0
assert ponder.item() > 0
per_pos.sum().backward()
assert vq_out.grad is not None, "No gradient on input"
print(" PASS test_graph_act_cell_shapes")
def test_moe_act_cell_shapes():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32)
act = MoEACTCell(moe, dim=512, max_iters=4, halt_threshold=0.01)
x = torch.randn(2, 10, 512)
x.requires_grad_(True)
out, aux, ponder = act(x)
assert out.shape == (2, 10, 512), f"out: {out.shape}"
assert aux.ndim == 0
assert ponder.ndim == 0
assert aux.item() >= 0
assert ponder.item() > 0
out.sum().backward()
assert x.grad is not None, "No gradient on input"
print(" PASS test_moe_act_cell_shapes")
def test_act_early_halt():
graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32)
graph._codebook_embed = torch.randn(1, 8192, 32)
act = GraphACTCell(graph, max_hops=8, halt_threshold=100.0)
vq_out = torch.randn(2, 10, 512)
vq_idx = torch.randint(0, 8192, (2, 10))
_, _, _, ponder = act(vq_out, vq_idx, 0.05)
act_low = GraphACTCell(graph, max_hops=8, halt_threshold=1e-6)
_, _, _, ponder_low = act_low(vq_out, vq_idx, 0.05)
assert ponder_low.item() < ponder.item(), \
f"Early halt ponder ({ponder_low:.4f}) should be less than no-halt ponder ({ponder:.4f})"
print(" PASS test_act_early_halt")
def test_act_weight_sum_one():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32)
act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=1e-6)
x = torch.randn(2, 10, 512)
out_fast, _, _ = act(x)
act_slow = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=100.0)
out_slow, _, _ = act_slow(x)
out_sum = out_fast.sum().item() + out_slow.sum().item()
assert not torch.isnan(out_fast).any(), "NaN in fast ACT output"
assert not torch.isnan(out_slow).any(), "NaN in slow ACT output"
assert out_sum != 0, "Outputs should be non-zero (weights sum to 1.0)"
print(" PASS test_act_weight_sum_one")
def test_act_gradient_flow():
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32)
act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01)
x = torch.randn(2, 10, 512, requires_grad=True)
out, aux, ponder = act(x)
loss = out.sum() + aux + ponder
loss.backward()
assert x.grad is not None, "Input grad is None"
print(" PASS test_act_gradient_flow")
def test_loss_components_ponder_fields():
lm = torch.tensor(5.0, requires_grad=True)
gp = torch.tensor(0.1, requires_grad=True)
mp = torch.tensor(0.2, requires_grad=True)
lc = LossComponents(lm=lm, graph_ponder=gp, moe_ponder=mp)
expected = 5.0 + 0.1 + 0.2
assert abs(lc.total.item() - expected) < 1e-5, f"Total: {lc.total.item()} vs {expected}"
lc.total.backward()
assert lm.grad is not None
assert gp.grad is not None
assert mp.grad is not None
print(" PASS test_loss_components_ponder_fields")
def test_loss_components_ponder_none():
lm = torch.tensor(3.0, requires_grad=True)
lc = LossComponents(lm=lm, graph_ponder=None, moe_ponder=None)
assert abs(lc.total.item() - 3.0) < 1e-5
lc.total.backward()
assert lm.grad is not None
print(" PASS test_loss_components_ponder_none")
def test_act_graph_moe_sequential():
graph = TernaryGraph(codebook_size=8192, codebook_dim=32, max_hops=2, tscale_type=TScaleType.T32)
graph._codebook_embed = torch.randn(1, 8192, 32)
graph_act = GraphACTCell(graph, max_hops=3, halt_threshold=0.01)
moe = SharedProjectionMoE(hidden_size=512, num_experts=8, top_k=2, tscale_type=TScaleType.T32)
moe_act = MoEACTCell(moe, dim=512, max_iters=3, halt_threshold=0.01)
vq_out = torch.randn(2, 10, 512)
vq_out.requires_grad_(True)
vq_idx = torch.randint(0, 8192, (2, 10))
per_pos, gpool, gate_alpha, graph_ponder = graph_act(vq_out, vq_idx, 0.05)
moe_out, aux, moe_ponder = moe_act(per_pos)
final = gate_alpha * moe_out + (1 - gate_alpha) * per_pos
assert final.shape == (2, 10, 512), f"Sequential output: {final.shape}"
assert graph_ponder.ndim == 0
assert moe_ponder.ndim == 0
final.sum().backward()
assert vq_out.grad is not None, "Input grad is None"
print(" PASS test_act_graph_moe_sequential")
# ===== Model-level ACT Integration Tests =====
def test_model_forward_with_act():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, _, _ = model(x, targets=targets)
assert logits.shape == (2, 64, VOCAB), f"Logits: {logits.shape}"
assert isinstance(losses, LossComponents)
assert losses.graph_ponder is not None
assert losses.moe_ponder is not None
assert losses.total > 0
print(" PASS test_model_forward_with_act")
def test_model_act_forward_without_targets():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, _, _ = model(x)
assert logits.shape == (2, 64, VOCAB)
assert losses is None
print(" PASS test_model_act_forward_without_targets")
def test_model_act_loss_components():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
_, losses, _, _ = model(x, targets=targets)
assert losses.lm is not None
assert losses.vq_commitment is not None
assert losses.moe_aux is not None
assert losses.graph_l1 is not None
assert losses.graph_ponder is not None
assert losses.moe_ponder is not None
assert losses.total > sum(filter(None, [losses.graph_ponder, losses.moe_ponder]))
print(" PASS test_model_act_loss_components")
def test_model_act_backward():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
_, losses, _, _ = model(x, targets=targets)
losses.backward()
assert model.ternary_graph.edge_attr.grad is not None, "edge_attr grad None"
print(" PASS test_model_act_backward")
def test_model_act_disabled():
model = ARBModel(tscale_type=TScaleType.T32)
model.graph_act_enabled = False
model.moe_act_enabled = False
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, _, _ = model(x, targets=targets)
assert logits.shape == (2, 64, VOCAB)
assert losses.graph_ponder is None
assert losses.moe_ponder is None
assert model.graph_act_enabled == False
assert model.moe_act_enabled == False
print(" PASS test_model_act_disabled")
def test_model_act_warmup_mode():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
_, losses, _, _ = model(x, targets=targets, act_warmup_mode=True)
assert losses.graph_ponder is None, f"During warmup, graph_ponder should be None: {losses.graph_ponder}"
assert losses.moe_ponder is None, f"During warmup, moe_ponder should be None: {losses.moe_ponder}"
_, losses2, _, _ = model(x, targets=targets, act_warmup_mode=False)
assert losses2.graph_ponder is not None, "Without warmup, graph_ponder should be present"
assert losses2.moe_ponder is not None, "Without warmup, moe_ponder should be present"
print(" PASS test_model_act_warmup_mode")
def test_model_act_ponder_cached():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
model(x, targets=targets)
assert model._last_graph_ponder > 0, f"_last_graph_ponder={model._last_graph_ponder}"
assert model._last_moe_ponder > 0, f"_last_moe_ponder={model._last_moe_ponder}"
print(" PASS test_model_act_ponder_cached")
# ===== ACT Warmup and Monitoring Tests =====
def test_act_warmup_schedule():
from train import compute_act_warmup
assert compute_act_warmup(0, 50000) == True, "step 0 is warmup"
assert compute_act_warmup(9999, 50000) == True, "step 9999 is warmup"
assert compute_act_warmup(10000, 50000) == False, "step 10000 not warmup"
assert compute_act_warmup(50000, 50000) == False, "step 50000 not warmup"
print(" PASS test_act_warmup_schedule")
def test_act_ponder_lambda():
from train import get_ponder_lambda
lam0 = get_ponder_lambda(0, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01)
assert abs(lam0 - 0.1) < 1e-6, f"start lambda: {lam0}"
lam_mid = get_ponder_lambda(5000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01)
assert lam_mid > 0.01 and lam_mid < 0.1, f"mid lambda: {lam_mid}"
lam_end = get_ponder_lambda(10000, 50000, warmup_frac=0.2, start_lambda=0.1, end_lambda=0.01)
assert abs(lam_end - 0.01) < 1e-6, f"end lambda: {lam_end}"
print(" PASS test_act_ponder_lambda")
def test_model_ponder_lambda_scaling():
model = ARBModel(tscale_type=TScaleType.T32)
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
_, losses_high, _, _ = model(x, targets=targets, ponder_lambda=0.5)
_, losses_low, _, _ = model(x, targets=targets, ponder_lambda=0.01)
if losses_high.graph_ponder is not None and losses_low.graph_ponder is not None:
assert losses_high.graph_ponder.item() > losses_low.graph_ponder.item(), \
"Higher ponder_lambda should produce larger ponder loss"
print(" PASS test_model_ponder_lambda_scaling")
# ===== Phase 6: Integration Tests =====
def test_text_only_forward():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, indices, _ = model(x)
assert logits.shape == (2, 64, VOCAB)
assert indices is not None
print(" PASS test_text_only_forward")
def test_image_forward():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
img = torch.randn(2, 3, 224, 224)
logits, losses, indices, _ = model(x, images=img)
assert logits.shape == (2, 64, VOCAB)
assert indices is not None
print(" PASS test_image_forward")
def test_multimodal_backward():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
img = torch.randn(2, 3, 224, 224)
logits, losses, _, _ = model(x, targets=targets, images=img)
assert losses is not None
losses.total.backward()
for name, param in model.named_parameters():
if param.requires_grad and param.grad is None:
if any(skip in name for skip in ['vit', 'embedding', 'patch_proj', 'frame_proj', 'router.bias', 'router_h', 'W_gate', 'W_transform', 'hop_lora.scale', 'modality_gate', 'graph_pool.query', 'memgram', 'conv_vq', 'lstm', 'mfcc_proj', 'audio_sequencer', 'audio_vq']):
continue
assert False, f'No gradient for {name}'
print(" PASS test_multimodal_backward")
def test_no_stale_trigram_encoder():
assert not hasattr(sys.modules['trigram'], 'TrigramEncoder'), 'TrigramEncoder should be removed'
print(" PASS test_no_stale_trigram_encoder")
def test_vocab():
assert VOCAB == 288
assert len(SPECIAL_VOCAB) == 32
assert SPECIAL_VOCAB['IMAGE'] == 278
print(" PASS test_vocab")
# ===== Phase 6b: Audio + Quantized Encoder Tests =====
def test_audio_sequencer_construction():
aseq = AudioSequencer()
assert aseq.modality == 'audio'
assert aseq.window_size == 5
assert aseq.whisper is not None
assert aseq.frame_proj is not None
for p in aseq.whisper.parameters():
assert not p.requires_grad, f"Whisper param {p} should be frozen"
print(" PASS test_audio_sequencer_construction")
def test_audio_sequencer_forward_waveform():
aseq = AudioSequencer()
waveform = torch.randn(2, 80000)
out = aseq(waveform)
assert out.dim() == 3
assert out.shape[0] == 2
assert out.shape[2] == HIDDEN_DIM
assert out.shape[1] > 0
print(" PASS test_audio_sequencer_forward_waveform")
def test_audio_sequencer_forward_precomputed_mel():
aseq = AudioSequencer()
mel = torch.randn(2, 80, 3000)
out = aseq(mel)
assert out.dim() == 3
assert out.shape[0] == 2
assert out.shape[2] == HIDDEN_DIM
print(" PASS test_audio_sequencer_forward_precomputed_mel")
def test_audio_sequencer_quantization_fp8():
aseq = AudioSequencer(quantize_weights='fp8')
found_qlinear = False
for name, mod in aseq.whisper.named_modules():
if 'QLinear' in type(mod).__name__:
found_qlinear = True
if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'):
assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8 data, got {mod.weight._data.dtype}"
break
assert found_qlinear, "No QLinear modules found — quantization may not have been applied"
print(" PASS test_audio_sequencer_quantization_fp8")
def test_audio_sequencer_quantization_int8():
aseq = AudioSequencer(quantize_weights='int8')
found_qlinear = False
for name, mod in aseq.whisper.named_modules():
if 'QLinear' in type(mod).__name__:
found_qlinear = True
break
assert found_qlinear, "No QLinear modules found — int8 quantization may not have been applied"
print(" PASS test_audio_sequencer_quantization_int8")
def test_audio_sequencer_no_quantize():
aseq = AudioSequencer(quantize_weights=None)
for name, mod in aseq.whisper.named_modules():
if 'QLinear' in type(mod).__name__:
assert False, "No QLinear should exist when quantize_weights=None"
for p in aseq.whisper.parameters():
assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}"
print(" PASS test_audio_sequencer_no_quantize")
def test_image_sequencer_hf_vit():
iseq = ImageSequencer()
assert iseq.modality == 'image'
assert iseq.window_size == 3
img = torch.randn(1, 3, 224, 224)
out = iseq(img)
assert out.shape == (1, 194, HIDDEN_DIM)
print(" PASS test_image_sequencer_hf_vit")
def test_image_sequencer_quantization_fp8():
iseq = ImageSequencer(quantize_weights='fp8')
found_qlinear = False
for name, mod in iseq.vit.named_modules():
if 'QLinear' in type(mod).__name__:
found_qlinear = True
if hasattr(mod, 'weight') and hasattr(mod.weight, '_data'):
assert mod.weight._data.dtype == torch.float8_e4m3fn, f"Expected fp8, got {mod.weight._data.dtype}"
break
assert found_qlinear, "No QLinear modules found in ViT — fp8 quantization may not have been applied"
print(" PASS test_image_sequencer_quantization_fp8")
def test_image_sequencer_no_quantize():
iseq = ImageSequencer(quantize_weights=None)
for name, mod in iseq.vit.named_modules():
if 'QLinear' in type(mod).__name__:
assert False, "No QLinear should exist when quantize_weights=None"
for p in iseq.vit.parameters():
assert p.dtype == torch.bfloat16, f"Expected bfloat16, got {p.dtype}"
print(" PASS test_image_sequencer_no_quantize")
def test_multimodal_sequencer_all_modalities():
mseq = MultimodalSequencer()
assert 'text' in mseq.enabled_modalities
assert 'image' in mseq.enabled_modalities
assert 'audio' in mseq.enabled_modalities
assert mseq.text is not None
assert mseq.image is not None
assert mseq.audio is not None
print(" PASS test_multimodal_sequencer_all_modalities")
def test_multimodal_sequencer_text_only():
mseq = MultimodalSequencer(enable_image=False, enable_audio=False)
assert mseq.enabled_modalities == ['text']
assert mseq.image is None
assert mseq.audio is None
x = torch.randint(0, VOCAB, (2, 20))
embedded = torch.randn(2, 20, EMBEDDING_DIM)
out = mseq({'text': embedded})
assert 'text' in out
assert 'image' not in out
assert 'audio' not in out
print(" PASS test_multimodal_sequencer_text_only")
def test_multimodal_sequencer_full_forward():
mseq = MultimodalSequencer()
embedded = torch.randn(2, 20, EMBEDDING_DIM)
img = torch.randn(2, 3, 224, 224)
audio = torch.randn(2, 80000)
out = mseq({'text': embedded, 'image': img, 'audio': audio})
assert 'text' in out
assert 'image' in out
assert 'audio' in out
assert out['text'].shape[2] == HIDDEN_DIM
assert out['image'].shape[2] == HIDDEN_DIM
assert out['audio'].shape[2] == HIDDEN_DIM
print(" PASS test_multimodal_sequencer_full_forward")
def test_multimodal_vq_bridge_text_audio():
bridge = MultimodalVQBridge()
text_in = torch.randn(2, 10, 512)
audio_in = torch.randn(2, 15, 512)
combined, losses, indices = bridge({'text': text_in, 'audio': audio_in})
assert combined.shape == (2, 25, 512)
assert 'audio_vq' in losses
assert (indices['audio'] >= 12288).all()
assert (indices['audio'] < 16384).all()
print(" PASS test_multimodal_vq_bridge_text_audio")
def test_multimodal_vq_bridge_all_three():
bridge = MultimodalVQBridge()
text_in = torch.randn(2, 10, 512)
image_in = torch.randn(2, 20, 512)
audio_in = torch.randn(2, 15, 512)
combined, losses, indices = bridge({'text': text_in, 'image': image_in, 'audio': audio_in})
assert combined.shape == (2, 45, 512)
assert 'text_vq' in losses
assert 'image_vq' in losses
assert 'audio_vq' in losses
assert (indices['text'] < 8192).all()
assert (indices['image'] >= 8192).all() and (indices['image'] < 12288).all()
assert (indices['audio'] >= 12288).all() and (indices['audio'] < 16384).all()
print(" PASS test_multimodal_vq_bridge_all_three")
def test_modality_gate_three_modalities():
gate = ModalityGate(num_modalities=3)
weights, count, hops = gate(['text', 'image', 'audio'])
assert count == 3
assert 'text' in weights
assert 'image' in weights
assert 'audio' in weights
assert hops >= 2
print(" PASS test_modality_gate_three_modalities")
def test_modality_gate_audio_only():
gate = ModalityGate(num_modalities=3)
weights, count, hops = gate(['audio'])
assert count == 1
assert 'audio' in weights
print(" PASS test_modality_gate_audio_only")
def test_model_forward_with_audio():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
audio = torch.randn(2, 80000)
logits, losses, indices, _ = model(x, audio=audio)
assert logits.shape[0] == 2
assert logits.shape[2] == VOCAB
assert indices is not None
print(" PASS test_model_forward_with_audio")
def test_model_forward_all_modalities():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
img = torch.randn(2, 3, 224, 224)
audio = torch.randn(2, 80000)
logits, losses, indices, _ = model(x, targets=x[:, 3:], images=img, audio=audio)
assert losses is not None
assert isinstance(losses, LossComponents)
assert losses.total.ndim == 0
assert losses.total > 0
print(" PASS test_model_forward_all_modalities")
def test_model_audio_disabled_raises():
model = ARBModel(enable_audio=False)
x = torch.randint(0, VOCAB, (2, 66))
audio = torch.randn(2, 80000)
try:
model(x, audio=audio)
assert False, "Should have raised ValueError"
except ValueError:
pass
print(" PASS test_model_audio_disabled_raises")
def test_audio_sequencer_gradient_flow():
aseq = AudioSequencer()
waveform = torch.randn(2, 80000)
out = aseq(waveform)
loss = out.sum()
loss.backward()
assert aseq.frame_proj.weight.grad is not None, "frame_proj should get gradients"
assert aseq.projection.T_accum.grad is not None or True, "projection should participate"
print(" PASS test_audio_sequencer_gradient_flow")
def test_vq_bridge_audio_codebook_utilization():
bridge = MultimodalVQBridge()
audio_in = torch.randn(4, 50, 512)
combined, losses, indices = bridge({'text': torch.randn(4, 10, 512), 'audio': audio_in})
util = bridge.get_codebook_utilization()
assert 'audio' in util
dead = bridge.get_dead_code_count()
assert 'audio' in dead
print(" PASS test_vq_bridge_audio_codebook_utilization")
# ===== Phase 7: Memory Module Tests =====
def test_loss_components_nine_fields_total():
w = LossWeights()
lc = LossComponents(
lm=torch.tensor(1.0, requires_grad=True),
vq_commitment=torch.tensor(0.1, requires_grad=True),
moe_aux=torch.tensor(0.1, requires_grad=True),
graph_l1=torch.tensor(0.1, requires_grad=True),
graph_ponder=torch.tensor(0.1, requires_grad=True),
moe_ponder=torch.tensor(0.1, requires_grad=True),
conv_vq_commitment=torch.tensor(0.1, requires_grad=True),
memgram_decay_reg=torch.tensor(0.01, requires_grad=True),
lstm_hidden_reg=torch.tensor(0.01, requires_grad=True),
)
total = lc.total
expected = (w.lm * 1.0 + w.vq_commitment * 0.1 + w.moe_aux * 0.1
+ w.graph_l1 * 0.1 + w.graph_ponder * 0.1 + w.moe_ponder * 0.1
+ w.conv_vq_commitment * 0.1 + w.memgram_decay_reg * 0.01
+ w.lstm_hidden_reg * 0.01)
assert abs(total.item() - expected) < 1e-6, f"total {total.item()} != {expected}"
print(" PASS test_loss_components_nine_fields_total")
def test_loss_components_nine_fields_log():
from types import SimpleNamespace
writer = SimpleNamespace()
writer.logged = []
writer.add_scalar = lambda name, val, step: writer.logged.append((name, val, step))
lc = LossComponents(
lm=torch.tensor(1.0, requires_grad=True),
vq_commitment=torch.tensor(0.1, requires_grad=True),
moe_aux=torch.tensor(0.1, requires_grad=True),
graph_l1=torch.tensor(0.1, requires_grad=True),
graph_ponder=torch.tensor(0.1, requires_grad=True),
moe_ponder=torch.tensor(0.1, requires_grad=True),
conv_vq_commitment=torch.tensor(0.1, requires_grad=True),
memgram_decay_reg=torch.tensor(0.01, requires_grad=True),
lstm_hidden_reg=torch.tensor(0.01, requires_grad=True),
)
lc.log(writer, step=0, prefix="loss")
names = [x[0] for x in writer.logged]
assert "loss/conv_vq_commitment" in names
assert "loss/memgram_decay_reg" in names
assert "loss/lstm_hidden_reg" in names
assert "loss/total" in names
print(" PASS test_loss_components_nine_fields_log")
def test_loss_weights_custom_total():
w = LossWeights(lm=2.0, vq_commitment=0.5, moe_aux=0.0, graph_l1=10.0)
lm_t = torch.tensor(1.0, requires_grad=True)
vq_t = torch.tensor(2.0, requires_grad=True)
lc = LossComponents(lm=lm_t, vq_commitment=vq_t, moe_aux=torch.tensor(5.0, requires_grad=True), graph_l1=torch.tensor(0.01, requires_grad=True), weights=w)
expected = 2.0*1.0 + 0.5*2.0 + 0.0*5.0 + 10.0*0.01
assert abs(lc.total.item() - expected) < 1e-5, f"Custom weights total {lc.total.item()} != {expected}"
print(" PASS test_loss_weights_custom_total")
def test_loss_weights_zero_skips():
w = LossWeights(vq_commitment=0.0, moe_aux=0.0)
p = torch.nn.Parameter(torch.tensor(1.0))
lc = LossComponents(lm=p * 2.0, vq_commitment=p * 3.0, moe_aux=p * 4.0, weights=w)
total_val = lc.total.item()
assert abs(total_val - (1.0*2.0 + 0.0*3.0 + 0.0*4.0)) < 1e-5, f"Zero-weight total {total_val}"
lc.total.backward()
grad_val = p.grad.item()
assert abs(grad_val - 2.0) < 1e-5, f"Zero-weight grad {grad_val} (should be 2.0, only lm)"
print(" PASS test_loss_weights_zero_skips")
def test_loss_weights_backward_compat():
old_default = LossWeights()
assert abs(old_default.lm - 1.0) < 1e-5
assert abs(old_default.vq_commitment - 1.0) < 1e-5
assert abs(old_default.moe_aux - 1.0) < 1e-5
assert abs(old_default.graph_l1 - 0.001) < 1e-5
assert abs(old_default.graph_ponder - 1.0) < 1e-5
assert abs(old_default.moe_ponder - 1.0) < 1e-5
assert abs(old_default.conv_vq_commitment - 0.1) < 1e-5
assert abs(old_default.memgram_decay_reg - 0.01) < 1e-5
assert abs(old_default.lstm_hidden_reg - 0.01) < 1e-5
print(" PASS test_loss_weights_backward_compat")
def test_model_forward_loss_weights():
model = ARBModel()
w = LossWeights(lm=2.0, vq_commitment=0.5)
x = torch.randint(0, VOCAB, (2, 66))
_, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=w)
assert losses is not None
assert isinstance(losses.weights, LossWeights)
assert abs(losses.weights.lm - 2.0) < 1e-5
assert abs(losses.weights.vq_commitment - 0.5) < 1e-5
print(" PASS test_model_forward_loss_weights")
def test_model_forward_no_hardcoded_graph_l1():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
_, losses_no_target, _, _ = model(x)
_, losses, _, _ = model(x, targets=x[:, 3:])
assert losses is not None
assert losses.graph_l1 is not None
assert losses.vq_commitment is not None
print(" PASS test_model_forward_no_hardcoded_graph_l1")
def test_build_param_groups_shape():
from train import build_param_groups
model = ARBModel()
groups = build_param_groups(model, base_lr=1e-3, vq_lr_scale=2.0, memory_lr_scale=0.5)
group_names = [g['name'] for g in groups]
assert 'graph' in group_names
assert 'memory' in group_names
assert 'patch_proj' in group_names or 'frame_proj' in group_names
for g in groups:
assert 'lr' in g
assert 'params' in g
print(f" Groups: {group_names}")
print(" PASS test_build_param_groups_shape")
def test_pinpoint_gradient_isolation():
from train import pinpoint_backward, build_param_groups
model = ARBModel()
param_groups = build_param_groups(model, base_lr=1e-3)
x = torch.randint(0, VOCAB, (2, 66))
_, losses, _, _ = model(x, targets=x[:, 3:])
# Only vq_commitment active → grads only on vq_projection, NOT lm_core
lw_vq = LossWeights(lm=0.0, vq_commitment=1.0, moe_aux=0.0, graph_l1=0.0,
graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0,
memgram_decay_reg=0.0, lstm_hidden_reg=0.0)
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, lw_vq, param_groups, free_graph=False)
for g in param_groups:
for p in g['params']:
if g['name'] in ('vq_projection', 'patch_proj', 'frame_proj'):
continue # allowed to have grads
if g['name'] == 'vq_codebook':
continue # buffers, no grads
assert p.grad is None, f"{g['name']} param should NOT have grad with only vq_commitment"
# Only moe_aux active → grads only on moe groups
lw_moe = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=1.0, graph_l1=0.0,
graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0,
memgram_decay_reg=0.0, lstm_hidden_reg=0.0)
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, lw_moe, param_groups, free_graph=False)
for g in param_groups:
for p in g['params']:
if g['name'] in ('moe', 'moe_act'):
continue
assert p.grad is None, f"{g['name']} param should NOT have grad with only moe_aux"
# Only graph_l1 active → grads only on graph group
lw_graph = LossWeights(lm=0.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=1.0,
graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0,
memgram_decay_reg=0.0, lstm_hidden_reg=0.0)
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, lw_graph, param_groups, free_graph=False)
for g in param_groups:
for p in g['params']:
if g['name'] == 'graph':
continue
assert p.grad is None, f"{g['name']} param should NOT have grad with only graph_l1"
print(" PASS test_pinpoint_gradient_isolation")
def test_pinpoint_backward_accumulation():
from train import pinpoint_backward, build_param_groups
model = ARBModel()
param_groups = build_param_groups(model, base_lr=1e-3)
x = torch.randint(0, VOCAB, (2, 66))
_, losses, _, _ = model(x, targets=x[:, 3:])
lw = LossWeights(lm=1.0, vq_commitment=0.0, moe_aux=0.0, graph_l1=0.0,
graph_ponder=0.0, moe_ponder=0.0, conv_vq_commitment=0.0,
memgram_decay_reg=0.0, lstm_hidden_reg=0.0)
# grad_accum=1 on the same loss
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, lw, param_groups, grad_accum=1, free_graph=False)
norms_full = {}
for g in param_groups:
for p in g['params']:
if p.grad is not None:
norms_full[id(p)] = p.grad.data.norm().item()
# grad_accum=2 on the same loss (graph retained from free_graph=False above)
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, lw, param_groups, grad_accum=2, free_graph=False)
checked = False
for g in param_groups:
for p in g['params']:
if p.grad is not None and id(p) in norms_full and norms_full[id(p)] > 1e-8:
n_half = p.grad.data.norm().item()
ratio = n_half / norms_full[id(p)]
assert abs(ratio - 0.5) < 0.01, \
f"grad_accum=2 should halve grads, got {ratio:.4f}"
checked = True
break
if checked:
break
assert checked, "no param with non-zero grad in both runs"
print(" PASS test_pinpoint_backward_accumulation")
def test_pinpoint_backward_rescale_effect():
from train import pinpoint_backward, build_param_groups, DEFAULT_LOSS_TARGET_MAP
model = ARBModel()
loss_weights = LossWeights()
param_groups = build_param_groups(model, base_lr=1e-3)
# Run forward + pinpoint_backward with memory_grad_scale=0.25
x = torch.randint(0, VOCAB, (2, 66))
_, losses, _, _ = model(x, targets=x[:, 3:], loss_weights=loss_weights)
grads_reference = {}
for g in param_groups:
for p in g['params']:
grads_reference[id(p)] = torch.randn_like(p) # dummy initial grads
# Zero grads and run pinpoint with aggressive rescale
for g in param_groups:
for p in g['params']:
p.grad = None
pinpoint_backward(losses, loss_weights, param_groups, memory_grad_scale=0.25)
# Verify memory params got rescaled (only possible if LSTM/MemGram enabled)
memory_group = next((g for g in param_groups if g.get('name') == 'memory'), None)
if memory_group and memory_group['params'] and memory_group['params'][0].grad is not None:
for p in memory_group['params']:
assert p.grad is not None, "memory param should have grad"
# Verify lm_core params still get gradients
lm_group = next((g for g in param_groups if g.get('name') == 'lm_core'), None)
if lm_group and lm_group['params']:
lm_has_grad = any(p.grad is not None for p in lm_group['params'])
assert lm_has_grad, "lm_core should have grads from lm loss"
print(" PASS test_pinpoint_backward_rescale_effect")
def test_moe_router_h_with_h_t():
moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2,
core_rank=16, shared_inter=128, noise_std=0.0)
moe.lstm_enabled = True
x = torch.randn(2, 10, 64)
h_t = torch.randn(2, 64)
out, aux = moe(x, h_t=h_t)
assert out.shape == (2, 10, 64)
out.sum().backward()
assert hasattr(moe.router_h, '_hook_grad_T_sign'), "router_h not trained"
print(" PASS test_moe_router_h_with_h_t")
def test_moe_router_without_h_t():
moe = SharedProjectionMoE(hidden_size=64, num_experts=4, top_k=2,
core_rank=16, shared_inter=128, noise_std=0.0)
x = torch.randn(2, 10, 64)
out, aux = moe(x, h_t=None)
assert out.shape == (2, 10, 64)
out.sum().backward()
assert hasattr(moe.router, '_hook_grad_T_sign'), "original router not trained when no h_t"
print(" PASS test_moe_router_without_h_t")
def test_memgram_shapes():
mg = MemGram(struct_primes=[101,103,107,109], conv_primes=[53,59,61,67])
vq_idx = torch.randint(0, 100, (4, 20))
hs = torch.randn(4, 20, 512)
out, decay = mg(vq_idx, None, None, hs, timestep=100)
assert out.shape == (4, 20, 512), f"output shape {out.shape}"
assert decay.ndim == 0
print(" PASS test_memgram_shapes")
def test_memgram_hash_indices():
mg = MemGram(struct_primes=[101,103], conv_primes=[53,59])
prev = torch.randint(0, 100, (4, 19))
curr = torch.randint(0, 100, (4, 19))
h = mg._hash_pairs(prev, curr, [101, 103])
assert h.shape == (4, 19, 2), f"hash shape {h.shape}"
assert (h[..., 0] < 101).all(), "hash exceeds prime range"
assert (h[..., 1] < 103).all(), "hash exceeds prime range"
print(" PASS test_memgram_hash_indices")
def test_memgram_bilinear_gate_range():
mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], key_dim=8, embed_dim=16)
vq_idx = torch.randint(0, 80, (2, 10))
hs = torch.randn(2, 10, 512)
out, _ = mg(vq_idx, None, None, hs, timestep=50)
assert out.shape == (2, 10, 512)
assert torch.isfinite(out).all()
print(" PASS test_memgram_bilinear_gate_range")
def test_memgram_decay_formula():
mg = MemGram(struct_primes=[101,103], conv_primes=[53,59])
s = torch.zeros(1)
r = torch.zeros(1)
decay = mg._compute_decay(s, r, torch.tensor(100.0))
expected = torch.sigmoid(torch.zeros(1)) * torch.exp(-torch.exp(torch.zeros(1)) * 100.0)
assert torch.allclose(decay, expected, atol=1e-6), f"decay {decay} != {expected}"
print(" PASS test_memgram_decay_formula")
def test_memgram_gradient_flow():
mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64)
vq_idx = torch.randint(0, 80, (2, 10))
hs = torch.randn(2, 10, 64)
out, decay = mg(vq_idx, None, None, hs, timestep=50)
loss = out.sum() + decay
loss.backward()
assert mg.struct_emb[0].grad is not None, "no gradient to struct_emb"
assert mg.struct_emb[0].grad.abs().sum().item() > 0
print(" PASS test_memgram_gradient_flow")
def test_memgram_conv_path():
mg = MemGram(struct_primes=[101,103], conv_primes=[53,59], embed_dim=16, key_dim=8, hidden_dim=64)
vq_idx = torch.randint(0, 80, (2, 10))
hs = torch.randn(2, 10, 64)
out_no_conv, _ = mg(vq_idx, None, None, hs, timestep=50)
conv_code = torch.randint(0, 50, (2,))
out_with_conv, _ = mg(vq_idx, conv_code, conv_code, hs, timestep=50)
assert not torch.allclose(out_no_conv, out_with_conv, atol=1e-6), "conv path should change output"
print(" PASS test_memgram_conv_path")
def test_conv_vq_shapes():
cvq = ConvVQCodebook(codebook_size=16, code_dim=8)
x = torch.randn(4, 512)
code, quantized, commitment = cvq(x, step=500, enabled=True)
assert code.shape == (4,), f"code shape {code.shape}"
assert quantized.shape == (4, 512), f"quantized shape {quantized.shape}"
assert commitment.ndim == 0
print(" PASS test_conv_vq_shapes")
def test_conv_vq_hard_cap():
cvq = ConvVQCodebook(codebook_size=8, code_dim=8)
for i in range(12):
x = torch.randn(4, 512)
cvq(x, step=i, enabled=True)
assert cvq.n_active.item() == 8, f"n_active={cvq.n_active.item()} (should be 8)"
print(" PASS test_conv_vq_hard_cap")
def test_conv_vq_deferred_activation():
cvq = ConvVQCodebook(codebook_size=8, code_dim=8)
x = torch.randn(4, 512)
code, quantized, commitment = cvq(x, step=500, enabled=False)
assert torch.equal(code, torch.zeros(4, dtype=torch.long)), "code should be zeros when disabled"
assert commitment.item() == 0.0, "commitment should be 0 when disabled"
print(" PASS test_conv_vq_deferred_activation")
def test_conv_vq_ema_update():
cvq = ConvVQCodebook(codebook_size=4, code_dim=8)
for i in range(4):
cvq(torch.randn(1, 512) + i, step=i, enabled=True)
embed_before = cvq.embed.clone()
for i in range(4):
cvq(torch.randn(1, 512), step=10 + i, enabled=True)
embed_after = cvq.embed.clone()
assert not torch.allclose(embed_before, embed_after), "entries should change through EMA + replacement"
print(" PASS test_conv_vq_ema_update")
def test_conv_vq_persistence():
cvq = ConvVQCodebook(codebook_size=8, code_dim=8)
x = torch.randn(2, 512)
cvq(x, step=0, enabled=True)
sd = cvq.state_dict()
cvq2 = ConvVQCodebook(codebook_size=8, code_dim=8)
cvq2.load_state_dict(sd)
assert torch.allclose(cvq.embed, cvq2.embed)
assert torch.equal(cvq.timestamps, cvq2.timestamps)
assert torch.equal(cvq.n_active, cvq2.n_active)
assert torch.equal(cvq.cluster_size, cvq2.cluster_size)
print(" PASS test_conv_vq_persistence")
def test_conv_vq_fuzzy_retrieve():
cvq = ConvVQCodebook(codebook_size=16, code_dim=8)
for i in range(3):
x = torch.randn(2, 512)
cvq(x, step=i, enabled=True)
query = torch.randn(8)
idx, sim = cvq.fuzzy_retrieve(query, top_k=3)
assert idx.numel() == 3, f"retrieved {idx.numel()} elements, expected 3"
assert sim.numel() == 3
print(" PASS test_conv_vq_fuzzy_retrieve")
def test_conv_vq_commitment_nonneg():
cvq = ConvVQCodebook(codebook_size=8, code_dim=8)
x = torch.randn(2, 512)
_, _, commitment = cvq(x, step=0, enabled=True)
assert commitment.item() >= 0, f"negative commitment {commitment.item()}"
print(" PASS test_conv_vq_commitment_nonneg")
def test_lstm_shapes():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(4, 64)
h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None)
assert h_out.shape == (4, 64), f"h_out shape {h_out.shape}"
assert c_focus.shape == (4, 64), f"c_focus shape {c_focus.shape}"
assert h_topic.shape == (4, 64), f"h_topic shape {h_topic.shape}"
assert c_topic.shape == (4, 64), f"c_topic shape {c_topic.shape}"
assert c_proj.shape == (4, 64), f"c_proj shape {c_proj.shape}"
assert reg.ndim == 0
print(" PASS test_lstm_shapes")
def test_lstm_forget_gate_bias():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
bias_ih = lstm.focus_cell.bias_ih[64:128]
assert torch.allclose(bias_ih, torch.ones_like(bias_ih)), "focus forget gate bias not 1.0"
bias_ih_topic = lstm.topic_cell.bias_ih[64:128]
assert torch.allclose(bias_ih_topic, torch.full_like(bias_ih_topic, 1.5)), "topic forget gate bias not 1.5"
print(" PASS test_lstm_forget_gate_bias")
def test_lstm_bptt_detach():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10)
x = torch.randn(2, 64)
memory = None
for i in range(49):
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory)
memory = (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach())
assert h_out.grad_fn is not None, "grad_fn should exist before BPTT boundary"
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, memory)
assert h_out.grad_fn is None, "h_out should be detached at BPTT focus boundary"
print(" PASS test_lstm_bptt_detach")
def test_lstm_hidden_reg():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(2, 64)
_, _, _, _, _, reg = lstm(x, None)
h_out, _, _, _, _, _ = lstm(x, None)
expected = (h_out ** 2).mean()
assert torch.allclose(reg, expected, atol=1e-6), f"reg {reg} != expected {expected}"
print(" PASS test_lstm_hidden_reg")
def test_lstm_c_t_proj_ternary():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
assert isinstance(lstm.c_focus_proj, TernaryScaleTensor), "c_focus_proj not TernaryScaleTensor"
assert isinstance(lstm.c_topic_proj, TernaryScaleTensor), "c_topic_proj not TernaryScaleTensor"
print(" PASS test_lstm_c_t_proj_ternary")
def test_memory_modules_backward_compat():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, CTX))
logits, losses, indices, _ = model(x, targets=x[:, 3:])
assert logits.shape[0] == 2
assert losses is not None
print(" PASS test_memory_modules_backward_compat")
# ===== Phase 7: Forward Pipeline Integration Tests =====
def test_forward_no_memory_backward_compat():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, indices, mem_state = model(x, targets=targets)
assert logits.shape == (2, 64, VOCAB), f"logits shape {logits.shape}"
assert mem_state is None, f"memory_state should be None when lstm disabled, got {mem_state}"
print(" PASS test_forward_no_memory_backward_compat")
def test_forward_lstm_enabled_h_t_passed():
model = ARBModel()
model.lstm_enabled = True
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits, losses, indices, mem_state = model(x, targets=targets, memory_state=None, timestep=0)
assert mem_state is not None, "memory_state should be tuple when lstm enabled"
h_out, c_focus, h_topic, c_topic = mem_state
assert h_out.shape == (2, 512), f"h_out shape {h_out.shape}"
assert c_focus.shape == (2, 512), f"c_focus shape {c_focus.shape}"
assert h_topic.shape == (2, 512), f"h_topic shape {h_topic.shape}"
assert c_topic.shape == (2, 512), f"c_topic shape {c_topic.shape}"
print(" PASS test_forward_lstm_enabled_h_t_passed")
def test_forward_lstm_c_t_residual():
model = ARBModel()
x = torch.randint(0, VOCAB, (2, 66))
targets = x[:, 3:]
logits_no_lstm, _, _, _ = model(x, targets=targets)
model.lstm_enabled = True
logits_lstm, _, _, _ = model(x, targets=targets, memory_state=None, timestep=0)
assert not torch.allclose(logits_no_lstm, logits_lstm, atol=1e-4), "c_t_proj should modify output"
print(" PASS test_forward_lstm_c_t_residual")
def test_forward_memgram_injection():
model = ARBModel()
model.memgram_enabled = True
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100)
assert losses.memgram_decay_reg is not None, "memgram_decay_reg should be set"
print(" PASS test_forward_memgram_injection")
def test_forward_conv_vq_deferred():
model = ARBModel()
model.conv_vq_enabled = True
model._conv_vq_ready = False
x = torch.randint(0, VOCAB, (2, 66))
logits, losses, _, _ = model(x, targets=x[:, 3:], timestep=100)
assert losses.conv_vq_commitment is None or losses.conv_vq_commitment.item() == 0.0, \
f"conv_vq should be deferred, got {losses.conv_vq_commitment}"
print(" PASS test_forward_conv_vq_deferred")
def test_generate_carries_lstm_state():
model = ARBModel()
model.lstm_enabled = True
idx = torch.zeros((1, 10), dtype=torch.long)
out = model.generate(idx, max_new_token=10)
assert out.shape == (1, 20), f"output shape {out.shape}"
assert out.dtype == torch.long
print(" PASS test_generate_carries_lstm_state")
# ===== Phase 7: Training Schedule Tests =====
def test_memory_schedule_warmup():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(0, 10000)
assert not any([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all off during warmup"
print(" PASS test_memory_schedule_warmup")
def test_memory_schedule_lstm_first():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(2500, 10000)
assert lstm_on, "lstm should be on after warmup"
assert not memgram_on, "memgram should be off at 25%"
print(" PASS test_memory_schedule_lstm_first")
def test_memory_schedule_memgram_second():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(3500, 10000, vq_utilization=0.4)
assert lstm_on and memgram_on and conv_vq_on, "lstm+memgram+conv_vq on at 35% with util>30%"
assert not decay_reg_on, "decay_reg should be off at 35%"
print(" PASS test_memory_schedule_memgram_second")
def test_memory_schedule_all_on():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(5000, 10000, vq_utilization=0.4)
assert all([lstm_on, memgram_on, conv_vq_on, decay_reg_on]), "all on at 50%"
print(" PASS test_memory_schedule_all_on")
def test_memory_schedule_conv_vq_requires_vq_util():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4000, 10000, vq_utilization=0.1)
assert lstm_on, "lstm should be on"
assert not conv_vq_on, "conv_vq should be off when util < 30%"
print(" PASS test_memory_schedule_conv_vq_requires_vq_util")
def test_memory_schedule_decay_reg_last():
from train import compute_memory_schedule
lstm_on, memgram_on, conv_vq_on, decay_reg_on = compute_memory_schedule(4500, 10000, vq_utilization=0.4)
assert decay_reg_on, "decay_reg should be on at 45%"
print(" PASS test_memory_schedule_decay_reg_last")
def test_lstm_state_reset_per_batch():
model = ARBModel()
model.lstm_enabled = True
model.eval()
x = torch.randint(0, VOCAB, (2, 66))
with torch.no_grad():
_, _, _, mem1 = model(x, memory_state=None, timestep=0)
_, _, _, mem2 = model(x, memory_state=None, timestep=1)
h1, c1_f, h1_t, c1_t = mem1
h2, c2_f, h2_t, c2_t = mem2
assert h1.shape == h2.shape, "h_out shapes should match"
assert c1_f.shape == c2_f.shape, "c_focus shapes should match"
print(" PASS test_lstm_state_reset_per_batch")
def test_bptt_counter_separate():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=5, bptt_topic=10)
x = torch.randn(2, 64)
memory = None
for i in range(4):
h_out, c_f, h_t, c_t, _, _ = lstm(x, memory)
memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach())
assert lstm.step_count == 4, f"step_count={lstm.step_count}"
h_out, c_f, h_t, c_t, _, _ = lstm(x, memory)
assert c_f.grad_fn is None, "c_focus should be detached at BPTT focus boundary"
print(" PASS test_bptt_counter_separate")
# ===== FocusGate Tests =====
def test_focus_gate_no_boundary():
fg = FocusGate(hidden_dim=64)
x = torch.randn(2, 64)
reset, dampen = fg(x, boundary_signal=None)
assert reset.shape == (2, 1), f"reset shape {reset.shape}"
assert dampen.shape == (2, 64), f"dampen shape {dampen.shape}"
assert torch.allclose(reset, torch.ones_like(reset)), "reset should be 1.0 for no boundary"
assert torch.allclose(dampen, torch.ones_like(dampen)), "dampen should be 1.0 for no boundary"
print(" PASS test_focus_gate_no_boundary")
def test_focus_gate_boundary_signal():
fg = FocusGate(hidden_dim=64)
x = torch.randn(2, 64)
reset_bos, dampen_bos = fg(x, boundary_signal=SPECIAL_VOCAB['BOS'])
assert reset_bos.shape == (2, 1)
assert dampen_bos.shape == (2, 64)
assert 0.0 <= reset_bos.mean().item() <= 1.0, f"reset out of range: {reset_bos.mean().item()}"
assert 0.0 <= dampen_bos.mean().item() <= 1.0, f"dampen out of range: {dampen_bos.mean().item()}"
print(" PASS test_focus_gate_boundary_signal")
def test_focus_gate_all_boundary_types():
fg = FocusGate(hidden_dim=64)
x = torch.randn(2, 64)
for tok_name, tok_id in [('BOS', SPECIAL_VOCAB['BOS']),
('SYSTEM', SPECIAL_VOCAB['SYSTEM']),
('USER', SPECIAL_VOCAB['USER']),
('ASSISTANT', SPECIAL_VOCAB['ASSISTANT'])]:
reset, dampen = fg(x, boundary_signal=tok_id)
assert reset.mean().item() < 1.0, f"{tok_name} should reduce reset from 1.0"
no_reset, _ = fg(x, boundary_signal=None)
assert torch.allclose(no_reset, torch.ones_like(no_reset)), "no-boundary reset should be 1.0"
print(" PASS test_focus_gate_all_boundary_types")
def test_focus_gate_unknown_token():
fg = FocusGate(hidden_dim=64)
x = torch.randn(2, 64)
reset, dampen = fg(x, boundary_signal=9999)
assert torch.allclose(reset, torch.ones_like(reset)), "unknown token should return reset=1.0"
assert torch.allclose(dampen, torch.ones_like(dampen)), "unknown token should return dampen=1.0"
print(" PASS test_focus_gate_unknown_token")
def test_focus_gate_c_focus_modulation():
fg = FocusGate(hidden_dim=64)
x = torch.randn(2, 64)
c_focus = torch.randn(2, 64)
c_focus_before = c_focus.clone()
reset, dampen = fg(x, boundary_signal=SPECIAL_VOCAB['USER'])
c_focus_mod = c_focus * reset * dampen
assert not torch.allclose(c_focus_mod, c_focus_before, atol=1e-6), "focus gate should modify c_focus on boundary"
reset_none, dampen_none = fg(x, boundary_signal=None)
c_focus_noop = c_focus * reset_none * dampen_none
assert torch.allclose(c_focus_noop, c_focus, atol=1e-6), "focus gate should not modify c_focus without boundary"
print(" PASS test_focus_gate_c_focus_modulation")
# ===== ConversationStack Tests =====
def test_conv_stack_push_pop():
stack = ConversationStack(max_conversations=4, hidden_dim=64)
h = torch.randn(64)
c_f = torch.randn(64)
h_t = torch.randn(64)
c_t = torch.randn(64)
stack.push("conv_1", h, c_f, h_t, c_t, "cpu")
result = stack.pop("conv_1", "cpu")
assert result is not None, "pop should find pushed conversation"
h_rest, c_f_rest, h_t_rest, c_t_rest = result
assert torch.allclose(h_rest, h, atol=1e-6), "h_focus not preserved"
assert torch.allclose(c_f_rest, c_f, atol=1e-6), "c_focus not preserved"
assert torch.allclose(h_t_rest, h_t, atol=1e-6), "h_topic not preserved"
assert torch.allclose(c_t_rest, c_t, atol=1e-6), "c_topic not preserved"
print(" PASS test_conv_stack_push_pop")
def test_conv_stack_pop_missing():
stack = ConversationStack(max_conversations=4, hidden_dim=64)
result = stack.pop("nonexistent", "cpu")
assert result is None, "pop on nonexistent should return None"
print(" PASS test_conv_stack_pop_missing")
def test_conv_stack_clear():
stack = ConversationStack(max_conversations=4, hidden_dim=64)
h = torch.randn(64)
stack.push("conv_1", h, torch.randn(64), torch.randn(64), torch.randn(64), "cpu")
stack.clear("conv_1")
result = stack.pop("conv_1", "cpu")
assert result is None, "cleared conversation should not be found"
print(" PASS test_conv_stack_clear")
def test_conv_stack_lru_eviction():
stack = ConversationStack(max_conversations=2, hidden_dim=64)
for i in range(3):
h = torch.full((64,), float(i))
stack.push(f"conv_{i}", h, torch.zeros(64), torch.zeros(64), torch.zeros(64), "cpu")
result_0 = stack.pop("conv_0", "cpu")
assert result_0 is None, "conv_0 should be evicted (LRU)"
result_1 = stack.pop("conv_1", "cpu")
assert result_1 is not None, "conv_1 should still exist"
result_2 = stack.pop("conv_2", "cpu")
assert result_2 is not None, "conv_2 should still exist"
print(" PASS test_conv_stack_lru_eviction")
def test_conv_stack_reset():
stack = ConversationStack(max_conversations=4, hidden_dim=64)
stack.push("conv_1", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu")
stack.push("conv_2", torch.randn(64), torch.randn(64), torch.randn(64), torch.randn(64), "cpu")
stack.reset()
assert stack.pop("conv_1", "cpu") is None, "conv_1 should be gone after reset"
assert stack.pop("conv_2", "cpu") is None, "conv_2 should be gone after reset"
assert stack.active_slot == -1, "active_slot should be -1 after reset"
print(" PASS test_conv_stack_reset")
# ===== ConversationLSTM Tests =====
def test_conversation_lstm_forward():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=50, bptt_topic=200)
x = torch.randn(4, 64)
h_out, c_focus, h_topic, c_topic, c_proj, reg = lstm(x, None)
assert h_out.shape == (4, 64)
assert c_focus.shape == (4, 64)
assert h_topic.shape == (4, 64)
assert c_topic.shape == (4, 64)
assert c_proj.shape == (4, 64)
assert reg.ndim == 0
print(" PASS test_conversation_lstm_forward")
def test_conversation_lstm_dual_state():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(2, 64)
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None)
for _ in range(10):
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()))
assert c_focus.abs().sum() > 0, "c_focus should be nonzero after steps"
assert c_topic.abs().sum() > 0, "c_topic should be nonzero after steps"
print(" PASS test_conversation_lstm_dual_state")
def test_conversation_lstm_boundary_reset():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(2, 64)
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, None)
for _ in range(5):
h_out, c_focus, h_topic, c_topic, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()))
c_focus_before = c_focus.clone()
h_out, c_focus_after, _, _, _, _ = lstm(x, (h_out.detach(), c_focus.detach(), h_topic.detach(), c_topic.detach()), boundary_signal=SPECIAL_VOCAB['BOS'])
assert lstm._last_reset < 1.0, "BOS boundary should reduce reset from 1.0"
print(" PASS test_conversation_lstm_boundary_reset")
def test_conversation_lstm_topic_gate():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(2, 64)
_, _, h_topic_0, _, _, _ = lstm(x, None)
assert h_topic_0.abs().sum() > 0, "h_topic should be nonzero after one step"
print(" PASS test_conversation_lstm_topic_gate")
def test_conversation_lstm_bptt_dual_windows():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64, bptt_focus=3, bptt_topic=6)
x = torch.randn(2, 64)
memory = None
for i in range(2):
h_out, c_f, h_t, c_t, _, _ = lstm(x, memory)
memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach())
assert h_out.grad_fn is not None, "h_out should have grad before focus BPTT"
h_out, c_f, h_t, c_t, _, _ = lstm(x, memory)
assert c_f.grad_fn is None, "c_focus should be detached at focus BPTT boundary (step 3)"
print(" PASS test_conversation_lstm_bptt_dual_windows")
def test_conversation_lstm_topic_preserves_on_boundary():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
x = torch.randn(2, 64)
memory = None
for _ in range(5):
h_out, c_f, h_t, c_t, _, _ = lstm(x, memory)
memory = (h_out.detach(), c_f.detach(), h_t.detach(), c_t.detach())
c_topic_before = c_t.clone()
_, c_f_after, _, c_t_after, _, _ = lstm(x, memory, boundary_signal=SPECIAL_VOCAB['USER'])
decay_ratio = c_t_after.norm() / max(c_topic_before.norm(), 1e-8)
focus_decay = c_f_after.norm() / max(c_f.norm(), 1e-8)
assert decay_ratio >= focus_decay * 0.9, f"topic should decay less than focus on boundary: topic={decay_ratio:.3f} focus={focus_decay:.3f}"
print(" PASS test_conversation_lstm_topic_preserves_on_boundary")
def test_conversation_lstm_h_out_is_sum():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
lstm.eval()
x = torch.randn(2, 64)
with torch.no_grad():
h_out, c_f, h_topic, c_topic, c_proj, _ = lstm(x, None)
assert h_out.shape == (2, 64)
assert c_proj.shape == (2, 64)
print(" PASS test_conversation_lstm_h_out_is_sum")
def test_extract_boundary_from_input():
x_bos = torch.tensor([[SPECIAL_VOCAB['BOS'], 10, 20]])
assert _extract_boundary_from_input(x_bos) == SPECIAL_VOCAB['BOS'], "should detect BOS"
x_user = torch.tensor([[10, SPECIAL_VOCAB['USER'], 20]])
assert _extract_boundary_from_input(x_user) == SPECIAL_VOCAB['USER'], "should detect USER"
x_none = torch.tensor([[10, 20, 30]])
assert _extract_boundary_from_input(x_none) is None, "should return None for no boundary"
print(" PASS test_extract_boundary_from_input")
# ===== Integration: ConversationLSTM + Model =====
def test_model_switch_conversation():
model = ARBModel()
model.lstm_enabled = True
model.switch_conversation("conv_A")
assert model.lstm.conv_stack._current_conv_id == "conv_A"
print(" PASS test_model_switch_conversation")
def test_model_reset_conversation():
model = ARBModel()
model.lstm_enabled = True
model.switch_conversation("conv_A")
model.reset_conversation("conv_A")
assert model.lstm.conv_stack._current_conv_id is None
print(" PASS test_model_reset_conversation")
def test_model_generate_with_conversation_id():
model = ARBModel()
model.lstm_enabled = True
idx = torch.zeros((1, 10), dtype=torch.long)
out = model.generate(idx, max_new_token=5, conversation_id="conv_test")
assert out.shape == (1, 15), f"output shape {out.shape}"
print(" PASS test_model_generate_with_conversation_id")
def test_conversation_lstm_ternary_projections():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
assert isinstance(lstm.c_focus_proj, TernaryScaleTensor)
assert isinstance(lstm.c_topic_proj, TernaryScaleTensor)
print(" PASS test_conversation_lstm_ternary_projections")
def test_conversation_lstm_focus_gate_params():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
assert isinstance(lstm.focus_gate, FocusGate)
assert isinstance(lstm.focus_gate.boundary_embed, nn.Embedding)
assert isinstance(lstm.focus_gate.reset_fc, nn.Linear)
assert isinstance(lstm.focus_gate.dampen_fc, nn.Linear)
print(" PASS test_conversation_lstm_focus_gate_params")
def test_conversation_lstm_topic_cell_bias():
lstm = ConversationLSTM(input_dim=64, hidden_dim=64)
bias_topic = lstm.topic_cell.bias_ih[64:128]
assert torch.allclose(bias_topic, torch.full_like(bias_topic, 1.5)), f"topic forget gate bias should be 1.5, got {bias_topic.mean().item()}"
print(" PASS test_conversation_lstm_topic_cell_bias")
if __name__ == "__main__":
tests = [
test_sticky_zone_ste,
test_sticky_zone_ste_dtype_preservation,
test_scaled_ternary_linear,
test_rmsnorm,
test_byte_embedding,
test_text_sequencer,
test_trigram_window,
test_image_sequencer,
test_image_sequencer_frozen,
test_target_alignment,
test_model_forward,
test_generate,
test_param_count,
test_gradient_flow,
test_model_forward_with_targets,
test_save_load_roundtrip,
test_vq_adapter_shapes,
test_vq_integration,
test_vq_disabled,
test_vq_with_targets,
test_l2_distance_matching,
test_vq_ternary_projections,
test_multimodal_vq_bridge_text_only,
test_multimodal_vq_bridge_text_image,
test_modality_gate_shapes,
test_ternary_graph_multicodebook,
test_vq_no_float_cast_in_model,
test_zero_fp32_params,
test_sticky_zone_ste_gradient,
test_graph_moe_gate_shape,
test_ternary_graph_shapes,
test_graph_gradient_flow,
test_graph_connectivity_monitor,
test_model_forward_with_graph,
test_model_graph_disabled,
test_ternary_graph_in_modules,
test_moe_shapes,
test_moe_router,
test_moe_aux_loss,
test_shared_expert,
test_moe_gradient_flow,
test_moe_zero_fp32,
test_ternary_graph_with_gate,
test_model_forward_with_moe,
test_model_moe_disabled,
test_model_moe_loss_components,
test_model_moe_gate_modulation,
test_param_count_with_moe,
test_moe_monitoring,
test_loss_components,
test_loss_components_none_fields,
test_loss_components_backward,
test_gnn_lora_adapter,
test_gnn_lora_gradient,
test_shared_gnn_weight_tying,
test_shared_gnn_multi_hop,
test_model_losses_components_type,
test_halting_unit_shapes,
test_halting_unit_ternary_pure,
test_graph_act_cell_shapes,
test_moe_act_cell_shapes,
test_act_early_halt,
test_act_weight_sum_one,
test_act_gradient_flow,
test_loss_components_ponder_fields,
test_loss_components_ponder_none,
test_act_graph_moe_sequential,
test_model_forward_with_act,
test_model_act_forward_without_targets,
test_model_act_loss_components,
test_model_act_backward,
test_model_act_disabled,
test_model_act_warmup_mode,
test_model_act_ponder_cached,
test_act_warmup_schedule,
test_act_ponder_lambda,
test_model_ponder_lambda_scaling,
test_text_only_forward,
test_image_forward,
test_multimodal_backward,
test_no_stale_trigram_encoder,
test_vocab,
test_memgram_shapes,
test_memgram_hash_indices,
test_memgram_bilinear_gate_range,
test_memgram_decay_formula,
test_memgram_gradient_flow,
test_memgram_conv_path,
test_conv_vq_shapes,
test_conv_vq_hard_cap,
test_conv_vq_deferred_activation,
test_conv_vq_ema_update,
test_conv_vq_persistence,
test_conv_vq_fuzzy_retrieve,
test_conv_vq_commitment_nonneg,
test_lstm_shapes,
test_lstm_forget_gate_bias,
test_lstm_bptt_detach,
test_lstm_hidden_reg,
test_lstm_c_t_proj_ternary,
test_memory_modules_backward_compat,
test_loss_components_nine_fields_total,
test_loss_components_nine_fields_log,
test_moe_router_h_with_h_t,
test_moe_router_without_h_t,
test_forward_no_memory_backward_compat,
test_forward_lstm_enabled_h_t_passed,
test_forward_lstm_c_t_residual,
test_forward_memgram_injection,
test_forward_conv_vq_deferred,
test_generate_carries_lstm_state,
test_memory_schedule_warmup,
test_memory_schedule_lstm_first,
test_memory_schedule_memgram_second,
test_memory_schedule_all_on,
test_memory_schedule_conv_vq_requires_vq_util,
test_memory_schedule_decay_reg_last,
test_lstm_state_reset_per_batch,
test_bptt_counter_separate,
test_focus_gate_no_boundary,
test_focus_gate_boundary_signal,
test_focus_gate_all_boundary_types,
test_focus_gate_unknown_token,
test_focus_gate_c_focus_modulation,
test_conv_stack_push_pop,
test_conv_stack_pop_missing,
test_conv_stack_clear,
test_conv_stack_lru_eviction,
test_conv_stack_reset,
test_conversation_lstm_forward,
test_conversation_lstm_dual_state,
test_conversation_lstm_boundary_reset,
test_conversation_lstm_topic_gate,
test_conversation_lstm_bptt_dual_windows,
test_conversation_lstm_topic_preserves_on_boundary,
test_conversation_lstm_h_out_is_sum,
test_extract_boundary_from_input,
test_model_switch_conversation,
test_model_reset_conversation,
test_model_generate_with_conversation_id,
test_conversation_lstm_ternary_projections,
test_conversation_lstm_focus_gate_params,
test_conversation_lstm_topic_cell_bias,
]
print("Running MORPH tests (Phase 1 + Phase 2 VQ + Phase 3 Graph + Phase 4 MoE + Explore + Phase 5 ACT + Phase 6 Multi-Modal + Phase 7 Memory)...\n")
passed = 0
failed = 0
for t in tests:
try:
t()
passed += 1
except Exception as e:
print(f" FAIL {t.__name__}: {e}")
failed += 1
print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")