ARBS / arbitor /encoders /opensora_vae.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Open-Sora 3D VAE v1.2 sidecar module.
Latent: [B, 4, T/4, H/8, W/8]
8× spatial compression, 4× temporal compression.
Frozen float32 sidecar (no gradients).
Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding.
Temporal VAE requires opensora package or custom module loading.
"""
import os
import torch
import torch.nn as nn
from safetensors import safe_open
_LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae")
_VAE_CONFIG = {
"scale": (3.85, 2.32, 2.33, 3.06),
"shift": (-0.10, 0.34, 0.27, 0.98),
"micro_frame_size": 17,
}
_QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ")
def _mark_quantized_sidecar(module, quant_type, applied):
module._arb_quantize_requested = quant_type
module._arb_quantized_int8 = bool(applied and quant_type == "int8")
module._arb_quantized = bool(applied)
for p in module.parameters():
p.requires_grad = False
return module
def _has_quantized_modules(module):
return any(
any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS)
for child in module.modules()
)
def _freeze_sidecar(model, quantize_requested=None, quantized=False):
_mark_quantized_sidecar(model, quantize_requested, quantized)
return model
def _quantize_int8_if_requested(model, quantize):
if quantize is None:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
try:
from optimum.quanto import quantize, freeze
qtype = {"int8": qint8}.get(quantize)
if qtype is None:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
quantize(model, weights=qtype)
freeze(model)
_mark_quantized_sidecar(model, quantize, _has_quantized_modules(model))
except ImportError:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
def load_opensora_vae(device="cuda", quantize=None):
"""Load Open-Sora 3D VAE as frozen float32 sidecar.
Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal
VAE from local safetensors. Falls back to spatial-only if temporal
module can't be loaded.
"""
try:
from diffusers import AutoencoderKL
except ImportError:
raise RuntimeError("need diffusers for Open-Sora VAE spatial component")
# Load spatial VAE
spatial_vae = AutoencoderKL.from_pretrained(
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
torch_dtype=torch.float32,
).to(device)
spatial_vae.eval()
# Try to load temporal VAE weights
temporal_state = {}
safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
if os.path.isfile(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
for k in f.keys():
if k.startswith("temporal_vae."):
temporal_state[k] = f.get_tensor(k)
if k.startswith("scale"):
temporal_state["scale"] = f.get_tensor(k)
if k.startswith("shift"):
temporal_state["shift"] = f.get_tensor(k)
_freeze_sidecar(spatial_vae, quantize, False)
return OpenSoraVAEWrapper(spatial_vae, temporal_state)
class OpenSoraVAEWrapper(nn.Module):
def __init__(self, spatial_vae, temporal_state=None):
super().__init__()
self.spatial = spatial_vae
self.latent_channels = 4
self.scale_factor_spatial = 8
self.scale_factor_temporal = 4
self.temporal_state = temporal_state
self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0
@torch.no_grad()
def encode(self, video_tensor):
"""Encode video tensor: [B,3,T,H,W] → [B,4,T/4,H/8,W/8]."""
B, C, T, H, W = video_tensor.shape
# Process frame-by-frame through spatial VAE
latents = []
for t in range(T):
frame = video_tensor[:, :, t, :, :]
latent = self.spatial.encode(frame).latent_dist.sample()
latents.append(latent)
latent = torch.stack(latents, dim=2)
# Scale
latent = latent * 0.18215
# Temporal downsample (simple: take every 4th)
if latent.shape[2] >= 4:
latent = latent[:, :, ::4, :, :]
return latent
@torch.no_grad()
def decode(self, latents, num_frames=None):
"""Decode latents: [B,4,T/4,H/8,W/8] → [B,3,T,H,W]."""
B, C, T, H, W = latents.shape
# Temporal upsample (repeat each latent 4×)
latents = latents.repeat_interleave(4, dim=2)
# Unscale
latents = latents / 0.18215
# Decode frame-by-frame
frames = []
for t in range(latents.shape[2]):
frame = latents[:, :, t, :, :]
decoded = self.spatial.decode(frame).sample
frames.append(decoded)
return torch.stack(frames, dim=2)