File size: 1,271 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | """Cross-modal integration tests for Phase 20."""
import os
import torch
import pytest
from arbitor.main import ARBModel
pytestmark = pytest.mark.skipif(
os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
reason="full cross-modal ARBModel tests require the 3B target model and sidecar encoders",
)
def test_cross_modality_unified_latent():
model = ARBModel(enable_image=True, enable_audio=True)
model.eval()
text = torch.randint(0, 288, (1, 50))
img = torch.randn(1, 3, 256, 256)
audio = torch.randn(1, 16000 * 3)
logits, losses, indices, _ = model(text, images=img, audio=audio)
assert logits is not None
assert indices is not None
assert indices.shape[1] > 50
def test_text_only_still_works():
model = ARBModel(enable_image=False, enable_audio=False)
model.eval()
text = torch.randint(0, 288, (1, 50))
logits, losses, indices, _ = model(text)
assert logits is not None
assert logits.shape[-1] == 288
def test_image_only():
model = ARBModel(enable_image=True, enable_audio=False)
model.eval()
text = torch.randint(0, 288, (1, 10))
img = torch.randn(1, 3, 224, 224)
logits, losses, indices, _ = model(text, images=img)
assert logits is not None
assert indices.shape[1] > 10
|