ARBS / testing /model /video-comprehension.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Video/latent comprehension tests β€” verify VideoHead on CPU.
Tests: VideoHead forward, cross-attention conditioning,
ACT halting, latent shape compatibility with pig-vae.
Runs on CPU β€” quick smoke tests only (no full video decode).
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor.kernel.ternary_scale import TScaleType
device = "cpu"
FAILED = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" βœ“ {name}")
else:
print(f" βœ— {name} β€” {detail}")
FAILED += 1
print("\n=== Video / Latent Comprehension ===\n")
from arbitor import VideoHead, HIDDEN_DIM
# 1. VideoHead forward
head = VideoHead()
relational = torch.randn(2, 10, HIDDEN_DIM)
latents = head(relational)
check("VideoHead forward", latents is not None, "got None")
check("Latent shape", latents.shape == (2, 16, 1, 32, 32),
f"got {latents.shape}")
check("No NaN in latents", not torch.isnan(latents).any())
check("Latents finite", torch.isfinite(latents).all())
# 2. ACT halting (should stop early for clear conditioning)
head2 = VideoHead(max_steps=6)
relational_clear = torch.randn(2, 10, HIDDEN_DIM) * 10 # strong signal
latents2 = head2(relational_clear)
check("ACT halting with strong signal", latents2 is not None)
# 3. Latents with different batch sizes
latents3 = head(torch.randn(1, 5, HIDDEN_DIM))
check("Batch=1 works", latents3.shape == (1, 16, 1, 32, 32))
latents4 = head(torch.randn(4, 8, HIDDEN_DIM))
check("Batch=4 works", latents4.shape == (4, 16, 1, 32, 32))
# 4. pig-vae latent compatibility check
# The pig-vae expects [B, 16, T, H, W] latents
check("Latent channels = 16", latents.shape[1] == 16)
check("Latent spatial = 32", latents.shape[3] == 32 and latents.shape[4] == 32)
check("Latent temporal = 1 (single frame)", latents.shape[2] == 1)
# 5. Model with VideoHead
from arbitor import ARBModel
model = ARBModel(enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=False)
model.eval()
x = torch.randint(0, 256, (2, 10))
with torch.no_grad():
video_latents = model.video_head(relational)
check("VideoHead in model pipeline", video_latents.shape == (2, 16, 1, 32, 32))
# 6. Quantization effects (VideoHead uses TernaryScaleTensor internally)
params = sum(p.numel() for p in head.parameters() if not hasattr(p, 'T_packed'))
ternary_buffers = sum(b.numel() for n, b in head.named_buffers() if 'T_packed' in n)
check("VideoHead has ternary weights", ternary_buffers > 0,
f"{ternary_buffers} packed ternary entries")
check("VideoHead minimal float params", params < 5000,
f"{params} float params")
print(f"\n{'='*50}")
if FAILED == 0:
print("βœ“ All video comprehension tests passed!")
else:
print(f"βœ— {FAILED} test(s) failed")
sys.exit(FAILED)