"""Text comprehension tests — verify the byte-level text pipeline on CPU. Tests: ByteEmbedding, TextSequencer, ByteHead, OutputRouter, forward pass. Runs on CPU — safe to parallelize with GPU training. """ import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch from arbitor import ARBModel, VOCAB, HIDDEN_DIM from arbitor.config import SPECIAL_VOCAB from arbitor.kernel.ternary_scale import TScaleType device = "cpu" FAILED = 0 def check(name, condition, detail=""): global FAILED if condition: print(f" ✓ {name}") else: print(f" ✗ {name} — {detail}") FAILED += 1 print("\n=== Text Comprehension ===\n") # 1. ByteEmbedding from arbitor.sequencers import ByteEmbedding emb = ByteEmbedding() x = torch.randint(0, 256, (4, 20)) out = emb(x) check("ByteEmbedding forward", out.shape == (4, 20, 256), f"got {out.shape}") # 2. Special token embedding bos = torch.full((1, 1), SPECIAL_VOCAB['BOS'], dtype=torch.long) bos_out = emb(bos) check("BOS token embedded", bos_out.shape == (1, 1, 256)) check("BOS has non-zero values", bos_out.abs().sum().item() > 0) # 3. TextSequencer from arbitor.sequencers import TextSequencer seq = TextSequencer() x_unfolded = torch.randn(4, 20, 256) rel = seq(x_unfolded) check("TextSequencer forward", rel.shape == (4, 18, 512), f"got {rel.shape}") check("No NaN in relational", not torch.isnan(rel).any()) # 4. Full model forward — step by step to avoid training-mode shape edge cases model = ARBModel(enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False) model.eval() x = torch.randint(0, 256, (2, 10)) with torch.no_grad(): embedded = model.embedding(x) seq_out = model.multimodal_sequencer({'text': embedded}) relational = seq_out['text'] logits = model.byte_head(relational) check("Model forward (no MoE)", logits.shape[-1] == VOCAB, f"last dim={logits.shape[-1]}") check("Logits finite", torch.isfinite(logits).all()) # 5. Full model with MoE (step-by-step) model2 = ARBModel(enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=True) model2.eval() x2 = torch.randint(0, 256, (2, 10)) with torch.no_grad(): embedded2 = model2.embedding(x2) seq_out2 = model2.multimodal_sequencer({'text': embedded2}) relational2 = seq_out2['text'] moe_out2, _ = model2.moe(relational2) logits2 = model2.byte_head(moe_out2) check("Model forward (MoE)", logits2.shape[-1] == VOCAB, f"last dim={logits2.shape[-1]}") check("Logits finite", torch.isfinite(logits2).all()) # 6. OutputRouter from arbitor import OutputRouter router = OutputRouter() router_out = router(rel, training=False) check("OutputRouter argmax", router_out.shape == (4, 18), f"got {router_out.shape}") router_weights, router_logits = router(rel, training=True) check("OutputRouter soft weights", router_weights.shape == (4, 18, 4)) # 7. ByteHead output distribution from arbitor import ByteHead head = ByteHead() logits = head(rel) probs = torch.softmax(logits, dim=-1) check("ByteHead output sum approx 1", abs(probs[0, 0].sum().item() - 1.0) < 1e-4, f"sum={probs[0, 0].sum().item()}") # 8. Token range sanity check("Logits in valid range", 0 <= logits.max().item() <= 100) print(f"\n{'='*50}") if FAILED == 0: print("✓ All text comprehension tests passed!") else: print(f"✗ {FAILED} test(s) failed") sys.exit(FAILED)