| """Open-Sora 3D VAE v1.2 sidecar module. |
| |
| Latent: [B, 4, T/4, H/8, W/8] |
| 8× spatial compression, 4× temporal compression. |
| Frozen float32 sidecar (no gradients). |
| |
| Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding. |
| Temporal VAE requires opensora package or custom module loading. |
| """ |
| import os |
| import torch |
| import torch.nn as nn |
| from safetensors import safe_open |
|
|
| _LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae") |
| _VAE_CONFIG = { |
| "scale": (3.85, 2.32, 2.33, 3.06), |
| "shift": (-0.10, 0.34, 0.27, 0.98), |
| "micro_frame_size": 17, |
| } |
| _QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ") |
|
|
|
|
| def _mark_quantized_sidecar(module, quant_type, applied): |
| module._arb_quantize_requested = quant_type |
| module._arb_quantized_int8 = bool(applied and quant_type == "int8") |
| module._arb_quantized = bool(applied) |
| for p in module.parameters(): |
| p.requires_grad = False |
| return module |
|
|
|
|
| def _has_quantized_modules(module): |
| return any( |
| any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS) |
| for child in module.modules() |
| ) |
|
|
|
|
| def _freeze_sidecar(model, quantize_requested=None, quantized=False): |
| _mark_quantized_sidecar(model, quantize_requested, quantized) |
| return model |
|
|
|
|
| def _quantize_int8_if_requested(model, quantize): |
| if quantize is None: |
| model = model.to(torch.bfloat16) |
| _mark_quantized_sidecar(model, quantize, False) |
| return model |
| try: |
| from optimum.quanto import quantize, freeze |
| qtype = {"int8": qint8}.get(quantize) |
| if qtype is None: |
| model = model.to(torch.bfloat16) |
| _mark_quantized_sidecar(model, quantize, False) |
| return model |
| quantize(model, weights=qtype) |
| freeze(model) |
| _mark_quantized_sidecar(model, quantize, _has_quantized_modules(model)) |
| except ImportError: |
| model = model.to(torch.bfloat16) |
| _mark_quantized_sidecar(model, quantize, False) |
| return model |
|
|
|
|
| def load_opensora_vae(device="cuda", quantize=None): |
| """Load Open-Sora 3D VAE as frozen float32 sidecar. |
| |
| Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal |
| VAE from local safetensors. Falls back to spatial-only if temporal |
| module can't be loaded. |
| """ |
| try: |
| from diffusers import AutoencoderKL |
| except ImportError: |
| raise RuntimeError("need diffusers for Open-Sora VAE spatial component") |
|
|
| |
| spatial_vae = AutoencoderKL.from_pretrained( |
| "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", |
| subfolder="vae", |
| torch_dtype=torch.float32, |
| ).to(device) |
| spatial_vae.eval() |
|
|
| |
| temporal_state = {} |
| safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors") |
| if os.path.isfile(safetensors_path): |
| with safe_open(safetensors_path, framework="pt") as f: |
| for k in f.keys(): |
| if k.startswith("temporal_vae."): |
| temporal_state[k] = f.get_tensor(k) |
| if k.startswith("scale"): |
| temporal_state["scale"] = f.get_tensor(k) |
| if k.startswith("shift"): |
| temporal_state["shift"] = f.get_tensor(k) |
|
|
| _freeze_sidecar(spatial_vae, quantize, False) |
| return OpenSoraVAEWrapper(spatial_vae, temporal_state) |
|
|
|
|
| class OpenSoraVAEWrapper(nn.Module): |
| def __init__(self, spatial_vae, temporal_state=None): |
| super().__init__() |
| self.spatial = spatial_vae |
| self.latent_channels = 4 |
| self.scale_factor_spatial = 8 |
| self.scale_factor_temporal = 4 |
| self.temporal_state = temporal_state |
| self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0 |
|
|
| @torch.no_grad() |
| def encode(self, video_tensor): |
| """Encode video tensor: [B,3,T,H,W] → [B,4,T/4,H/8,W/8].""" |
| B, C, T, H, W = video_tensor.shape |
| |
| latents = [] |
| for t in range(T): |
| frame = video_tensor[:, :, t, :, :] |
| latent = self.spatial.encode(frame).latent_dist.sample() |
| latents.append(latent) |
| latent = torch.stack(latents, dim=2) |
| |
| latent = latent * 0.18215 |
| |
| if latent.shape[2] >= 4: |
| latent = latent[:, :, ::4, :, :] |
| return latent |
|
|
| @torch.no_grad() |
| def decode(self, latents, num_frames=None): |
| """Decode latents: [B,4,T/4,H/8,W/8] → [B,3,T,H,W].""" |
| B, C, T, H, W = latents.shape |
| |
| latents = latents.repeat_interleave(4, dim=2) |
| |
| latents = latents / 0.18215 |
| |
| frames = [] |
| for t in range(latents.shape[2]): |
| frame = latents[:, :, t, :, :] |
| decoded = self.spatial.decode(frame).sample |
| frames.append(decoded) |
| return torch.stack(frames, dim=2) |
|
|