File size: 5,441 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """Health check β MoEGraph routing, VQ codebooks, OutputRouter, gradient flow.
Fast β tests modules directly, not through full model init.
Run: python testing/model/health.py
"""
import os, sys, math
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
import torch.nn.functional as F
from arbitor.components import MoEGraph, OutputRouter, LTIInjection
from arbitor.vq import SharedVQ
from arbitor.config import HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE
from arbitor.kernel.ternary_audit import audit_model
from arbitor.decoders import VideoHead, TalkerHead
import time
device = "cpu"
FAILED = 0
WARNINGS = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" β {name}")
else:
print(f" β {name} β {detail}")
FAILED += 1
def warn(name, condition, detail=""):
global WARNINGS
if not condition:
print(f" β {name} β {detail}")
WARNINGS += 1
t_start = time.time()
print("\n=== ARB System Health Check ===\n")
# ------- MoEGraph -------
print("--- MoEGraph Routing ---")
mg = MoEGraph(top_k=4)
check(f"Experts: {mg.num_experts}", mg.num_experts >= 2)
check(f"Top-k: {mg.top_k}", mg.top_k >= 1)
check(f"Workspace: {mg.cb_dim}x{mg.core_rank}x{mg.shared_inter}", mg.cb_dim > 0)
# Centroids
with torch.no_grad():
cids = torch.arange(mg.num_experts)
cents = mg.centroids(cids).float()
check("Centroids finite", torch.isfinite(cents).all())
warn("Centroids diverse", cents.norm(dim=-1).std() > 0.01, f"std={cents.norm(dim=-1).std():.4f}")
# LTI
check("LTI present", hasattr(mg, 'lti'))
A = mg.lti.get_A()
check("LTI A < 1", (A > 0).all() and (A < 1).all(), f"A range [{A.min():.4f}, {A.max():.4f}]")
# Forward
x = torch.randn(2, 10, HIDDEN_DIM)
vq_i = torch.randint(0, 1000, (2, 10))
out, ponder = mg(x, vq_i)
check("Forward shape", out.shape == (2, 10, HIDDEN_DIM), f"got {out.shape}")
check("Ponder finite", torch.isfinite(ponder).all())
# ------- SharedVQ -------
print("\n--- SharedVQ Codebook ---")
bridge = SharedVQ(codebook_size=4096, codebook_dim=64)
vq = bridge.vq
check(f"Codebook: {vq.codebook_size}x{vq.codebook_dim}", vq.codebook_size > 0)
x_vq = torch.randn(2, 10, HIDDEN_DIM)
combined, vq_losses, indices = bridge({'text': x_vq})
check("Forward shape", combined.shape[-1] == vq.codebook_dim, f"got {combined.shape}")
for k in vq_losses:
check(f"VQ loss {k}", torch.isfinite(vq_losses[k]).all(), f"non-finite: {vq_losses[k]}")
uniq = indices['text'].unique().numel()
util = uniq / vq.codebook_size
warn(f"Utilization ({uniq}/{vq.codebook_size} = {util:.1%})", util > 0.01, "very low")
lookup = "EXACT" if vq.codebook_size <= vq.exact_lookup_max else "CANDIDATE"
print(f" Lookup mode: {lookup}")
# ------- KGVQ/Composite -------
print("\n--- Composite Head (KGVQ) ---")
from arbitor.components import CompositeProposalHead
from arbitor.config import KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES
ch = CompositeProposalHead(dim=HIDDEN_DIM, codebook_dim=64, k_max=K_MAX_COMPOSITES, codebook_size=512)
print(f" k_max={ch.k_max}, codebook: {ch.kgvq.codebook_size}x{ch.kgvq.codebook_dim}")
pool = x.mean(dim=1)
cids, c_loss, _ = ch(pool)
check("Forward shape", cids.shape[-1] == ch.k_max, f"got {cids.shape}")
n_valid = (cids >= 0).sum().item()
warn(f"Valid proposals ({n_valid}/{cids.numel()})", n_valid > 0, "all negative")
# ------- OutputRouter (3-layer) -------
print("\n--- OutputRouter ---")
router = OutputRouter(depth=3)
check("3-layer", router.hidden1 is not None and router.hidden2 is not None)
with torch.no_grad():
inf_out = router(x, training=False)
weights, logits = router(x, training=True)
check("Inference shape", inf_out.shape == (2, 10), f"got {inf_out.shape}")
check("Weights sum to 1", abs(weights.sum(dim=-1)[0,0].item() - 1.0) < 1e-5)
check("Logits finite", torch.isfinite(logits).all())
# ------- LTI Injection -------
print("\n--- LTI Injection ---")
lti = LTIInjection(64)
h = torch.randn(2, 10, 64)
e = torch.randn(2, 10, 64)
t = torch.randn(2, 10, 64)
out = lti(h, e, t)
check("Forward shape", out.shape == h.shape)
check("Spectral radius < 1", lti.get_A().max().item() < 1.0)
# ------- VideoHead -------
print("\n--- VideoHead ---")
vh = VideoHead()
x_vh = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
latents = vh(x_vh)
check("Forward", latents.dim() == 5, f"got {latents.shape}")
check("Latent channels = 4", latents.shape[1] == 4, f"got {latents.shape[1]}")
check("Latents finite", torch.isfinite(latents).all())
print(f" latent shape: {latents.shape} ([B, 4, chunks, 64, 64])")
print(f" LTI active: {hasattr(vh, 'lti')}")
# ------- TalkerHead -------
print("\n--- TalkerHead ---")
th = TalkerHead()
x_th = torch.randn(2, 10, HIDDEN_DIM)
with torch.no_grad():
logits_th = th.token_logits(x_th)
gen = th(x_th)
check("Token logits", logits_th.shape[-1] == 288, f"got {logits_th.shape}")
check("Generation", gen.shape == (2, 500), f"got {gen.shape}")
check("Has 2-layer MLP", hasattr(th, 'hidden'), "missing hidden layer")
if hasattr(th, 'hidden'):
print(f" MLP: {th.hidden._T_shape.tolist()}β288")
# ------- Summary -------
elapsed = time.time() - t_start
print(f"\n{'='*50}")
print(f" {FAILED} failures, {WARNINGS} warnings ({elapsed:.1f}s)")
if FAILED == 0:
print("β ARB System health check passed!")
else:
print(f"β {FAILED} test(s) failed β review above")
sys.exit(FAILED)
|