| """Health check — MoEGraph routing, VQ codebooks, OutputRouter, gradient flow. |
| Fast — tests modules directly, not through full model init. |
| Run: python testing/model/health.py |
| """ |
| import os, sys, math |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| import torch.nn.functional as F |
| from arbitor.components import MoEGraph, OutputRouter, LTIInjection |
| from arbitor.vq import SharedVQ |
| from arbitor.config import HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE |
| from arbitor.kernel.ternary_audit import audit_model |
| from arbitor.decoders import VideoHead, TalkerHead |
| import time |
|
|
| device = "cpu" |
| FAILED = 0 |
| WARNINGS = 0 |
|
|
| def check(name, condition, detail=""): |
| global FAILED |
| if condition: |
| print(f" ✓ {name}") |
| else: |
| print(f" ✗ {name} — {detail}") |
| FAILED += 1 |
|
|
| def warn(name, condition, detail=""): |
| global WARNINGS |
| if not condition: |
| print(f" ⚠{name} — {detail}") |
| WARNINGS += 1 |
|
|
| t_start = time.time() |
| print("\n=== ARB System Health Check ===\n") |
|
|
| |
| print("--- MoEGraph Routing ---") |
| mg = MoEGraph(top_k=4) |
| check(f"Experts: {mg.num_experts}", mg.num_experts >= 2) |
| check(f"Top-k: {mg.top_k}", mg.top_k >= 1) |
| check(f"Workspace: {mg.cb_dim}x{mg.core_rank}x{mg.shared_inter}", mg.cb_dim > 0) |
|
|
| |
| with torch.no_grad(): |
| cids = torch.arange(mg.num_experts) |
| cents = mg.centroids(cids).float() |
| check("Centroids finite", torch.isfinite(cents).all()) |
| warn("Centroids diverse", cents.norm(dim=-1).std() > 0.01, f"std={cents.norm(dim=-1).std():.4f}") |
|
|
| |
| check("LTI present", hasattr(mg, 'lti')) |
| A = mg.lti.get_A() |
| check("LTI A < 1", (A > 0).all() and (A < 1).all(), f"A range [{A.min():.4f}, {A.max():.4f}]") |
|
|
| |
| x = torch.randn(2, 10, HIDDEN_DIM) |
| vq_i = torch.randint(0, 1000, (2, 10)) |
| out, ponder = mg(x, vq_i) |
| check("Forward shape", out.shape == (2, 10, HIDDEN_DIM), f"got {out.shape}") |
| check("Ponder finite", torch.isfinite(ponder).all()) |
|
|
| |
| print("\n--- SharedVQ Codebook ---") |
| bridge = SharedVQ(codebook_size=4096, codebook_dim=64) |
| vq = bridge.vq |
| check(f"Codebook: {vq.codebook_size}x{vq.codebook_dim}", vq.codebook_size > 0) |
|
|
| x_vq = torch.randn(2, 10, HIDDEN_DIM) |
| combined, vq_losses, indices = bridge({'text': x_vq}) |
| check("Forward shape", combined.shape[-1] == vq.codebook_dim, f"got {combined.shape}") |
| for k in vq_losses: |
| check(f"VQ loss {k}", torch.isfinite(vq_losses[k]).all(), f"non-finite: {vq_losses[k]}") |
|
|
| uniq = indices['text'].unique().numel() |
| util = uniq / vq.codebook_size |
| warn(f"Utilization ({uniq}/{vq.codebook_size} = {util:.1%})", util > 0.01, "very low") |
|
|
| lookup = "EXACT" if vq.codebook_size <= vq.exact_lookup_max else "CANDIDATE" |
| print(f" Lookup mode: {lookup}") |
|
|
| |
| print("\n--- Composite Head (KGVQ) ---") |
| from arbitor.components import CompositeProposalHead |
| from arbitor.config import KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES |
| ch = CompositeProposalHead(dim=HIDDEN_DIM, codebook_dim=64, k_max=K_MAX_COMPOSITES, codebook_size=512) |
| print(f" k_max={ch.k_max}, codebook: {ch.kgvq.codebook_size}x{ch.kgvq.codebook_dim}") |
|
|
| pool = x.mean(dim=1) |
| cids, c_loss, _ = ch(pool) |
| check("Forward shape", cids.shape[-1] == ch.k_max, f"got {cids.shape}") |
| n_valid = (cids >= 0).sum().item() |
| warn(f"Valid proposals ({n_valid}/{cids.numel()})", n_valid > 0, "all negative") |
|
|
| |
| print("\n--- OutputRouter ---") |
| router = OutputRouter(depth=3) |
| check("3-layer", router.hidden1 is not None and router.hidden2 is not None) |
|
|
| with torch.no_grad(): |
| inf_out = router(x, training=False) |
| weights, logits = router(x, training=True) |
| check("Inference shape", inf_out.shape == (2, 10), f"got {inf_out.shape}") |
| check("Weights sum to 1", abs(weights.sum(dim=-1)[0,0].item() - 1.0) < 1e-5) |
| check("Logits finite", torch.isfinite(logits).all()) |
|
|
| |
| print("\n--- LTI Injection ---") |
| lti = LTIInjection(64) |
| h = torch.randn(2, 10, 64) |
| e = torch.randn(2, 10, 64) |
| t = torch.randn(2, 10, 64) |
| out = lti(h, e, t) |
| check("Forward shape", out.shape == h.shape) |
| check("Spectral radius < 1", lti.get_A().max().item() < 1.0) |
|
|
| |
| print("\n--- VideoHead ---") |
| vh = VideoHead() |
| x_vh = torch.randn(2, 10, HIDDEN_DIM) |
| with torch.no_grad(): |
| latents = vh(x_vh) |
| check("Forward", latents.dim() == 5, f"got {latents.shape}") |
| check("Latent channels = 4", latents.shape[1] == 4, f"got {latents.shape[1]}") |
| check("Latents finite", torch.isfinite(latents).all()) |
| print(f" latent shape: {latents.shape} ([B, 4, chunks, 64, 64])") |
| print(f" LTI active: {hasattr(vh, 'lti')}") |
|
|
| |
| print("\n--- TalkerHead ---") |
| th = TalkerHead() |
| x_th = torch.randn(2, 10, HIDDEN_DIM) |
| with torch.no_grad(): |
| logits_th = th.token_logits(x_th) |
| gen = th(x_th) |
| check("Token logits", logits_th.shape[-1] == 288, f"got {logits_th.shape}") |
| check("Generation", gen.shape == (2, 500), f"got {gen.shape}") |
| check("Has 2-layer MLP", hasattr(th, 'hidden'), "missing hidden layer") |
| if hasattr(th, 'hidden'): |
| print(f" MLP: {th.hidden._T_shape.tolist()}→288") |
|
|
| |
| elapsed = time.time() - t_start |
| print(f"\n{'='*50}") |
| print(f" {FAILED} failures, {WARNINGS} warnings ({elapsed:.1f}s)") |
| if FAILED == 0: |
| print("✓ ARB System health check passed!") |
| else: |
| print(f"✗ {FAILED} test(s) failed — review above") |
| sys.exit(FAILED) |
|
|