| """Image comprehension tests β verify the image pipeline on CPU. |
| |
| Tests: ImageSequencer forward, DINOv2 feature extraction, |
| patch_proj β unfold β projection β norm pipeline. |
| Runs on CPU β large model download first time. |
| """ |
| import os, sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from arbitor.kernel.ternary_scale import TScaleType |
|
|
| device = "cpu" |
| FAILED = 0 |
|
|
| def check(name, condition, detail=""): |
| global FAILED |
| if condition: |
| print(f" β {name}") |
| else: |
| print(f" β {name} β {detail}") |
| FAILED += 1 |
|
|
|
|
| print("\n=== Image Comprehension ===\n") |
| print("Loading ImageSequencer (downloads DINOv2-small on first run)...") |
|
|
| from arbitor import ARBModel, HIDDEN_DIM |
|
|
| |
| model = ARBModel(enable_image=True, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=False) |
| model.eval() |
| img = torch.randn(1, 3, 224, 224) |
| with torch.no_grad(): |
| seq_out = model.image_sequencer(img) |
| check("ImageSequencer forward", seq_out is not None, "got None") |
| check("Output shape", seq_out.shape[-1] == HIDDEN_DIM, |
| f"last dim={seq_out.shape[-1]}") |
| check("No NaN in image features", not torch.isnan(seq_out).any()) |
|
|
| |
| check("Image features finite", torch.isfinite(seq_out).all()) |
| check("Image features have variance", seq_out.std().item() > 0.001, |
| f"std={seq_out.std().item()}") |
|
|
| |
| with torch.no_grad(): |
| logits, losses, _, _ = model(x=None, images=img, |
| targets=torch.randint(0, 256, (1, 100))) |
| check("Model with image forward", logits is not None) |
| if losses is not None: |
| check("Image loss is finite", torch.isfinite(losses.total)) |
|
|
| |
| del model |
| from arbitor import ARBModel as ARBModel2 |
| model2 = ARBModel2(enable_image=True, enable_audio=False, |
| enable_vq=True, enable_graph=True, |
| enable_memory_modules=False, enable_moe=True) |
| model2.eval() |
| with torch.no_grad(): |
| seq_out2 = model2.image_sequencer(img) |
| check("Image with VQ pipeline", seq_out2 is not None, "got None") |
|
|
| print(f"\n{'='*50}") |
| if FAILED == 0: |
| print("β All image comprehension tests passed!") |
| else: |
| print(f"β {FAILED} test(s) failed") |
| sys.exit(FAILED) |
|
|