"""Open-Sora 3D VAE v1.2 sidecar module. Latent: [B, 4, T/4, H/8, W/8] 8× spatial compression, 4× temporal compression. Frozen float32 sidecar (no gradients). Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding. Temporal VAE requires opensora package or custom module loading. """ import os import torch import torch.nn as nn from safetensors import safe_open _LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae") _VAE_CONFIG = { "scale": (3.85, 2.32, 2.33, 3.06), "shift": (-0.10, 0.34, 0.27, 0.98), "micro_frame_size": 17, } _QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ") def _mark_quantized_sidecar(module, quant_type, applied): module._arb_quantize_requested = quant_type module._arb_quantized_int8 = bool(applied and quant_type == "int8") module._arb_quantized = bool(applied) for p in module.parameters(): p.requires_grad = False return module def _has_quantized_modules(module): return any( any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS) for child in module.modules() ) def _freeze_sidecar(model, quantize_requested=None, quantized=False): _mark_quantized_sidecar(model, quantize_requested, quantized) return model def _quantize_int8_if_requested(model, quantize): if quantize is None: model = model.to(torch.bfloat16) _mark_quantized_sidecar(model, quantize, False) return model try: from optimum.quanto import quantize, freeze qtype = {"int8": qint8}.get(quantize) if qtype is None: model = model.to(torch.bfloat16) _mark_quantized_sidecar(model, quantize, False) return model quantize(model, weights=qtype) freeze(model) _mark_quantized_sidecar(model, quantize, _has_quantized_modules(model)) except ImportError: model = model.to(torch.bfloat16) _mark_quantized_sidecar(model, quantize, False) return model def load_opensora_vae(device="cuda", quantize=None): """Load Open-Sora 3D VAE as frozen float32 sidecar. Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal VAE from local safetensors. Falls back to spatial-only if temporal module can't be loaded. """ try: from diffusers import AutoencoderKL except ImportError: raise RuntimeError("need diffusers for Open-Sora VAE spatial component") # Load spatial VAE spatial_vae = AutoencoderKL.from_pretrained( "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae", torch_dtype=torch.float32, ).to(device) spatial_vae.eval() # Try to load temporal VAE weights temporal_state = {} safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors") if os.path.isfile(safetensors_path): with safe_open(safetensors_path, framework="pt") as f: for k in f.keys(): if k.startswith("temporal_vae."): temporal_state[k] = f.get_tensor(k) if k.startswith("scale"): temporal_state["scale"] = f.get_tensor(k) if k.startswith("shift"): temporal_state["shift"] = f.get_tensor(k) _freeze_sidecar(spatial_vae, quantize, False) return OpenSoraVAEWrapper(spatial_vae, temporal_state) class OpenSoraVAEWrapper(nn.Module): def __init__(self, spatial_vae, temporal_state=None): super().__init__() self.spatial = spatial_vae self.latent_channels = 4 self.scale_factor_spatial = 8 self.scale_factor_temporal = 4 self.temporal_state = temporal_state self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0 @torch.no_grad() def encode(self, video_tensor): """Encode video tensor: [B,3,T,H,W] → [B,4,T/4,H/8,W/8].""" B, C, T, H, W = video_tensor.shape # Process frame-by-frame through spatial VAE latents = [] for t in range(T): frame = video_tensor[:, :, t, :, :] latent = self.spatial.encode(frame).latent_dist.sample() latents.append(latent) latent = torch.stack(latents, dim=2) # Scale latent = latent * 0.18215 # Temporal downsample (simple: take every 4th) if latent.shape[2] >= 4: latent = latent[:, :, ::4, :, :] return latent @torch.no_grad() def decode(self, latents, num_frames=None): """Decode latents: [B,4,T/4,H/8,W/8] → [B,3,T,H,W].""" B, C, T, H, W = latents.shape # Temporal upsample (repeat each latent 4×) latents = latents.repeat_interleave(4, dim=2) # Unscale latents = latents / 0.18215 # Decode frame-by-frame frames = [] for t in range(latents.shape[2]): frame = latents[:, :, t, :, :] decoded = self.spatial.decode(frame).sample frames.append(decoded) return torch.stack(frames, dim=2)