ARBS / testing /model /health.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Health check — MoEGraph routing, VQ codebooks, OutputRouter, gradient flow.
Fast — tests modules directly, not through full model init.
Run: python testing/model/health.py
"""
import os, sys, math
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
import torch.nn.functional as F
from arbitor.components import MoEGraph, OutputRouter, LTIInjection
from arbitor.vq import SharedVQ
from arbitor.config import HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE
from arbitor.kernel.ternary_audit import audit_model
from arbitor.decoders import VideoHead, TalkerHead
import time
device = "cpu"
FAILED = 0
WARNINGS = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" ✓ {name}")
else:
print(f" ✗ {name} — {detail}")
FAILED += 1
def warn(name, condition, detail=""):
global WARNINGS
if not condition:
print(f" ⚠ {name} — {detail}")
WARNINGS += 1
t_start = time.time()
print("\n=== ARB System Health Check ===\n")
# ------- MoEGraph -------
print("--- MoEGraph Routing ---")
mg = MoEGraph(top_k=4)
check(f"Experts: {mg.num_experts}", mg.num_experts >= 2)
check(f"Top-k: {mg.top_k}", mg.top_k >= 1)
check(f"Workspace: {mg.cb_dim}x{mg.core_rank}x{mg.shared_inter}", mg.cb_dim > 0)
# Centroids
with torch.no_grad():
cids = torch.arange(mg.num_experts)
cents = mg.centroids(cids).float()
check("Centroids finite", torch.isfinite(cents).all())
warn("Centroids diverse", cents.norm(dim=-1).std() > 0.01, f"std={cents.norm(dim=-1).std():.4f}")
# LTI
check("LTI present", hasattr(mg, 'lti'))
A = mg.lti.get_A()
check("LTI A < 1", (A > 0).all() and (A < 1).all(), f"A range [{A.min():.4f}, {A.max():.4f}]")
# Forward
x = torch.randn(2, 10, HIDDEN_DIM)
vq_i = torch.randint(0, 1000, (2, 10))
out, ponder = mg(x, vq_i)
check("Forward shape", out.shape == (2, 10, HIDDEN_DIM), f"got {out.shape}")
check("Ponder finite", torch.isfinite(ponder).all())
# ------- SharedVQ -------
print("\n--- SharedVQ Codebook ---")
bridge = SharedVQ(codebook_size=4096, codebook_dim=64)
vq = bridge.vq
check(f"Codebook: {vq.codebook_size}x{vq.codebook_dim}", vq.codebook_size > 0)
x_vq = torch.randn(2, 10, HIDDEN_DIM)
combined, vq_losses, indices = bridge({'text': x_vq})
check("Forward shape", combined.shape[-1] == vq.codebook_dim, f"got {combined.shape}")
for k in vq_losses:
check(f"VQ loss {k}", torch.isfinite(vq_losses[k]).all(), f"non-finite: {vq_losses[k]}")
uniq = indices['text'].unique().numel()
util = uniq / vq.codebook_size
warn(f"Utilization ({uniq}/{vq.codebook_size} = {util:.1%})", util > 0.01, "very low")
lookup = "EXACT" if vq.codebook_size <= vq.exact_lookup_max else "CANDIDATE"
print(f" Lookup mode: {lookup}")
# ------- KGVQ/Composite -------
print("\n--- Composite Head (KGVQ) ---")
from arbitor.components import CompositeProposalHead
from arbitor.config import KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES
ch = CompositeProposalHead(dim=HIDDEN_DIM, codebook_dim=64, k_max=K_MAX_COMPOSITES, codebook_size=512)
print(f" k_max={ch.k_max}, codebook: {ch.kgvq.codebook_size}x{ch.kgvq.codebook_dim}")
pool = x.mean(dim=1)
cids, c_loss, _ = ch(pool)
check("Forward shape", cids.shape[-1] == ch.k_max, f"got {cids.shape}")
n_valid = (cids >= 0).sum().item()
warn(f"Valid proposals ({n_valid}/{cids.numel()})", n_valid > 0, "all negative")
# ------- OutputRouter (3-layer) -------
print("\n--- OutputRouter ---")
router = OutputRouter(depth=3)
check("3-layer", router.hidden1 is not None and router.hidden2 is not None)
with torch.no_grad():
inf_out = router(x, training=False)
weights, logits = router(x, training=True)
check("Inference shape", inf_out.shape == (2, 10), f"got {inf_out.shape}")
check("Weights sum to 1", abs(weights.sum(dim=-1)[0,0].item() - 1.0) < 1e-5)
check("Logits finite", torch.isfinite(logits).all())
# ------- LTI Injection -------
print("\n--- LTI Injection ---")
lti = LTIInjection(64)
h = torch.randn(2, 10, 64)
e = torch.randn(2, 10, 64)
t = torch.randn(2, 10, 64)
out = lti(h, e, t)
check("Forward shape", out.shape == h.shape)
check("Spectral radius < 1", lti.get_A().max().item() < 1.0)
# ------- VideoHead -------
print("\n--- VideoHead ---")
vh = VideoHead()
x_vh = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
latents = vh(x_vh)
check("Forward", latents.dim() == 5, f"got {latents.shape}")
check("Latent channels = 4", latents.shape[1] == 4, f"got {latents.shape[1]}")
check("Latents finite", torch.isfinite(latents).all())
print(f" latent shape: {latents.shape} ([B, 4, chunks, 64, 64])")
print(f" LTI active: {hasattr(vh, 'lti')}")
# ------- TalkerHead -------
print("\n--- TalkerHead ---")
th = TalkerHead()
x_th = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
logits_th = th.token_logits(x_th)
gen = th(x_th)
check("Token logits", logits_th.shape[-1] == 288, f"got {logits_th.shape}")
check("Generation", gen.shape == (2, 500), f"got {gen.shape}")
check("Has 2-layer MLP", hasattr(th, 'hidden'), "missing hidden layer")
if hasattr(th, 'hidden'):
print(f" MLP: {th.hidden._T_shape.tolist()}→288")
# ------- Summary -------
elapsed = time.time() - t_start
print(f"\n{'='*50}")
print(f" {FAILED} failures, {WARNINGS} warnings ({elapsed:.1f}s)")
if FAILED == 0:
print("✓ ARB System health check passed!")
else:
print(f"✗ {FAILED} test(s) failed — review above")
sys.exit(FAILED)