| from typing import Tuple |
| import torch |
| from diffusers import AutoencoderKL |
| from einops import rearrange |
| from torch import Tensor |
|
|
|
|
| from ltx_video.models.autoencoders.causal_video_autoencoder import ( |
| CausalVideoAutoencoder, |
| ) |
| from ltx_video.models.autoencoders.video_autoencoder import ( |
| Downsample3D, |
| VideoAutoencoder, |
| ) |
|
|
| try: |
| import torch_xla.core.xla_model as xm |
| except ImportError: |
| xm = None |
|
|
|
|
| def vae_encode( |
| media_items: Tensor, |
| vae: AutoencoderKL, |
| split_size: int = 1, |
| vae_per_channel_normalize=False, |
| ) -> Tensor: |
| """ |
| Encodes media items (images or videos) into latent representations using a specified VAE model. |
| The function supports processing batches of images or video frames and can handle the processing |
| in smaller sub-batches if needed. |
| |
| Args: |
| media_items (Tensor): A torch Tensor containing the media items to encode. The expected |
| shape is (batch_size, channels, height, width) for images or (batch_size, channels, |
| frames, height, width) for videos. |
| vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, |
| pre-configured and loaded with the appropriate model weights. |
| split_size (int, optional): The number of sub-batches to split the input batch into for encoding. |
| If set to more than 1, the input media items are processed in smaller batches according to |
| this value. Defaults to 1, which processes all items in a single batch. |
| |
| Returns: |
| Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted |
| to match the input shape, scaled by the model's configuration. |
| |
| Examples: |
| >>> import torch |
| >>> from diffusers import AutoencoderKL |
| >>> vae = AutoencoderKL.from_pretrained('your-model-name') |
| >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. |
| >>> latents = vae_encode(images, vae) |
| >>> print(latents.shape) # Output shape will depend on the model's latent configuration. |
| |
| Note: |
| In case of a video, the function encodes the media item frame-by frame. |
| """ |
| is_video_shaped = media_items.dim() == 5 |
| batch_size, channels = media_items.shape[0:2] |
|
|
| if channels != 3: |
| raise ValueError(f"Expects tensors with 3 channels, got {channels}.") |
|
|
| if is_video_shaped and not isinstance( |
| vae, (VideoAutoencoder, CausalVideoAutoencoder) |
| ): |
| media_items = rearrange(media_items, "b c n h w -> (b n) c h w") |
| if split_size > 1: |
| if len(media_items) % split_size != 0: |
| raise ValueError( |
| "Error: The batch size must be divisible by 'train.vae_bs_split" |
| ) |
| encode_bs = len(media_items) // split_size |
| |
| latents = [] |
| if media_items.device.type == "xla": |
| xm.mark_step() |
| for image_batch in media_items.split(encode_bs): |
| latents.append(vae.encode(image_batch).latent_dist.sample()) |
| if media_items.device.type == "xla": |
| xm.mark_step() |
| latents = torch.cat(latents, dim=0) |
| else: |
| latents = vae.encode(media_items).latent_dist.sample() |
|
|
| latents = normalize_latents(latents, vae, vae_per_channel_normalize) |
| if is_video_shaped and not isinstance( |
| vae, (VideoAutoencoder, CausalVideoAutoencoder) |
| ): |
| latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) |
| return latents |
|
|
|
|
| def vae_decode( |
| latents: Tensor, |
| vae: AutoencoderKL, |
| is_video: bool = True, |
| split_size: int = 1, |
| vae_per_channel_normalize=False, |
| timestep=None, |
| ) -> Tensor: |
| is_video_shaped = latents.dim() == 5 |
| batch_size = latents.shape[0] |
|
|
| if is_video_shaped and not isinstance( |
| vae, (VideoAutoencoder, CausalVideoAutoencoder) |
| ): |
| latents = rearrange(latents, "b c n h w -> (b n) c h w") |
| if split_size > 1: |
| if len(latents) % split_size != 0: |
| raise ValueError( |
| "Error: The batch size must be divisible by 'train.vae_bs_split" |
| ) |
| encode_bs = len(latents) // split_size |
| image_batch = [ |
| _run_decoder( |
| latent_batch, vae, is_video, vae_per_channel_normalize, timestep |
| ) |
| for latent_batch in latents.split(encode_bs) |
| ] |
| images = torch.cat(image_batch, dim=0) |
| else: |
| images = _run_decoder( |
| latents, vae, is_video, vae_per_channel_normalize, timestep |
| ) |
|
|
| if is_video_shaped and not isinstance( |
| vae, (VideoAutoencoder, CausalVideoAutoencoder) |
| ): |
| images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) |
| return images |
|
|
|
|
| def _run_decoder( |
| latents: Tensor, |
| vae: AutoencoderKL, |
| is_video: bool, |
| vae_per_channel_normalize=False, |
| timestep=None, |
| ) -> Tensor: |
| if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
| *_, fl, hl, wl = latents.shape |
| temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) |
| latents = latents.to(vae.dtype) |
| vae_decode_kwargs = {} |
| if timestep is not None: |
| vae_decode_kwargs["timestep"] = timestep |
| image = vae.decode( |
| un_normalize_latents(latents, vae, vae_per_channel_normalize), |
| return_dict=False, |
| target_shape=( |
| 1, |
| 3, |
| fl * temporal_scale if is_video else 1, |
| hl * spatial_scale, |
| wl * spatial_scale, |
| ), |
| **vae_decode_kwargs, |
| )[0] |
| else: |
| image = vae.decode( |
| un_normalize_latents(latents, vae, vae_per_channel_normalize), |
| return_dict=False, |
| )[0] |
| return image |
|
|
|
|
| def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: |
| if isinstance(vae, CausalVideoAutoencoder): |
| spatial = vae.spatial_downscale_factor |
| temporal = vae.temporal_downscale_factor |
| else: |
| down_blocks = len( |
| [ |
| block |
| for block in vae.encoder.down_blocks |
| if isinstance(block.downsample, Downsample3D) |
| ] |
| ) |
| spatial = vae.config.patch_size * 2**down_blocks |
| temporal = ( |
| vae.config.patch_size_t * 2**down_blocks |
| if isinstance(vae, VideoAutoencoder) |
| else 1 |
| ) |
|
|
| return (temporal, spatial, spatial) |
|
|
|
|
| def latent_to_pixel_coords( |
| latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False |
| ) -> Tensor: |
| """ |
| Converts latent coordinates to pixel coordinates by scaling them according to the VAE's |
| configuration. |
| |
| Args: |
| latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] |
| containing the latent corner coordinates of each token. |
| vae (AutoencoderKL): The VAE model |
| causal_fix (bool): Whether to take into account the different temporal scale |
| of the first frame. Default = False for backwards compatibility. |
| Returns: |
| Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. |
| """ |
|
|
| scale_factors = get_vae_size_scale_factor(vae) |
| causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix |
| pixel_coords = latent_to_pixel_coords_from_factors( |
| latent_coords, scale_factors, causal_fix |
| ) |
| return pixel_coords |
|
|
|
|
| def latent_to_pixel_coords_from_factors( |
| latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False |
| ) -> Tensor: |
| pixel_coords = ( |
| latent_coords |
| * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] |
| ) |
| if causal_fix: |
| |
| pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) |
| return pixel_coords |
|
|
|
|
| def normalize_latents( |
| latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False |
| ) -> Tensor: |
| return ( |
| (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) |
| / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
| if vae_per_channel_normalize |
| else latents * vae.config.scaling_factor |
| ) |
|
|
|
|
| def un_normalize_latents( |
| latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False |
| ) -> Tensor: |
| return ( |
| latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
| + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
| if vae_per_channel_normalize |
| else latents / vae.config.scaling_factor |
| ) |
|
|