| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| __all__ = ['WanVideoVAE'] |
|
|
| from typing import List |
| import torch |
| from torch import Tensor |
| from einops import rearrange |
|
|
| from common.utils.logging import get_logger |
| from common.utils.distributed import get_device |
| from common.utils.misc import AutoEncoderParams |
| from .vae2_2 import Wan2_2_VAE |
|
|
|
|
| def reparameterize(mu, log_var): |
| std = torch.exp(0.5 * log_var) |
| eps = torch.randn_like(std) |
| return eps * std + mu |
|
|
|
|
| class WanVideoVAE(object): |
| __version__ = "v2.2" |
| __name__ = "WanVideoVAE" |
| __logger__ = None |
|
|
| def __init__(self, config_path: str = "", **kwargs) -> None: |
| if self.__class__.__logger__ is None: |
| self.__class__.__logger__ = get_logger(self.__class__.__name__) |
| self.logger = self.__class__.__logger__ |
|
|
| self.dtype = kwargs.get("dtype", torch.bfloat16) |
| self.configure_vae_model() |
| self.use_sample = kwargs.get("use_sample", True) |
|
|
| |
| self.vae_config = AutoEncoderParams( |
| downsample_spatial=16, |
| downsample_temporal=4, |
| z_channels=48, |
| |
| |
| ) |
|
|
| def configure_vae_model(self): |
| device = get_device() |
|
|
| |
| try: |
| from config.config_factory import get_model_path |
| vae_path = get_model_path("vae.wan") |
| except Exception as e: |
| |
| vae_path = "downloads/Wan2.2_VAE.pth" |
| |
| self.vae: Wan2_2_VAE = Wan2_2_VAE(vae_pth=vae_path, device=device, dtype=self.dtype) |
| |
| |
|
|
| @torch.no_grad() |
| def vae_encode(self, samples: List[Tensor], **kwargs) -> List[Tensor]: |
| device = get_device() |
|
|
| latents = [] |
| with torch.autocast(device_type="cuda", dtype=self.dtype): |
| for x in samples: |
| x = x.to(device=device).unsqueeze(0) |
|
|
| u, log_var = self.vae.encode(x) |
|
|
| if self.use_sample: |
| u = reparameterize(u, log_var) |
|
|
| u = rearrange(u, "b c ... -> b ... c") |
|
|
| latents.append(u.squeeze(0)) |
|
|
| return latents |
|
|
| @torch.no_grad() |
| def vae_decode(self, latents: List[Tensor], **kwargs) -> List[Tensor]: |
| device = get_device() |
|
|
| samples = [] |
| with torch.autocast(device_type="cuda", dtype=self.dtype): |
| for u in latents: |
| u = u.unsqueeze(0).to(device=device) |
| u = rearrange(u, "b ... c -> b c ...") |
|
|
| x_hat = self.vae.decode(u) |
|
|
| samples.append(x_hat.squeeze(0)) |
|
|
| return samples |
|
|