File size: 2,102 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Tests for VAE2DEncoder and MelSpectrogram3Band."""
import math
import os
import torch
import pytest
from arbitor.encoders.vae2d import load_vae2d
from arbitor.encoders.mel_frontend import MelSpectrogram3Band

pytestmark = pytest.mark.skipif(
    os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
    reason="VAE2D sidecar tests load the full PixArt/OpenSora VAE path",
)


def test_vae2d_encoder_output_shape():
    encoder = load_vae2d("cpu")
    img = torch.randn(1, 3, 256, 256)
    latent = encoder(img)
    assert latent.shape == (1, 4, 32, 32)


def test_vae2d_encoder_requires_divisible_by_8():
    encoder = load_vae2d("cpu")
    img = torch.randn(1, 3, 224, 224)
    latent = encoder(img)
    assert latent.shape == (1, 4, 28, 28)


def test_mel_3band_output_shape():
    mel = MelSpectrogram3Band(sample_rate=16000)
    audio = torch.randn(1, 80000)
    spec = mel(audio)
    T_mel = math.ceil(80000 / 512)
    assert spec.shape == (1, 3, 64, T_mel)


def test_mel_3band_channels_distinct():
    audio = torch.randn(1, 16000)
    spec = MelSpectrogram3Band()(audio)
    assert not torch.allclose(spec[0, 0], spec[0, 1])
    assert not torch.allclose(spec[0, 1], spec[0, 2])


def test_vae2d_frozen():
    encoder = load_vae2d("cpu")
    for p in encoder.parameters():
        assert not p.requires_grad


def test_vae2d_no_decoder():
    encoder = load_vae2d("cpu")
    total = sum(p.numel() for p in encoder.parameters())
    assert total < 60_000_000


def test_vae2d_batch_independence():
    encoder = load_vae2d("cpu")
    imgs = torch.randn(2, 3, 256, 256)
    latent = encoder(imgs)
    assert latent.shape[0] == 2
    assert not torch.allclose(latent[0], latent[1])


def test_vae2d_on_mel_spectrogram():
    encoder = load_vae2d("cpu")
    mel = MelSpectrogram3Band(sample_rate=16000)
    length = 48641
    audio = torch.randn(1, length)
    spec = mel(audio)
    assert spec.shape[-1] % 8 == 0
    latent = encoder(spec)
    assert latent.shape[1] == 4
    assert latent.shape[2] == 8  # 64 mel bands / 8 = 8
    T_latent = spec.shape[-1] // 8
    assert latent.shape[3] == T_latent