ARBS / arbitor /encoders /pig_vae.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""pig-vae (WanVAE) sidecar module.
Loads from local safetensors, .pth, or diffusers AutoencoderKLWan.
Exposes encode() and decode() for the VideoHead training pipeline.
Latent shape: [B, 16, T/4, H/8, W/8] for input video of T frames at HxW.
"""
import os, torch
import torch.nn as nn
_LOCAL_VAE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models", "pig-vae")
_VAE_CONFIG = {
"base_dim": 96, "z_dim": 16, "dim_mult": [1, 2, 4, 4],
"num_res_blocks": 2, "dropout": 0.0,
"temperal_downsample": [False, True, True],
"in_channels": 3, "out_channels": 3,
"scale_factor_temporal": 4, "scale_factor_spatial": 8,
}
def _freeze_sidecar(model, quantize_requested=None, quantized=False):
model._arb_quantize_requested = quantize_requested
model._arb_quantized_int8 = bool(quantized and quantize_requested == "int8")
model._arb_quantized = bool(quantized)
for p in model.parameters():
p.requires_grad = False
return model
def _has_quantized_modules(model):
markers = ("Q", "Quanto", "Quantized", "WeightQ")
return any(any(marker in type(module).__name__ for marker in markers) for module in model.modules())
def _quantize_int8_if_requested(model, quantize):
if quantize == 'int8':
from optimum.quanto import quantize as quanto_quantize, freeze, qint8
quanto_quantize(model, weights=qint8)
freeze(model)
return _freeze_sidecar(model, quantize_requested=quantize, quantized=_has_quantized_modules(model))
return _freeze_sidecar(model, quantize_requested=quantize, quantized=False)
def _wan_vae_cls():
try:
from diffusers import AutoencoderKLWan
except ModuleNotFoundError as exc:
raise RuntimeError(
"pig-vae requires the optional diffusers dependency. "
"Install the project with `pip install -e .[diffusers]` in a venv "
"before loading or verifying pig-vae int8 quantization."
) from exc
return AutoencoderKLWan
def load_vae(device='cuda', quantize='int8'):
"""Load pig-vae from local cache or diffusers. Optionally int8 quantize."""
safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
gguf_path = os.path.join(_LOCAL_VAE_DIR, "pig_wan_vae_fp32-f16.gguf")
if os.path.isfile(safetensors_path):
return _load_local(safetensors_path, device, quantize, is_safetensors=True)
if os.path.isfile(gguf_path):
return _load_gguf(gguf_path, device, quantize)
return _load_from_hf(device, quantize)
def _build_vae():
AutoencoderKLWan = _wan_vae_cls()
return AutoencoderKLWan(
**_VAE_CONFIG,
latents_mean=[-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653,
-0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921],
latents_std=[2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708,
2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916],
)
def _load_local(path, device, quantize, is_safetensors=False):
if is_safetensors:
AutoencoderKLWan = _wan_vae_cls()
model = AutoencoderKLWan.from_single_file(path)
else:
model = _build_vae()
ckpt = torch.load(path, map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(ckpt, strict=False)
if missing or unexpected:
raise RuntimeError(
"pig-vae local .pth checkpoint does not match AutoencoderKLWan "
f"(missing={len(missing)}, unexpected={len(unexpected)})."
)
model = model.to(device)
model.eval()
model = _quantize_int8_if_requested(model, quantize)
return VAEWrapper(model)
def _load_gguf(path, device, quantize):
import gguf
reader = gguf.GGUFReader(path)
state_dict = {t.name: torch.tensor(t.data) for t in reader.tensors}
model = _build_vae()
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
raise RuntimeError(
"pig-vae local GGUF checkpoint does not match AutoencoderKLWan "
f"(missing={len(missing)}, unexpected={len(unexpected)})."
)
model = model.to(device)
model.eval()
model = _quantize_int8_if_requested(model, quantize)
return VAEWrapper(model)
def _load_from_hf(device, quantize):
AutoencoderKLWan = _wan_vae_cls()
model = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B", subfolder="vae",
torch_dtype=torch.bfloat16,
)
model = model.to(device)
model.eval()
model = _quantize_int8_if_requested(model, quantize)
return VAEWrapper(model)
class VAEWrapper(nn.Module):
def __init__(self, vae):
super().__init__()
self.vae = vae
self.latent_channels = _VAE_CONFIG["z_dim"]
self.scale_factor = 0.476986
def encode(self, video_tensor):
with torch.no_grad():
dist = self.vae.encode(video_tensor)
latents = dist.latent_dist.sample() if hasattr(dist, 'latent_dist') else dist
latents = latents * self.scale_factor
return latents
def decode(self, latents):
with torch.no_grad():
latents = latents / self.scale_factor
video = self.vae.decode(latents)
video = video.sample if hasattr(video, 'sample') else video
return video