File size: 5,083 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Open-Sora 3D VAE v1.2 sidecar module.

Latent: [B, 4, T/4, H/8, W/8]
8× spatial compression, 4× temporal compression.
Frozen float32 sidecar (no gradients).

Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding.
Temporal VAE requires opensora package or custom module loading.
"""
import os
import torch
import torch.nn as nn
from safetensors import safe_open

_LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae")
_VAE_CONFIG = {
    "scale": (3.85, 2.32, 2.33, 3.06),
    "shift": (-0.10, 0.34, 0.27, 0.98),
    "micro_frame_size": 17,
}
_QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ")


def _mark_quantized_sidecar(module, quant_type, applied):
    module._arb_quantize_requested = quant_type
    module._arb_quantized_int8 = bool(applied and quant_type == "int8")
    module._arb_quantized = bool(applied)
    for p in module.parameters():
        p.requires_grad = False
    return module


def _has_quantized_modules(module):
    return any(
        any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS)
        for child in module.modules()
    )


def _freeze_sidecar(model, quantize_requested=None, quantized=False):
    _mark_quantized_sidecar(model, quantize_requested, quantized)
    return model


def _quantize_int8_if_requested(model, quantize):
    if quantize is None:
        model = model.to(torch.bfloat16)
        _mark_quantized_sidecar(model, quantize, False)
        return model
    try:
        from optimum.quanto import quantize, freeze
        qtype = {"int8": qint8}.get(quantize)
        if qtype is None:
            model = model.to(torch.bfloat16)
            _mark_quantized_sidecar(model, quantize, False)
            return model
        quantize(model, weights=qtype)
        freeze(model)
        _mark_quantized_sidecar(model, quantize, _has_quantized_modules(model))
    except ImportError:
        model = model.to(torch.bfloat16)
        _mark_quantized_sidecar(model, quantize, False)
    return model


def load_opensora_vae(device="cuda", quantize=None):
    """Load Open-Sora 3D VAE as frozen float32 sidecar.

    Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal
    VAE from local safetensors. Falls back to spatial-only if temporal
    module can't be loaded.
    """
    try:
        from diffusers import AutoencoderKL
    except ImportError:
        raise RuntimeError("need diffusers for Open-Sora VAE spatial component")

    # Load spatial VAE
    spatial_vae = AutoencoderKL.from_pretrained(
        "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
        subfolder="vae",
        torch_dtype=torch.float32,
    ).to(device)
    spatial_vae.eval()

    # Try to load temporal VAE weights
    temporal_state = {}
    safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
    if os.path.isfile(safetensors_path):
        with safe_open(safetensors_path, framework="pt") as f:
            for k in f.keys():
                if k.startswith("temporal_vae."):
                    temporal_state[k] = f.get_tensor(k)
                if k.startswith("scale"):
                    temporal_state["scale"] = f.get_tensor(k)
                if k.startswith("shift"):
                    temporal_state["shift"] = f.get_tensor(k)

    _freeze_sidecar(spatial_vae, quantize, False)
    return OpenSoraVAEWrapper(spatial_vae, temporal_state)


class OpenSoraVAEWrapper(nn.Module):
    def __init__(self, spatial_vae, temporal_state=None):
        super().__init__()
        self.spatial = spatial_vae
        self.latent_channels = 4
        self.scale_factor_spatial = 8
        self.scale_factor_temporal = 4
        self.temporal_state = temporal_state
        self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0

    @torch.no_grad()
    def encode(self, video_tensor):
        """Encode video tensor: [B,3,T,H,W] → [B,4,T/4,H/8,W/8]."""
        B, C, T, H, W = video_tensor.shape
        # Process frame-by-frame through spatial VAE
        latents = []
        for t in range(T):
            frame = video_tensor[:, :, t, :, :]
            latent = self.spatial.encode(frame).latent_dist.sample()
            latents.append(latent)
        latent = torch.stack(latents, dim=2)
        # Scale
        latent = latent * 0.18215
        # Temporal downsample (simple: take every 4th)
        if latent.shape[2] >= 4:
            latent = latent[:, :, ::4, :, :]
        return latent

    @torch.no_grad()
    def decode(self, latents, num_frames=None):
        """Decode latents: [B,4,T/4,H/8,W/8] → [B,3,T,H,W]."""
        B, C, T, H, W = latents.shape
        # Temporal upsample (repeat each latent 4×)
        latents = latents.repeat_interleave(4, dim=2)
        # Unscale
        latents = latents / 0.18215
        # Decode frame-by-frame
        frames = []
        for t in range(latents.shape[2]):
            frame = latents[:, :, t, :, :]
            decoded = self.spatial.decode(frame).sample
            frames.append(decoded)
        return torch.stack(frames, dim=2)