| """pig-vae (WanVAE) sidecar module. |
| |
| Loads from local safetensors, .pth, or diffusers AutoencoderKLWan. |
| Exposes encode() and decode() for the VideoHead training pipeline. |
| |
| Latent shape: [B, 16, T/4, H/8, W/8] for input video of T frames at HxW. |
| """ |
| import os, torch |
| import torch.nn as nn |
|
|
| _LOCAL_VAE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models", "pig-vae") |
| _VAE_CONFIG = { |
| "base_dim": 96, "z_dim": 16, "dim_mult": [1, 2, 4, 4], |
| "num_res_blocks": 2, "dropout": 0.0, |
| "temperal_downsample": [False, True, True], |
| "in_channels": 3, "out_channels": 3, |
| "scale_factor_temporal": 4, "scale_factor_spatial": 8, |
| } |
|
|
|
|
| def _freeze_sidecar(model, quantize_requested=None, quantized=False): |
| model._arb_quantize_requested = quantize_requested |
| model._arb_quantized_int8 = bool(quantized and quantize_requested == "int8") |
| model._arb_quantized = bool(quantized) |
| for p in model.parameters(): |
| p.requires_grad = False |
| return model |
|
|
|
|
| def _has_quantized_modules(model): |
| markers = ("Q", "Quanto", "Quantized", "WeightQ") |
| return any(any(marker in type(module).__name__ for marker in markers) for module in model.modules()) |
|
|
|
|
| def _quantize_int8_if_requested(model, quantize): |
| if quantize == 'int8': |
| from optimum.quanto import quantize as quanto_quantize, freeze, qint8 |
| quanto_quantize(model, weights=qint8) |
| freeze(model) |
| return _freeze_sidecar(model, quantize_requested=quantize, quantized=_has_quantized_modules(model)) |
| return _freeze_sidecar(model, quantize_requested=quantize, quantized=False) |
|
|
|
|
| def _wan_vae_cls(): |
| try: |
| from diffusers import AutoencoderKLWan |
| except ModuleNotFoundError as exc: |
| raise RuntimeError( |
| "pig-vae requires the optional diffusers dependency. " |
| "Install the project with `pip install -e .[diffusers]` in a venv " |
| "before loading or verifying pig-vae int8 quantization." |
| ) from exc |
| return AutoencoderKLWan |
|
|
|
|
| def load_vae(device='cuda', quantize='int8'): |
| """Load pig-vae from local cache or diffusers. Optionally int8 quantize.""" |
| safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors") |
| gguf_path = os.path.join(_LOCAL_VAE_DIR, "pig_wan_vae_fp32-f16.gguf") |
|
|
| if os.path.isfile(safetensors_path): |
| return _load_local(safetensors_path, device, quantize, is_safetensors=True) |
| if os.path.isfile(gguf_path): |
| return _load_gguf(gguf_path, device, quantize) |
| return _load_from_hf(device, quantize) |
|
|
|
|
| def _build_vae(): |
| AutoencoderKLWan = _wan_vae_cls() |
| return AutoencoderKLWan( |
| **_VAE_CONFIG, |
| latents_mean=[-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, |
| -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, |
| -0.1922, -0.9497, 0.2503, -0.2921], |
| latents_std=[2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, |
| 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, |
| 1.6382, 1.1253, 2.8251, 1.916], |
| ) |
|
|
|
|
| def _load_local(path, device, quantize, is_safetensors=False): |
| if is_safetensors: |
| AutoencoderKLWan = _wan_vae_cls() |
| model = AutoencoderKLWan.from_single_file(path) |
| else: |
| model = _build_vae() |
| ckpt = torch.load(path, map_location="cpu", weights_only=True) |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) |
| if missing or unexpected: |
| raise RuntimeError( |
| "pig-vae local .pth checkpoint does not match AutoencoderKLWan " |
| f"(missing={len(missing)}, unexpected={len(unexpected)})." |
| ) |
| model = model.to(device) |
| model.eval() |
| model = _quantize_int8_if_requested(model, quantize) |
| return VAEWrapper(model) |
|
|
|
|
| def _load_gguf(path, device, quantize): |
| import gguf |
| reader = gguf.GGUFReader(path) |
| state_dict = {t.name: torch.tensor(t.data) for t in reader.tensors} |
| model = _build_vae() |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing or unexpected: |
| raise RuntimeError( |
| "pig-vae local GGUF checkpoint does not match AutoencoderKLWan " |
| f"(missing={len(missing)}, unexpected={len(unexpected)})." |
| ) |
| model = model.to(device) |
| model.eval() |
| model = _quantize_int8_if_requested(model, quantize) |
| return VAEWrapper(model) |
|
|
|
|
| def _load_from_hf(device, quantize): |
| AutoencoderKLWan = _wan_vae_cls() |
| model = AutoencoderKLWan.from_pretrained( |
| "Wan-AI/Wan2.1-T2V-1.3B", subfolder="vae", |
| torch_dtype=torch.bfloat16, |
| ) |
| model = model.to(device) |
| model.eval() |
| model = _quantize_int8_if_requested(model, quantize) |
| return VAEWrapper(model) |
|
|
|
|
| class VAEWrapper(nn.Module): |
| def __init__(self, vae): |
| super().__init__() |
| self.vae = vae |
| self.latent_channels = _VAE_CONFIG["z_dim"] |
| self.scale_factor = 0.476986 |
|
|
| def encode(self, video_tensor): |
| with torch.no_grad(): |
| dist = self.vae.encode(video_tensor) |
| latents = dist.latent_dist.sample() if hasattr(dist, 'latent_dist') else dist |
| latents = latents * self.scale_factor |
| return latents |
|
|
| def decode(self, latents): |
| with torch.no_grad(): |
| latents = latents / self.scale_factor |
| video = self.vae.decode(latents) |
| video = video.sample if hasattr(video, 'sample') else video |
| return video |
|
|