jsflow / back /preprocessing /encoders.py
xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Converting between pixel and latent representations of image data."""
import os
import warnings
import numpy as np
import torch
from torch_utils import persistence
from torch_utils import misc
warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
warnings.filterwarnings('ignore', '`resume_download` is deprecated')
#----------------------------------------------------------------------------
# Abstract base class for encoders/decoders that convert back and forth
# between pixel and latent representations of image data.
#
# Logically, "raw pixels" are first encoded into "raw latents" that are
# then further encoded into "final latents". Decoding, on the other hand,
# goes directly from the final latents to raw pixels. The final latents are
# used as inputs and outputs of the model, whereas the raw latents are
# stored in the dataset. This separation provides added flexibility in terms
# of performing just-in-time adjustments, such as data whitening, without
# having to construct a new dataset.
#
# All image data is represented as PyTorch tensors in NCHW order.
# Raw pixels are represented as 3-channel uint8.
@persistence.persistent_class
class Encoder:
def __init__(self):
pass
def init(self, device): # force lazy init to happen now
pass
def __getstate__(self):
return self.__dict__
def encode_pixels(self, x): # raw pixels => raw latents
raise NotImplementedError # to be overridden by subclass
#----------------------------------------------------------------------------
# Pre-trained VAE encoder from Stability AI.
@persistence.persistent_class
class StabilityVAEEncoder(Encoder):
def __init__(self,
vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
batch_size = 8, # Batch size to use when running the VAE.
):
super().__init__()
self.vae_name = vae_name
self.batch_size = int(batch_size)
self._vae = None
def init(self, device): # force lazy init to happen now
super().init(device)
if self._vae is None:
self._vae = load_stability_vae(self.vae_name, device=device)
else:
self._vae.to(device)
def __getstate__(self):
return dict(super().__getstate__(), _vae=None) # do not pickle the vae
def _run_vae_encoder(self, x):
d = self._vae.encode(x)['latent_dist']
return torch.cat([d.mean, d.std], dim=1)
def encode_pixels(self, x): # raw pixels => raw latents
self.init(x.device)
x = x.to(torch.float32) / 127.5 - 1
x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
return x
#----------------------------------------------------------------------------
def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
import dnnlib
cache_dir = dnnlib.make_cache_dir_path('diffusers')
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
os.environ['HF_HOME'] = cache_dir
import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
try:
# First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
vae = diffusers.models.AutoencoderKL.from_pretrained(
vae_name, cache_dir=cache_dir, local_files_only=True
)
except:
# Could not load the model from cache; try without local_files_only.
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir)
return vae.eval().requires_grad_(False).to(device)
#----------------------------------------------------------------------------