"""Tests for VAE2DSequencer and VAEAudioSequencer.""" import os import torch import pytest from arbitor.sequencers import VAE2DSequencer, VAEAudioSequencer pytestmark = pytest.mark.skipif( os.environ.get("ARB_RUN_SLOW_TESTS") != "1", reason="VAE sequencer tests load full sidecar encoders and 7168d projections", ) def test_vae2d_sequencer_output_shape(): seq = VAE2DSequencer() img = torch.randn(2, 3, 256, 256) out = seq(img) assert out.shape == (2, 1024, 7168) def test_vae2d_sequencer_224(): seq = VAE2DSequencer() img = torch.randn(1, 3, 224, 224) out = seq(img) assert out.shape == (1, 784, 7168) def test_vae2d_sequencer_different_resolutions(): seq = VAE2DSequencer() for h, w in [(128, 128), (256, 192), (512, 512)]: img = torch.randn(1, 3, h, w) out = seq(img) assert out.shape[-1] == 7168 assert out.shape[1] == (h // 8) * (w // 8) def test_vae2d_sequencer_no_vit_params(): seq = VAE2DSequencer() n_params = sum(p.numel() for p in seq.parameters() if p.requires_grad) assert n_params < 100_000 def test_vae2d_sequencer_output_range(): seq = VAE2DSequencer() img = torch.randn(1, 3, 256, 256) out = seq(img) assert torch.isfinite(out).all() assert out.abs().mean() < 100.0 def test_vae2d_sequencer_batch(): seq = VAE2DSequencer() imgs = torch.randn(4, 3, 256, 256) out = seq(imgs) assert out.shape[0] == 4 def test_vae_audio_sequencer_output_shape(): seq = VAEAudioSequencer() audio = torch.randn(1, 48000) out = seq(audio) assert out.shape[-1] == 7168 assert out.shape[0] == 1 def test_vae_audio_sequencer_mono_tensor(): seq = VAEAudioSequencer() audio = torch.randn(1, 1, 16000) out = seq(audio) assert out.shape[-1] == 7168 def test_vae_audio_sequencer_batch(): seq = VAEAudioSequencer() audios = torch.randn(2, 16000) out = seq(audios) assert out.shape[0] == 2 def test_vae_audio_no_moonshine_params(): seq = VAEAudioSequencer() n_trainable = sum(p.numel() for p in seq.parameters() if p.requires_grad) assert n_trainable < 100_000 def test_vae_audio_output_range(): seq = VAEAudioSequencer() audio = torch.randn(1, 16000) out = seq(audio) assert torch.isfinite(out).all() def test_vae_audio_variable_length(): seq = VAEAudioSequencer() short = torch.randn(1, 8000) long = torch.randn(1, 48000) out_short = seq(short) out_long = seq(long) assert out_short.shape[1] < out_long.shape[1]