File size: 2,915 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Video/latent comprehension tests — verify VideoHead on CPU.

Tests: VideoHead forward, cross-attention conditioning,
      ACT halting, latent shape compatibility with pig-vae.
Runs on CPU — quick smoke tests only (no full video decode).
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor.kernel.ternary_scale import TScaleType

device = "cpu"
FAILED = 0

def check(name, condition, detail=""):
    global FAILED
    if condition:
        print(f"  ✓ {name}")
    else:
        print(f"  ✗ {name} — {detail}")
        FAILED += 1


print("\n=== Video / Latent Comprehension ===\n")

from arbitor import VideoHead, HIDDEN_DIM

# 1. VideoHead forward
head = VideoHead()
relational = torch.randn(2, 10, HIDDEN_DIM)
latents = head(relational)
check("VideoHead forward", latents is not None, "got None")
check("Latent shape", latents.shape == (2, 16, 1, 32, 32),
      f"got {latents.shape}")
check("No NaN in latents", not torch.isnan(latents).any())
check("Latents finite", torch.isfinite(latents).all())

# 2. ACT halting (should stop early for clear conditioning)
head2 = VideoHead(max_steps=6)
relational_clear = torch.randn(2, 10, HIDDEN_DIM) * 10  # strong signal
latents2 = head2(relational_clear)
check("ACT halting with strong signal", latents2 is not None)

# 3. Latents with different batch sizes
latents3 = head(torch.randn(1, 5, HIDDEN_DIM))
check("Batch=1 works", latents3.shape == (1, 16, 1, 32, 32))

latents4 = head(torch.randn(4, 8, HIDDEN_DIM))
check("Batch=4 works", latents4.shape == (4, 16, 1, 32, 32))

# 4. pig-vae latent compatibility check
# The pig-vae expects [B, 16, T, H, W] latents
check("Latent channels = 16", latents.shape[1] == 16)
check("Latent spatial = 32", latents.shape[3] == 32 and latents.shape[4] == 32)
check("Latent temporal = 1 (single frame)", latents.shape[2] == 1)

# 5. Model with VideoHead
from arbitor import ARBModel
model = ARBModel(enable_image=False, enable_audio=False,
                 enable_vq=False, enable_graph=False,
                 enable_memory_modules=False, enable_moe=False)
model.eval()
x = torch.randint(0, 256, (2, 10))
with torch.no_grad():
    video_latents = model.video_head(relational)
check("VideoHead in model pipeline", video_latents.shape == (2, 16, 1, 32, 32))

# 6. Quantization effects (VideoHead uses TernaryScaleTensor internally)
params = sum(p.numel() for p in head.parameters() if not hasattr(p, 'T_packed'))
ternary_buffers = sum(b.numel() for n, b in head.named_buffers() if 'T_packed' in n)
check("VideoHead has ternary weights", ternary_buffers > 0,
      f"{ternary_buffers} packed ternary entries")
check("VideoHead minimal float params", params < 5000,
      f"{params} float params")

print(f"\n{'='*50}")
if FAILED == 0:
    print("✓ All video comprehension tests passed!")
else:
    print(f"✗ {FAILED} test(s) failed")
sys.exit(FAILED)