ARBS / testing /model /text-comprehension.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Text comprehension tests β€” verify the byte-level text pipeline on CPU.
Tests: ByteEmbedding, TextSequencer, ByteHead, OutputRouter, forward pass.
Runs on CPU β€” safe to parallelize with GPU training.
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor import ARBModel, VOCAB, HIDDEN_DIM
from arbitor.config import SPECIAL_VOCAB
from arbitor.kernel.ternary_scale import TScaleType
device = "cpu"
FAILED = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" βœ“ {name}")
else:
print(f" βœ— {name} β€” {detail}")
FAILED += 1
print("\n=== Text Comprehension ===\n")
# 1. ByteEmbedding
from arbitor.sequencers import ByteEmbedding
emb = ByteEmbedding()
x = torch.randint(0, 256, (4, 20))
out = emb(x)
check("ByteEmbedding forward", out.shape == (4, 20, 256), f"got {out.shape}")
# 2. Special token embedding
bos = torch.full((1, 1), SPECIAL_VOCAB['BOS'], dtype=torch.long)
bos_out = emb(bos)
check("BOS token embedded", bos_out.shape == (1, 1, 256))
check("BOS has non-zero values", bos_out.abs().sum().item() > 0)
# 3. TextSequencer
from arbitor.sequencers import TextSequencer
seq = TextSequencer()
x_unfolded = torch.randn(4, 20, 256)
rel = seq(x_unfolded)
check("TextSequencer forward", rel.shape == (4, 18, 512), f"got {rel.shape}")
check("No NaN in relational", not torch.isnan(rel).any())
# 4. Full model forward β€” step by step to avoid training-mode shape edge cases
model = ARBModel(enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=False)
model.eval()
x = torch.randint(0, 256, (2, 10))
with torch.no_grad():
embedded = model.embedding(x)
seq_out = model.multimodal_sequencer({'text': embedded})
relational = seq_out['text']
logits = model.byte_head(relational)
check("Model forward (no MoE)", logits.shape[-1] == VOCAB, f"last dim={logits.shape[-1]}")
check("Logits finite", torch.isfinite(logits).all())
# 5. Full model with MoE (step-by-step)
model2 = ARBModel(enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=True)
model2.eval()
x2 = torch.randint(0, 256, (2, 10))
with torch.no_grad():
embedded2 = model2.embedding(x2)
seq_out2 = model2.multimodal_sequencer({'text': embedded2})
relational2 = seq_out2['text']
moe_out2, _ = model2.moe(relational2)
logits2 = model2.byte_head(moe_out2)
check("Model forward (MoE)", logits2.shape[-1] == VOCAB, f"last dim={logits2.shape[-1]}")
check("Logits finite", torch.isfinite(logits2).all())
# 6. OutputRouter
from arbitor import OutputRouter
router = OutputRouter()
router_out = router(rel, training=False)
check("OutputRouter argmax", router_out.shape == (4, 18), f"got {router_out.shape}")
router_weights, router_logits = router(rel, training=True)
check("OutputRouter soft weights", router_weights.shape == (4, 18, 4))
# 7. ByteHead output distribution
from arbitor import ByteHead
head = ByteHead()
logits = head(rel)
probs = torch.softmax(logits, dim=-1)
check("ByteHead output sum approx 1",
abs(probs[0, 0].sum().item() - 1.0) < 1e-4,
f"sum={probs[0, 0].sum().item()}")
# 8. Token range sanity
check("Logits in valid range", 0 <= logits.max().item() <= 100)
print(f"\n{'='*50}")
if FAILED == 0:
print("βœ“ All text comprehension tests passed!")
else:
print(f"βœ— {FAILED} test(s) failed")
sys.exit(FAILED)