File size: 5,441 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Health check β€” MoEGraph routing, VQ codebooks, OutputRouter, gradient flow.
Fast β€” tests modules directly, not through full model init.
Run: python testing/model/health.py
"""
import os, sys, math
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
import torch.nn.functional as F
from arbitor.components import MoEGraph, OutputRouter, LTIInjection
from arbitor.vq import SharedVQ
from arbitor.config import HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE
from arbitor.kernel.ternary_audit import audit_model
from arbitor.decoders import VideoHead, TalkerHead
import time

device = "cpu"
FAILED = 0
WARNINGS = 0

def check(name, condition, detail=""):
    global FAILED
    if condition:
        print(f"  βœ“ {name}")
    else:
        print(f"  βœ— {name} β€” {detail}")
        FAILED += 1

def warn(name, condition, detail=""):
    global WARNINGS
    if not condition:
        print(f"  ⚠ {name} β€” {detail}")
        WARNINGS += 1

t_start = time.time()
print("\n=== ARB System Health Check ===\n")

# ------- MoEGraph -------
print("--- MoEGraph Routing ---")
mg = MoEGraph(top_k=4)
check(f"Experts: {mg.num_experts}", mg.num_experts >= 2)
check(f"Top-k: {mg.top_k}", mg.top_k >= 1)
check(f"Workspace: {mg.cb_dim}x{mg.core_rank}x{mg.shared_inter}", mg.cb_dim > 0)

# Centroids
with torch.no_grad():
    cids = torch.arange(mg.num_experts)
    cents = mg.centroids(cids).float()
    check("Centroids finite", torch.isfinite(cents).all())
    warn("Centroids diverse", cents.norm(dim=-1).std() > 0.01, f"std={cents.norm(dim=-1).std():.4f}")

# LTI
check("LTI present", hasattr(mg, 'lti'))
A = mg.lti.get_A()
check("LTI A < 1", (A > 0).all() and (A < 1).all(), f"A range [{A.min():.4f}, {A.max():.4f}]")

# Forward
x = torch.randn(2, 10, HIDDEN_DIM)
vq_i = torch.randint(0, 1000, (2, 10))
out, ponder = mg(x, vq_i)
check("Forward shape", out.shape == (2, 10, HIDDEN_DIM), f"got {out.shape}")
check("Ponder finite", torch.isfinite(ponder).all())

# ------- SharedVQ -------
print("\n--- SharedVQ Codebook ---")
bridge = SharedVQ(codebook_size=4096, codebook_dim=64)
vq = bridge.vq
check(f"Codebook: {vq.codebook_size}x{vq.codebook_dim}", vq.codebook_size > 0)

x_vq = torch.randn(2, 10, HIDDEN_DIM)
combined, vq_losses, indices = bridge({'text': x_vq})
check("Forward shape", combined.shape[-1] == vq.codebook_dim, f"got {combined.shape}")
for k in vq_losses:
    check(f"VQ loss {k}", torch.isfinite(vq_losses[k]).all(), f"non-finite: {vq_losses[k]}")

uniq = indices['text'].unique().numel()
util = uniq / vq.codebook_size
warn(f"Utilization ({uniq}/{vq.codebook_size} = {util:.1%})", util > 0.01, "very low")

lookup = "EXACT" if vq.codebook_size <= vq.exact_lookup_max else "CANDIDATE"
print(f"  Lookup mode: {lookup}")

# ------- KGVQ/Composite -------
print("\n--- Composite Head (KGVQ) ---")
from arbitor.components import CompositeProposalHead
from arbitor.config import KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES
ch = CompositeProposalHead(dim=HIDDEN_DIM, codebook_dim=64, k_max=K_MAX_COMPOSITES, codebook_size=512)
print(f"  k_max={ch.k_max}, codebook: {ch.kgvq.codebook_size}x{ch.kgvq.codebook_dim}")

pool = x.mean(dim=1)
cids, c_loss, _ = ch(pool)
check("Forward shape", cids.shape[-1] == ch.k_max, f"got {cids.shape}")
n_valid = (cids >= 0).sum().item()
warn(f"Valid proposals ({n_valid}/{cids.numel()})", n_valid > 0, "all negative")

# ------- OutputRouter (3-layer) -------
print("\n--- OutputRouter ---")
router = OutputRouter(depth=3)
check("3-layer", router.hidden1 is not None and router.hidden2 is not None)

with torch.no_grad():
    inf_out = router(x, training=False)
    weights, logits = router(x, training=True)
check("Inference shape", inf_out.shape == (2, 10), f"got {inf_out.shape}")
check("Weights sum to 1", abs(weights.sum(dim=-1)[0,0].item() - 1.0) < 1e-5)
check("Logits finite", torch.isfinite(logits).all())

# ------- LTI Injection -------
print("\n--- LTI Injection ---")
lti = LTIInjection(64)
h = torch.randn(2, 10, 64)
e = torch.randn(2, 10, 64)
t = torch.randn(2, 10, 64)
out = lti(h, e, t)
check("Forward shape", out.shape == h.shape)
check("Spectral radius < 1", lti.get_A().max().item() < 1.0)

# ------- VideoHead -------
print("\n--- VideoHead ---")
vh = VideoHead()
x_vh = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
    latents = vh(x_vh)
check("Forward", latents.dim() == 5, f"got {latents.shape}")
check("Latent channels = 4", latents.shape[1] == 4, f"got {latents.shape[1]}")
check("Latents finite", torch.isfinite(latents).all())
print(f"  latent shape: {latents.shape} ([B, 4, chunks, 64, 64])")
print(f"  LTI active: {hasattr(vh, 'lti')}")

# ------- TalkerHead -------
print("\n--- TalkerHead ---")
th = TalkerHead()
x_th = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
    logits_th = th.token_logits(x_th)
    gen = th(x_th)
check("Token logits", logits_th.shape[-1] == 288, f"got {logits_th.shape}")
check("Generation", gen.shape == (2, 500), f"got {gen.shape}")
check("Has 2-layer MLP", hasattr(th, 'hidden'), "missing hidden layer")
if hasattr(th, 'hidden'):
    print(f"  MLP: {th.hidden._T_shape.tolist()}β†’288")

# ------- Summary -------
elapsed = time.time() - t_start
print(f"\n{'='*50}")
print(f"  {FAILED} failures, {WARNINGS} warnings ({elapsed:.1f}s)")
if FAILED == 0:
    print("βœ“ ARB System health check passed!")
else:
    print(f"βœ— {FAILED} test(s) failed β€” review above")
sys.exit(FAILED)