File size: 2,915 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """Video/latent comprehension tests — verify VideoHead on CPU.
Tests: VideoHead forward, cross-attention conditioning,
ACT halting, latent shape compatibility with pig-vae.
Runs on CPU — quick smoke tests only (no full video decode).
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor.kernel.ternary_scale import TScaleType
device = "cpu"
FAILED = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" ✓ {name}")
else:
print(f" ✗ {name} — {detail}")
FAILED += 1
print("\n=== Video / Latent Comprehension ===\n")
from arbitor import VideoHead, HIDDEN_DIM
# 1. VideoHead forward
head = VideoHead()
relational = torch.randn(2, 10, HIDDEN_DIM)
latents = head(relational)
check("VideoHead forward", latents is not None, "got None")
check("Latent shape", latents.shape == (2, 16, 1, 32, 32),
f"got {latents.shape}")
check("No NaN in latents", not torch.isnan(latents).any())
check("Latents finite", torch.isfinite(latents).all())
# 2. ACT halting (should stop early for clear conditioning)
head2 = VideoHead(max_steps=6)
relational_clear = torch.randn(2, 10, HIDDEN_DIM) * 10 # strong signal
latents2 = head2(relational_clear)
check("ACT halting with strong signal", latents2 is not None)
# 3. Latents with different batch sizes
latents3 = head(torch.randn(1, 5, HIDDEN_DIM))
check("Batch=1 works", latents3.shape == (1, 16, 1, 32, 32))
latents4 = head(torch.randn(4, 8, HIDDEN_DIM))
check("Batch=4 works", latents4.shape == (4, 16, 1, 32, 32))
# 4. pig-vae latent compatibility check
# The pig-vae expects [B, 16, T, H, W] latents
check("Latent channels = 16", latents.shape[1] == 16)
check("Latent spatial = 32", latents.shape[3] == 32 and latents.shape[4] == 32)
check("Latent temporal = 1 (single frame)", latents.shape[2] == 1)
# 5. Model with VideoHead
from arbitor import ARBModel
model = ARBModel(enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=False)
model.eval()
x = torch.randint(0, 256, (2, 10))
with torch.no_grad():
video_latents = model.video_head(relational)
check("VideoHead in model pipeline", video_latents.shape == (2, 16, 1, 32, 32))
# 6. Quantization effects (VideoHead uses TernaryScaleTensor internally)
params = sum(p.numel() for p in head.parameters() if not hasattr(p, 'T_packed'))
ternary_buffers = sum(b.numel() for n, b in head.named_buffers() if 'T_packed' in n)
check("VideoHead has ternary weights", ternary_buffers > 0,
f"{ternary_buffers} packed ternary entries")
check("VideoHead minimal float params", params < 5000,
f"{params} float params")
print(f"\n{'='*50}")
if FAILED == 0:
print("✓ All video comprehension tests passed!")
else:
print(f"✗ {FAILED} test(s) failed")
sys.exit(FAILED)
|