"""Tests for VAE2DEncoder and MelSpectrogram3Band.""" import math import os import torch import pytest from arbitor.encoders.vae2d import load_vae2d from arbitor.encoders.mel_frontend import MelSpectrogram3Band pytestmark = pytest.mark.skipif( os.environ.get("ARB_RUN_SLOW_TESTS") != "1", reason="VAE2D sidecar tests load the full PixArt/OpenSora VAE path", ) def test_vae2d_encoder_output_shape(): encoder = load_vae2d("cpu") img = torch.randn(1, 3, 256, 256) latent = encoder(img) assert latent.shape == (1, 4, 32, 32) def test_vae2d_encoder_requires_divisible_by_8(): encoder = load_vae2d("cpu") img = torch.randn(1, 3, 224, 224) latent = encoder(img) assert latent.shape == (1, 4, 28, 28) def test_mel_3band_output_shape(): mel = MelSpectrogram3Band(sample_rate=16000) audio = torch.randn(1, 80000) spec = mel(audio) T_mel = math.ceil(80000 / 512) assert spec.shape == (1, 3, 64, T_mel) def test_mel_3band_channels_distinct(): audio = torch.randn(1, 16000) spec = MelSpectrogram3Band()(audio) assert not torch.allclose(spec[0, 0], spec[0, 1]) assert not torch.allclose(spec[0, 1], spec[0, 2]) def test_vae2d_frozen(): encoder = load_vae2d("cpu") for p in encoder.parameters(): assert not p.requires_grad def test_vae2d_no_decoder(): encoder = load_vae2d("cpu") total = sum(p.numel() for p in encoder.parameters()) assert total < 60_000_000 def test_vae2d_batch_independence(): encoder = load_vae2d("cpu") imgs = torch.randn(2, 3, 256, 256) latent = encoder(imgs) assert latent.shape[0] == 2 assert not torch.allclose(latent[0], latent[1]) def test_vae2d_on_mel_spectrogram(): encoder = load_vae2d("cpu") mel = MelSpectrogram3Band(sample_rate=16000) length = 48641 audio = torch.randn(1, length) spec = mel(audio) assert spec.shape[-1] % 8 == 0 latent = encoder(spec) assert latent.shape[1] == 4 assert latent.shape[2] == 8 # 64 mel bands / 8 = 8 T_latent = spec.shape[-1] // 8 assert latent.shape[3] == T_latent