"""Health check — MoEGraph routing, VQ codebooks, OutputRouter, gradient flow. Fast — tests modules directly, not through full model init. Run: python testing/model/health.py """ import os, sys, math sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch import torch.nn.functional as F from arbitor.components import MoEGraph, OutputRouter, LTIInjection from arbitor.vq import SharedVQ from arbitor.config import HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE from arbitor.kernel.ternary_audit import audit_model from arbitor.decoders import VideoHead, TalkerHead import time device = "cpu" FAILED = 0 WARNINGS = 0 def check(name, condition, detail=""): global FAILED if condition: print(f" ✓ {name}") else: print(f" ✗ {name} — {detail}") FAILED += 1 def warn(name, condition, detail=""): global WARNINGS if not condition: print(f" ⚠ {name} — {detail}") WARNINGS += 1 t_start = time.time() print("\n=== ARB System Health Check ===\n") # ------- MoEGraph ------- print("--- MoEGraph Routing ---") mg = MoEGraph(top_k=4) check(f"Experts: {mg.num_experts}", mg.num_experts >= 2) check(f"Top-k: {mg.top_k}", mg.top_k >= 1) check(f"Workspace: {mg.cb_dim}x{mg.core_rank}x{mg.shared_inter}", mg.cb_dim > 0) # Centroids with torch.no_grad(): cids = torch.arange(mg.num_experts) cents = mg.centroids(cids).float() check("Centroids finite", torch.isfinite(cents).all()) warn("Centroids diverse", cents.norm(dim=-1).std() > 0.01, f"std={cents.norm(dim=-1).std():.4f}") # LTI check("LTI present", hasattr(mg, 'lti')) A = mg.lti.get_A() check("LTI A < 1", (A > 0).all() and (A < 1).all(), f"A range [{A.min():.4f}, {A.max():.4f}]") # Forward x = torch.randn(2, 10, HIDDEN_DIM) vq_i = torch.randint(0, 1000, (2, 10)) out, ponder = mg(x, vq_i) check("Forward shape", out.shape == (2, 10, HIDDEN_DIM), f"got {out.shape}") check("Ponder finite", torch.isfinite(ponder).all()) # ------- SharedVQ ------- print("\n--- SharedVQ Codebook ---") bridge = SharedVQ(codebook_size=4096, codebook_dim=64) vq = bridge.vq check(f"Codebook: {vq.codebook_size}x{vq.codebook_dim}", vq.codebook_size > 0) x_vq = torch.randn(2, 10, HIDDEN_DIM) combined, vq_losses, indices = bridge({'text': x_vq}) check("Forward shape", combined.shape[-1] == vq.codebook_dim, f"got {combined.shape}") for k in vq_losses: check(f"VQ loss {k}", torch.isfinite(vq_losses[k]).all(), f"non-finite: {vq_losses[k]}") uniq = indices['text'].unique().numel() util = uniq / vq.codebook_size warn(f"Utilization ({uniq}/{vq.codebook_size} = {util:.1%})", util > 0.01, "very low") lookup = "EXACT" if vq.codebook_size <= vq.exact_lookup_max else "CANDIDATE" print(f" Lookup mode: {lookup}") # ------- KGVQ/Composite ------- print("\n--- Composite Head (KGVQ) ---") from arbitor.components import CompositeProposalHead from arbitor.config import KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES ch = CompositeProposalHead(dim=HIDDEN_DIM, codebook_dim=64, k_max=K_MAX_COMPOSITES, codebook_size=512) print(f" k_max={ch.k_max}, codebook: {ch.kgvq.codebook_size}x{ch.kgvq.codebook_dim}") pool = x.mean(dim=1) cids, c_loss, _ = ch(pool) check("Forward shape", cids.shape[-1] == ch.k_max, f"got {cids.shape}") n_valid = (cids >= 0).sum().item() warn(f"Valid proposals ({n_valid}/{cids.numel()})", n_valid > 0, "all negative") # ------- OutputRouter (3-layer) ------- print("\n--- OutputRouter ---") router = OutputRouter(depth=3) check("3-layer", router.hidden1 is not None and router.hidden2 is not None) with torch.no_grad(): inf_out = router(x, training=False) weights, logits = router(x, training=True) check("Inference shape", inf_out.shape == (2, 10), f"got {inf_out.shape}") check("Weights sum to 1", abs(weights.sum(dim=-1)[0,0].item() - 1.0) < 1e-5) check("Logits finite", torch.isfinite(logits).all()) # ------- LTI Injection ------- print("\n--- LTI Injection ---") lti = LTIInjection(64) h = torch.randn(2, 10, 64) e = torch.randn(2, 10, 64) t = torch.randn(2, 10, 64) out = lti(h, e, t) check("Forward shape", out.shape == h.shape) check("Spectral radius < 1", lti.get_A().max().item() < 1.0) # ------- VideoHead ------- print("\n--- VideoHead ---") vh = VideoHead() x_vh = torch.randn(2, 10, HIDDEN_DIM) with torch.no_grad(): latents = vh(x_vh) check("Forward", latents.dim() == 5, f"got {latents.shape}") check("Latent channels = 4", latents.shape[1] == 4, f"got {latents.shape[1]}") check("Latents finite", torch.isfinite(latents).all()) print(f" latent shape: {latents.shape} ([B, 4, chunks, 64, 64])") print(f" LTI active: {hasattr(vh, 'lti')}") # ------- TalkerHead ------- print("\n--- TalkerHead ---") th = TalkerHead() x_th = torch.randn(2, 10, HIDDEN_DIM) with torch.no_grad(): logits_th = th.token_logits(x_th) gen = th(x_th) check("Token logits", logits_th.shape[-1] == 288, f"got {logits_th.shape}") check("Generation", gen.shape == (2, 500), f"got {gen.shape}") check("Has 2-layer MLP", hasattr(th, 'hidden'), "missing hidden layer") if hasattr(th, 'hidden'): print(f" MLP: {th.hidden._T_shape.tolist()}→288") # ------- Summary ------- elapsed = time.time() - t_start print(f"\n{'='*50}") print(f" {FAILED} failures, {WARNINGS} warnings ({elapsed:.1f}s)") if FAILED == 0: print("✓ ARB System health check passed!") else: print(f"✗ {FAILED} test(s) failed — review above") sys.exit(FAILED)