ARBS / testing /vae /test_opensora_vae.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for Open-Sora 3D VAE wrapper (Phase 19 Plan 01)."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.config import (
OPEN_SORA_LATENT_CHANNELS,
OPEN_SORA_SCALE_FACTOR_SPATIAL,
OPEN_SORA_SCALE_FACTOR_TEMPORAL,
BYTEHEAD_ACT_MAX_ITERS,
VIDEOHEAD_ACT_MIN_FPS,
TIMESTAMP_MAX_PERIOD,
FRAME_BUFFER_LOCAL_SIZE,
)
from arbitor.encoders.opensora_vae import OpenSoraVAEWrapper, load_opensora_vae
def test_opensora_vae_config_constants():
assert OPEN_SORA_LATENT_CHANNELS == 4, f"got {OPEN_SORA_LATENT_CHANNELS}"
assert OPEN_SORA_SCALE_FACTOR_SPATIAL == 8
assert OPEN_SORA_SCALE_FACTOR_TEMPORAL == 4
assert BYTEHEAD_ACT_MAX_ITERS == 3
assert VIDEOHEAD_ACT_MIN_FPS == 1
assert abs(TIMESTAMP_MAX_PERIOD - 10000.0) < 1e-5
assert FRAME_BUFFER_LOCAL_SIZE == 3
print(" PASS test_opensora_vae_config_constants")
def test_opensora_vae_wrapper_construction():
mock_vae = type("MockVAE", (), {})()
wrapper = OpenSoraVAEWrapper(mock_vae)
assert wrapper.latent_channels == 4
assert wrapper.scale_factor_spatial == 8
assert wrapper.scale_factor_temporal == 4
assert hasattr(wrapper, "encode")
assert hasattr(wrapper, "decode")
print(" PASS test_opensora_vae_wrapper_construction")
def test_opensora_vae_load_function_exists():
assert callable(load_opensora_vae), "load_opensora_vae should be callable"
print(" PASS test_opensora_vae_load_function_exists")
def test_video_latent_channels_updated():
from arbitor.config import VIDEO_LATENT_CHANNELS
assert VIDEO_LATENT_CHANNELS == 4, f"VIDEO_LATENT_CHANNELS={VIDEO_LATENT_CHANNELS} should be 4"
print(" PASS test_video_latent_channels_updated")
if __name__ == "__main__":
test_opensora_vae_config_constants()
test_opensora_vae_wrapper_construction()
test_opensora_vae_load_function_exists()
test_video_latent_channels_updated()
print("\nAll Open-Sora VAE tests PASS")