File size: 3,573 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Text comprehension tests — verify the byte-level text pipeline on CPU.

Tests: ByteEmbedding, TextSequencer, ByteHead, OutputRouter, forward pass.
Runs on CPU — safe to parallelize with GPU training.
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor import ARBModel, VOCAB, HIDDEN_DIM
from arbitor.config import SPECIAL_VOCAB
from arbitor.kernel.ternary_scale import TScaleType

device = "cpu"
FAILED = 0

def check(name, condition, detail=""):
    global FAILED
    if condition:
        print(f"  ✓ {name}")
    else:
        print(f"  ✗ {name} — {detail}")
        FAILED += 1


print("\n=== Text Comprehension ===\n")

# 1. ByteEmbedding
from arbitor.sequencers import ByteEmbedding
emb = ByteEmbedding()
x = torch.randint(0, 256, (4, 20))
out = emb(x)
check("ByteEmbedding forward", out.shape == (4, 20, 256), f"got {out.shape}")

# 2. Special token embedding
bos = torch.full((1, 1), SPECIAL_VOCAB['BOS'], dtype=torch.long)
bos_out = emb(bos)
check("BOS token embedded", bos_out.shape == (1, 1, 256))
check("BOS has non-zero values", bos_out.abs().sum().item() > 0)

# 3. TextSequencer
from arbitor.sequencers import TextSequencer
seq = TextSequencer()
x_unfolded = torch.randn(4, 20, 256)
rel = seq(x_unfolded)
check("TextSequencer forward", rel.shape == (4, 18, 512), f"got {rel.shape}")
check("No NaN in relational", not torch.isnan(rel).any())

# 4. Full model forward — step by step to avoid training-mode shape edge cases
model = ARBModel(enable_image=False, enable_audio=False,
                 enable_vq=False, enable_graph=False,
                 enable_memory_modules=False, enable_moe=False)
model.eval()
x = torch.randint(0, 256, (2, 10))
with torch.no_grad():
    embedded = model.embedding(x)
    seq_out = model.multimodal_sequencer({'text': embedded})
    relational = seq_out['text']
    logits = model.byte_head(relational)
check("Model forward (no MoE)", logits.shape[-1] == VOCAB, f"last dim={logits.shape[-1]}")
check("Logits finite", torch.isfinite(logits).all())

# 5. Full model with MoE (step-by-step)
model2 = ARBModel(enable_image=False, enable_audio=False,
                  enable_vq=False, enable_graph=False,
                  enable_memory_modules=False, enable_moe=True)
model2.eval()
x2 = torch.randint(0, 256, (2, 10))
with torch.no_grad():
    embedded2 = model2.embedding(x2)
    seq_out2 = model2.multimodal_sequencer({'text': embedded2})
    relational2 = seq_out2['text']
    moe_out2, _ = model2.moe(relational2)
    logits2 = model2.byte_head(moe_out2)
check("Model forward (MoE)", logits2.shape[-1] == VOCAB, f"last dim={logits2.shape[-1]}")
check("Logits finite", torch.isfinite(logits2).all())

# 6. OutputRouter
from arbitor import OutputRouter
router = OutputRouter()
router_out = router(rel, training=False)
check("OutputRouter argmax", router_out.shape == (4, 18), f"got {router_out.shape}")
router_weights, router_logits = router(rel, training=True)
check("OutputRouter soft weights", router_weights.shape == (4, 18, 4))

# 7. ByteHead output distribution
from arbitor import ByteHead
head = ByteHead()
logits = head(rel)
probs = torch.softmax(logits, dim=-1)
check("ByteHead output sum approx 1",
      abs(probs[0, 0].sum().item() - 1.0) < 1e-4,
      f"sum={probs[0, 0].sum().item()}")

# 8. Token range sanity
check("Logits in valid range", 0 <= logits.max().item() <= 100)

print(f"\n{'='*50}")
if FAILED == 0:
    print("✓ All text comprehension tests passed!")
else:
    print(f"✗ {FAILED} test(s) failed")
sys.exit(FAILED)