"""Unit tests for Open-Sora 3D VAE wrapper (Phase 19 Plan 01).""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.config import ( OPEN_SORA_LATENT_CHANNELS, OPEN_SORA_SCALE_FACTOR_SPATIAL, OPEN_SORA_SCALE_FACTOR_TEMPORAL, BYTEHEAD_ACT_MAX_ITERS, VIDEOHEAD_ACT_MIN_FPS, TIMESTAMP_MAX_PERIOD, FRAME_BUFFER_LOCAL_SIZE, ) from arbitor.encoders.opensora_vae import OpenSoraVAEWrapper, load_opensora_vae def test_opensora_vae_config_constants(): assert OPEN_SORA_LATENT_CHANNELS == 4, f"got {OPEN_SORA_LATENT_CHANNELS}" assert OPEN_SORA_SCALE_FACTOR_SPATIAL == 8 assert OPEN_SORA_SCALE_FACTOR_TEMPORAL == 4 assert BYTEHEAD_ACT_MAX_ITERS == 3 assert VIDEOHEAD_ACT_MIN_FPS == 1 assert abs(TIMESTAMP_MAX_PERIOD - 10000.0) < 1e-5 assert FRAME_BUFFER_LOCAL_SIZE == 3 print(" PASS test_opensora_vae_config_constants") def test_opensora_vae_wrapper_construction(): mock_vae = type("MockVAE", (), {})() wrapper = OpenSoraVAEWrapper(mock_vae) assert wrapper.latent_channels == 4 assert wrapper.scale_factor_spatial == 8 assert wrapper.scale_factor_temporal == 4 assert hasattr(wrapper, "encode") assert hasattr(wrapper, "decode") print(" PASS test_opensora_vae_wrapper_construction") def test_opensora_vae_load_function_exists(): assert callable(load_opensora_vae), "load_opensora_vae should be callable" print(" PASS test_opensora_vae_load_function_exists") def test_video_latent_channels_updated(): from arbitor.config import VIDEO_LATENT_CHANNELS assert VIDEO_LATENT_CHANNELS == 4, f"VIDEO_LATENT_CHANNELS={VIDEO_LATENT_CHANNELS} should be 4" print(" PASS test_video_latent_channels_updated") if __name__ == "__main__": test_opensora_vae_config_constants() test_opensora_vae_wrapper_construction() test_opensora_vae_load_function_exists() test_video_latent_channels_updated() print("\nAll Open-Sora VAE tests PASS")