| """Unit tests for Open-Sora 3D VAE wrapper (Phase 19 Plan 01).""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.config import ( |
| OPEN_SORA_LATENT_CHANNELS, |
| OPEN_SORA_SCALE_FACTOR_SPATIAL, |
| OPEN_SORA_SCALE_FACTOR_TEMPORAL, |
| BYTEHEAD_ACT_MAX_ITERS, |
| VIDEOHEAD_ACT_MIN_FPS, |
| TIMESTAMP_MAX_PERIOD, |
| FRAME_BUFFER_LOCAL_SIZE, |
| ) |
| from arbitor.encoders.opensora_vae import OpenSoraVAEWrapper, load_opensora_vae |
|
|
|
|
| def test_opensora_vae_config_constants(): |
| assert OPEN_SORA_LATENT_CHANNELS == 4, f"got {OPEN_SORA_LATENT_CHANNELS}" |
| assert OPEN_SORA_SCALE_FACTOR_SPATIAL == 8 |
| assert OPEN_SORA_SCALE_FACTOR_TEMPORAL == 4 |
| assert BYTEHEAD_ACT_MAX_ITERS == 3 |
| assert VIDEOHEAD_ACT_MIN_FPS == 1 |
| assert abs(TIMESTAMP_MAX_PERIOD - 10000.0) < 1e-5 |
| assert FRAME_BUFFER_LOCAL_SIZE == 3 |
| print(" PASS test_opensora_vae_config_constants") |
|
|
|
|
| def test_opensora_vae_wrapper_construction(): |
| mock_vae = type("MockVAE", (), {})() |
| wrapper = OpenSoraVAEWrapper(mock_vae) |
| assert wrapper.latent_channels == 4 |
| assert wrapper.scale_factor_spatial == 8 |
| assert wrapper.scale_factor_temporal == 4 |
| assert hasattr(wrapper, "encode") |
| assert hasattr(wrapper, "decode") |
| print(" PASS test_opensora_vae_wrapper_construction") |
|
|
|
|
| def test_opensora_vae_load_function_exists(): |
| assert callable(load_opensora_vae), "load_opensora_vae should be callable" |
| print(" PASS test_opensora_vae_load_function_exists") |
|
|
|
|
| def test_video_latent_channels_updated(): |
| from arbitor.config import VIDEO_LATENT_CHANNELS |
| assert VIDEO_LATENT_CHANNELS == 4, f"VIDEO_LATENT_CHANNELS={VIDEO_LATENT_CHANNELS} should be 4" |
| print(" PASS test_video_latent_channels_updated") |
|
|
|
|
| if __name__ == "__main__": |
| test_opensora_vae_config_constants() |
| test_opensora_vae_wrapper_construction() |
| test_opensora_vae_load_function_exists() |
| test_video_latent_channels_updated() |
| print("\nAll Open-Sora VAE tests PASS") |
|
|