"""Image comprehension tests — verify the image pipeline on CPU. Tests: ImageSequencer forward, DINOv2 feature extraction, patch_proj → unfold → projection → norm pipeline. Runs on CPU — large model download first time. """ import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch from arbitor.kernel.ternary_scale import TScaleType device = "cpu" FAILED = 0 def check(name, condition, detail=""): global FAILED if condition: print(f" ✓ {name}") else: print(f" ✗ {name} — {detail}") FAILED += 1 print("\n=== Image Comprehension ===\n") print("Loading ImageSequencer (downloads DINOv2-small on first run)...") from arbitor import ARBModel, HIDDEN_DIM # 1. ImageSequencer forward with synthetic image model = ARBModel(enable_image=True, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False) model.eval() img = torch.randn(1, 3, 224, 224) with torch.no_grad(): seq_out = model.image_sequencer(img) check("ImageSequencer forward", seq_out is not None, "got None") check("Output shape", seq_out.shape[-1] == HIDDEN_DIM, f"last dim={seq_out.shape[-1]}") check("No NaN in image features", not torch.isnan(seq_out).any()) # 2. Image features are finite and reasonable check("Image features finite", torch.isfinite(seq_out).all()) check("Image features have variance", seq_out.std().item() > 0.001, f"std={seq_out.std().item()}") # 3. Full model with image input with torch.no_grad(): logits, losses, _, _ = model(x=None, images=img, targets=torch.randint(0, 256, (1, 100))) check("Model with image forward", logits is not None) if losses is not None: check("Image loss is finite", torch.isfinite(losses.total)) # 4. Modality gate del model # free memory from arbitor import ARBModel as ARBModel2 model2 = ARBModel2(enable_image=True, enable_audio=False, enable_vq=True, enable_graph=True, enable_memory_modules=False, enable_moe=True) model2.eval() with torch.no_grad(): seq_out2 = model2.image_sequencer(img) check("Image with VQ pipeline", seq_out2 is not None, "got None") print(f"\n{'='*50}") if FAILED == 0: print("✓ All image comprehension tests passed!") else: print(f"✗ {FAILED} test(s) failed") sys.exit(FAILED)