| """Text comprehension tests β verify the byte-level text pipeline on CPU. |
| |
| Tests: ByteEmbedding, TextSequencer, ByteHead, OutputRouter, forward pass. |
| Runs on CPU β safe to parallelize with GPU training. |
| """ |
| import os, sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from arbitor import ARBModel, VOCAB, HIDDEN_DIM |
| from arbitor.config import SPECIAL_VOCAB |
| from arbitor.kernel.ternary_scale import TScaleType |
|
|
| device = "cpu" |
| FAILED = 0 |
|
|
| def check(name, condition, detail=""): |
| global FAILED |
| if condition: |
| print(f" β {name}") |
| else: |
| print(f" β {name} β {detail}") |
| FAILED += 1 |
|
|
|
|
| print("\n=== Text Comprehension ===\n") |
|
|
| |
| from arbitor.sequencers import ByteEmbedding |
| emb = ByteEmbedding() |
| x = torch.randint(0, 256, (4, 20)) |
| out = emb(x) |
| check("ByteEmbedding forward", out.shape == (4, 20, 256), f"got {out.shape}") |
|
|
| |
| bos = torch.full((1, 1), SPECIAL_VOCAB['BOS'], dtype=torch.long) |
| bos_out = emb(bos) |
| check("BOS token embedded", bos_out.shape == (1, 1, 256)) |
| check("BOS has non-zero values", bos_out.abs().sum().item() > 0) |
|
|
| |
| from arbitor.sequencers import TextSequencer |
| seq = TextSequencer() |
| x_unfolded = torch.randn(4, 20, 256) |
| rel = seq(x_unfolded) |
| check("TextSequencer forward", rel.shape == (4, 18, 512), f"got {rel.shape}") |
| check("No NaN in relational", not torch.isnan(rel).any()) |
|
|
| |
| model = ARBModel(enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=False) |
| model.eval() |
| x = torch.randint(0, 256, (2, 10)) |
| with torch.no_grad(): |
| embedded = model.embedding(x) |
| seq_out = model.multimodal_sequencer({'text': embedded}) |
| relational = seq_out['text'] |
| logits = model.byte_head(relational) |
| check("Model forward (no MoE)", logits.shape[-1] == VOCAB, f"last dim={logits.shape[-1]}") |
| check("Logits finite", torch.isfinite(logits).all()) |
|
|
| |
| model2 = ARBModel(enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=True) |
| model2.eval() |
| x2 = torch.randint(0, 256, (2, 10)) |
| with torch.no_grad(): |
| embedded2 = model2.embedding(x2) |
| seq_out2 = model2.multimodal_sequencer({'text': embedded2}) |
| relational2 = seq_out2['text'] |
| moe_out2, _ = model2.moe(relational2) |
| logits2 = model2.byte_head(moe_out2) |
| check("Model forward (MoE)", logits2.shape[-1] == VOCAB, f"last dim={logits2.shape[-1]}") |
| check("Logits finite", torch.isfinite(logits2).all()) |
|
|
| |
| from arbitor import OutputRouter |
| router = OutputRouter() |
| router_out = router(rel, training=False) |
| check("OutputRouter argmax", router_out.shape == (4, 18), f"got {router_out.shape}") |
| router_weights, router_logits = router(rel, training=True) |
| check("OutputRouter soft weights", router_weights.shape == (4, 18, 4)) |
|
|
| |
| from arbitor import ByteHead |
| head = ByteHead() |
| logits = head(rel) |
| probs = torch.softmax(logits, dim=-1) |
| check("ByteHead output sum approx 1", |
| abs(probs[0, 0].sum().item() - 1.0) < 1e-4, |
| f"sum={probs[0, 0].sum().item()}") |
|
|
| |
| check("Logits in valid range", 0 <= logits.max().item() <= 100) |
|
|
| print(f"\n{'='*50}") |
| if FAILED == 0: |
| print("β All text comprehension tests passed!") |
| else: |
| print(f"β {FAILED} test(s) failed") |
| sys.exit(FAILED) |
|
|