File size: 5,083 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """Open-Sora 3D VAE v1.2 sidecar module.
Latent: [B, 4, T/4, H/8, W/8]
8× spatial compression, 4× temporal compression.
Frozen float32 sidecar (no gradients).
Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding.
Temporal VAE requires opensora package or custom module loading.
"""
import os
import torch
import torch.nn as nn
from safetensors import safe_open
_LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae")
_VAE_CONFIG = {
"scale": (3.85, 2.32, 2.33, 3.06),
"shift": (-0.10, 0.34, 0.27, 0.98),
"micro_frame_size": 17,
}
_QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ")
def _mark_quantized_sidecar(module, quant_type, applied):
module._arb_quantize_requested = quant_type
module._arb_quantized_int8 = bool(applied and quant_type == "int8")
module._arb_quantized = bool(applied)
for p in module.parameters():
p.requires_grad = False
return module
def _has_quantized_modules(module):
return any(
any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS)
for child in module.modules()
)
def _freeze_sidecar(model, quantize_requested=None, quantized=False):
_mark_quantized_sidecar(model, quantize_requested, quantized)
return model
def _quantize_int8_if_requested(model, quantize):
if quantize is None:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
try:
from optimum.quanto import quantize, freeze
qtype = {"int8": qint8}.get(quantize)
if qtype is None:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
quantize(model, weights=qtype)
freeze(model)
_mark_quantized_sidecar(model, quantize, _has_quantized_modules(model))
except ImportError:
model = model.to(torch.bfloat16)
_mark_quantized_sidecar(model, quantize, False)
return model
def load_opensora_vae(device="cuda", quantize=None):
"""Load Open-Sora 3D VAE as frozen float32 sidecar.
Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal
VAE from local safetensors. Falls back to spatial-only if temporal
module can't be loaded.
"""
try:
from diffusers import AutoencoderKL
except ImportError:
raise RuntimeError("need diffusers for Open-Sora VAE spatial component")
# Load spatial VAE
spatial_vae = AutoencoderKL.from_pretrained(
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
torch_dtype=torch.float32,
).to(device)
spatial_vae.eval()
# Try to load temporal VAE weights
temporal_state = {}
safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
if os.path.isfile(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
for k in f.keys():
if k.startswith("temporal_vae."):
temporal_state[k] = f.get_tensor(k)
if k.startswith("scale"):
temporal_state["scale"] = f.get_tensor(k)
if k.startswith("shift"):
temporal_state["shift"] = f.get_tensor(k)
_freeze_sidecar(spatial_vae, quantize, False)
return OpenSoraVAEWrapper(spatial_vae, temporal_state)
class OpenSoraVAEWrapper(nn.Module):
def __init__(self, spatial_vae, temporal_state=None):
super().__init__()
self.spatial = spatial_vae
self.latent_channels = 4
self.scale_factor_spatial = 8
self.scale_factor_temporal = 4
self.temporal_state = temporal_state
self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0
@torch.no_grad()
def encode(self, video_tensor):
"""Encode video tensor: [B,3,T,H,W] → [B,4,T/4,H/8,W/8]."""
B, C, T, H, W = video_tensor.shape
# Process frame-by-frame through spatial VAE
latents = []
for t in range(T):
frame = video_tensor[:, :, t, :, :]
latent = self.spatial.encode(frame).latent_dist.sample()
latents.append(latent)
latent = torch.stack(latents, dim=2)
# Scale
latent = latent * 0.18215
# Temporal downsample (simple: take every 4th)
if latent.shape[2] >= 4:
latent = latent[:, :, ::4, :, :]
return latent
@torch.no_grad()
def decode(self, latents, num_frames=None):
"""Decode latents: [B,4,T/4,H/8,W/8] → [B,3,T,H,W]."""
B, C, T, H, W = latents.shape
# Temporal upsample (repeat each latent 4×)
latents = latents.repeat_interleave(4, dim=2)
# Unscale
latents = latents / 0.18215
# Decode frame-by-frame
frames = []
for t in range(latents.shape[2]):
frame = latents[:, :, t, :, :]
decoded = self.spatial.decode(frame).sample
frames.append(decoded)
return torch.stack(frames, dim=2)
|