"""Cross-modal integration tests for Phase 20.""" import os import torch import pytest from arbitor.main import ARBModel pytestmark = pytest.mark.skipif( os.environ.get("ARB_RUN_SLOW_TESTS") != "1", reason="full cross-modal ARBModel tests require the 3B target model and sidecar encoders", ) def test_cross_modality_unified_latent(): model = ARBModel(enable_image=True, enable_audio=True) model.eval() text = torch.randint(0, 288, (1, 50)) img = torch.randn(1, 3, 256, 256) audio = torch.randn(1, 16000 * 3) logits, losses, indices, _ = model(text, images=img, audio=audio) assert logits is not None assert indices is not None assert indices.shape[1] > 50 def test_text_only_still_works(): model = ARBModel(enable_image=False, enable_audio=False) model.eval() text = torch.randint(0, 288, (1, 50)) logits, losses, indices, _ = model(text) assert logits is not None assert logits.shape[-1] == 288 def test_image_only(): model = ARBModel(enable_image=True, enable_audio=False) model.eval() text = torch.randint(0, 288, (1, 10)) img = torch.randn(1, 3, 224, 224) logits, losses, indices, _ = model(text, images=img) assert logits is not None assert indices.shape[1] > 10