"""pig-vae (WanVAE) sidecar module. Loads from local safetensors, .pth, or diffusers AutoencoderKLWan. Exposes encode() and decode() for the VideoHead training pipeline. Latent shape: [B, 16, T/4, H/8, W/8] for input video of T frames at HxW. """ import os, torch import torch.nn as nn _LOCAL_VAE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models", "pig-vae") _VAE_CONFIG = { "base_dim": 96, "z_dim": 16, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "dropout": 0.0, "temperal_downsample": [False, True, True], "in_channels": 3, "out_channels": 3, "scale_factor_temporal": 4, "scale_factor_spatial": 8, } def _freeze_sidecar(model, quantize_requested=None, quantized=False): model._arb_quantize_requested = quantize_requested model._arb_quantized_int8 = bool(quantized and quantize_requested == "int8") model._arb_quantized = bool(quantized) for p in model.parameters(): p.requires_grad = False return model def _has_quantized_modules(model): markers = ("Q", "Quanto", "Quantized", "WeightQ") return any(any(marker in type(module).__name__ for marker in markers) for module in model.modules()) def _quantize_int8_if_requested(model, quantize): if quantize == 'int8': from optimum.quanto import quantize as quanto_quantize, freeze, qint8 quanto_quantize(model, weights=qint8) freeze(model) return _freeze_sidecar(model, quantize_requested=quantize, quantized=_has_quantized_modules(model)) return _freeze_sidecar(model, quantize_requested=quantize, quantized=False) def _wan_vae_cls(): try: from diffusers import AutoencoderKLWan except ModuleNotFoundError as exc: raise RuntimeError( "pig-vae requires the optional diffusers dependency. " "Install the project with `pip install -e .[diffusers]` in a venv " "before loading or verifying pig-vae int8 quantization." ) from exc return AutoencoderKLWan def load_vae(device='cuda', quantize='int8'): """Load pig-vae from local cache or diffusers. Optionally int8 quantize.""" safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors") gguf_path = os.path.join(_LOCAL_VAE_DIR, "pig_wan_vae_fp32-f16.gguf") if os.path.isfile(safetensors_path): return _load_local(safetensors_path, device, quantize, is_safetensors=True) if os.path.isfile(gguf_path): return _load_gguf(gguf_path, device, quantize) return _load_from_hf(device, quantize) def _build_vae(): AutoencoderKLWan = _wan_vae_cls() return AutoencoderKLWan( **_VAE_CONFIG, latents_mean=[-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], latents_std=[2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916], ) def _load_local(path, device, quantize, is_safetensors=False): if is_safetensors: AutoencoderKLWan = _wan_vae_cls() model = AutoencoderKLWan.from_single_file(path) else: model = _build_vae() ckpt = torch.load(path, map_location="cpu", weights_only=True) missing, unexpected = model.load_state_dict(ckpt, strict=False) if missing or unexpected: raise RuntimeError( "pig-vae local .pth checkpoint does not match AutoencoderKLWan " f"(missing={len(missing)}, unexpected={len(unexpected)})." ) model = model.to(device) model.eval() model = _quantize_int8_if_requested(model, quantize) return VAEWrapper(model) def _load_gguf(path, device, quantize): import gguf reader = gguf.GGUFReader(path) state_dict = {t.name: torch.tensor(t.data) for t in reader.tensors} model = _build_vae() missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing or unexpected: raise RuntimeError( "pig-vae local GGUF checkpoint does not match AutoencoderKLWan " f"(missing={len(missing)}, unexpected={len(unexpected)})." ) model = model.to(device) model.eval() model = _quantize_int8_if_requested(model, quantize) return VAEWrapper(model) def _load_from_hf(device, quantize): AutoencoderKLWan = _wan_vae_cls() model = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B", subfolder="vae", torch_dtype=torch.bfloat16, ) model = model.to(device) model.eval() model = _quantize_int8_if_requested(model, quantize) return VAEWrapper(model) class VAEWrapper(nn.Module): def __init__(self, vae): super().__init__() self.vae = vae self.latent_channels = _VAE_CONFIG["z_dim"] self.scale_factor = 0.476986 def encode(self, video_tensor): with torch.no_grad(): dist = self.vae.encode(video_tensor) latents = dist.latent_dist.sample() if hasattr(dist, 'latent_dist') else dist latents = latents * self.scale_factor return latents def decode(self, latents): with torch.no_grad(): latents = latents / self.scale_factor video = self.vae.decode(latents) video = video.sample if hasattr(video, 'sample') else video return video