| """Video/latent comprehension tests β verify VideoHead on CPU. |
| |
| Tests: VideoHead forward, cross-attention conditioning, |
| ACT halting, latent shape compatibility with pig-vae. |
| Runs on CPU β quick smoke tests only (no full video decode). |
| """ |
| import os, sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from arbitor.kernel.ternary_scale import TScaleType |
|
|
| device = "cpu" |
| FAILED = 0 |
|
|
| def check(name, condition, detail=""): |
| global FAILED |
| if condition: |
| print(f" β {name}") |
| else: |
| print(f" β {name} β {detail}") |
| FAILED += 1 |
|
|
|
|
| print("\n=== Video / Latent Comprehension ===\n") |
|
|
| from arbitor import VideoHead, HIDDEN_DIM |
|
|
| |
| head = VideoHead() |
| relational = torch.randn(2, 10, HIDDEN_DIM) |
| latents = head(relational) |
| check("VideoHead forward", latents is not None, "got None") |
| check("Latent shape", latents.shape == (2, 16, 1, 32, 32), |
| f"got {latents.shape}") |
| check("No NaN in latents", not torch.isnan(latents).any()) |
| check("Latents finite", torch.isfinite(latents).all()) |
|
|
| |
| head2 = VideoHead(max_steps=6) |
| relational_clear = torch.randn(2, 10, HIDDEN_DIM) * 10 |
| latents2 = head2(relational_clear) |
| check("ACT halting with strong signal", latents2 is not None) |
|
|
| |
| latents3 = head(torch.randn(1, 5, HIDDEN_DIM)) |
| check("Batch=1 works", latents3.shape == (1, 16, 1, 32, 32)) |
|
|
| latents4 = head(torch.randn(4, 8, HIDDEN_DIM)) |
| check("Batch=4 works", latents4.shape == (4, 16, 1, 32, 32)) |
|
|
| |
| |
| check("Latent channels = 16", latents.shape[1] == 16) |
| check("Latent spatial = 32", latents.shape[3] == 32 and latents.shape[4] == 32) |
| check("Latent temporal = 1 (single frame)", latents.shape[2] == 1) |
|
|
| |
| from arbitor import ARBModel |
| model = ARBModel(enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=False) |
| model.eval() |
| x = torch.randint(0, 256, (2, 10)) |
| with torch.no_grad(): |
| video_latents = model.video_head(relational) |
| check("VideoHead in model pipeline", video_latents.shape == (2, 16, 1, 32, 32)) |
|
|
| |
| params = sum(p.numel() for p in head.parameters() if not hasattr(p, 'T_packed')) |
| ternary_buffers = sum(b.numel() for n, b in head.named_buffers() if 'T_packed' in n) |
| check("VideoHead has ternary weights", ternary_buffers > 0, |
| f"{ternary_buffers} packed ternary entries") |
| check("VideoHead minimal float params", params < 5000, |
| f"{params} float params") |
|
|
| print(f"\n{'='*50}") |
| if FAILED == 0: |
| print("β All video comprehension tests passed!") |
| else: |
| print(f"β {FAILED} test(s) failed") |
| sys.exit(FAILED) |
|
|