ARBS / testing /model /image-comprehension.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Image comprehension tests β€” verify the image pipeline on CPU.
Tests: ImageSequencer forward, DINOv2 feature extraction,
patch_proj β†’ unfold β†’ projection β†’ norm pipeline.
Runs on CPU β€” large model download first time.
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor.kernel.ternary_scale import TScaleType
device = "cpu"
FAILED = 0
def check(name, condition, detail=""):
global FAILED
if condition:
print(f" βœ“ {name}")
else:
print(f" βœ— {name} β€” {detail}")
FAILED += 1
print("\n=== Image Comprehension ===\n")
print("Loading ImageSequencer (downloads DINOv2-small on first run)...")
from arbitor import ARBModel, HIDDEN_DIM
# 1. ImageSequencer forward with synthetic image
model = ARBModel(enable_image=True, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=False)
model.eval()
img = torch.randn(1, 3, 224, 224)
with torch.no_grad():
seq_out = model.image_sequencer(img)
check("ImageSequencer forward", seq_out is not None, "got None")
check("Output shape", seq_out.shape[-1] == HIDDEN_DIM,
f"last dim={seq_out.shape[-1]}")
check("No NaN in image features", not torch.isnan(seq_out).any())
# 2. Image features are finite and reasonable
check("Image features finite", torch.isfinite(seq_out).all())
check("Image features have variance", seq_out.std().item() > 0.001,
f"std={seq_out.std().item()}")
# 3. Full model with image input
with torch.no_grad():
logits, losses, _, _ = model(x=None, images=img,
targets=torch.randint(0, 256, (1, 100)))
check("Model with image forward", logits is not None)
if losses is not None:
check("Image loss is finite", torch.isfinite(losses.total))
# 4. Modality gate
del model # free memory
from arbitor import ARBModel as ARBModel2
model2 = ARBModel2(enable_image=True, enable_audio=False,
enable_vq=True, enable_graph=True,
enable_memory_modules=False, enable_moe=True)
model2.eval()
with torch.no_grad():
seq_out2 = model2.image_sequencer(img)
check("Image with VQ pipeline", seq_out2 is not None, "got None")
print(f"\n{'='*50}")
if FAILED == 0:
print("βœ“ All image comprehension tests passed!")
else:
print(f"βœ— {FAILED} test(s) failed")
sys.exit(FAILED)