File size: 2,019 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Unit tests for Open-Sora 3D VAE wrapper (Phase 19 Plan 01)."""
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.config import (
    OPEN_SORA_LATENT_CHANNELS,
    OPEN_SORA_SCALE_FACTOR_SPATIAL,
    OPEN_SORA_SCALE_FACTOR_TEMPORAL,
    BYTEHEAD_ACT_MAX_ITERS,
    VIDEOHEAD_ACT_MIN_FPS,
    TIMESTAMP_MAX_PERIOD,
    FRAME_BUFFER_LOCAL_SIZE,
)
from arbitor.encoders.opensora_vae import OpenSoraVAEWrapper, load_opensora_vae


def test_opensora_vae_config_constants():
    assert OPEN_SORA_LATENT_CHANNELS == 4, f"got {OPEN_SORA_LATENT_CHANNELS}"
    assert OPEN_SORA_SCALE_FACTOR_SPATIAL == 8
    assert OPEN_SORA_SCALE_FACTOR_TEMPORAL == 4
    assert BYTEHEAD_ACT_MAX_ITERS == 3
    assert VIDEOHEAD_ACT_MIN_FPS == 1
    assert abs(TIMESTAMP_MAX_PERIOD - 10000.0) < 1e-5
    assert FRAME_BUFFER_LOCAL_SIZE == 3
    print(" PASS test_opensora_vae_config_constants")


def test_opensora_vae_wrapper_construction():
    mock_vae = type("MockVAE", (), {})()
    wrapper = OpenSoraVAEWrapper(mock_vae)
    assert wrapper.latent_channels == 4
    assert wrapper.scale_factor_spatial == 8
    assert wrapper.scale_factor_temporal == 4
    assert hasattr(wrapper, "encode")
    assert hasattr(wrapper, "decode")
    print(" PASS test_opensora_vae_wrapper_construction")


def test_opensora_vae_load_function_exists():
    assert callable(load_opensora_vae), "load_opensora_vae should be callable"
    print(" PASS test_opensora_vae_load_function_exists")


def test_video_latent_channels_updated():
    from arbitor.config import VIDEO_LATENT_CHANNELS
    assert VIDEO_LATENT_CHANNELS == 4, f"VIDEO_LATENT_CHANNELS={VIDEO_LATENT_CHANNELS} should be 4"
    print(" PASS test_video_latent_channels_updated")


if __name__ == "__main__":
    test_opensora_vae_config_constants()
    test_opensora_vae_wrapper_construction()
    test_opensora_vae_load_function_exists()
    test_video_latent_channels_updated()
    print("\nAll Open-Sora VAE tests PASS")