ARBS / tests /test_cross_modal.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Cross-modal integration tests for Phase 20."""
import os
import torch
import pytest
from arbitor.main import ARBModel
pytestmark = pytest.mark.skipif(
os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
reason="full cross-modal ARBModel tests require the 3B target model and sidecar encoders",
)
def test_cross_modality_unified_latent():
model = ARBModel(enable_image=True, enable_audio=True)
model.eval()
text = torch.randint(0, 288, (1, 50))
img = torch.randn(1, 3, 256, 256)
audio = torch.randn(1, 16000 * 3)
logits, losses, indices, _ = model(text, images=img, audio=audio)
assert logits is not None
assert indices is not None
assert indices.shape[1] > 50
def test_text_only_still_works():
model = ARBModel(enable_image=False, enable_audio=False)
model.eval()
text = torch.randint(0, 288, (1, 50))
logits, losses, indices, _ = model(text)
assert logits is not None
assert logits.shape[-1] == 288
def test_image_only():
model = ARBModel(enable_image=True, enable_audio=False)
model.eval()
text = torch.randint(0, 288, (1, 10))
img = torch.randn(1, 3, 224, 224)
logits, losses, indices, _ = model(text, images=img)
assert logits is not None
assert indices.shape[1] > 10