Spaces:
Running on Zero
Running on Zero
| from dataclasses import dataclass, replace | |
| from typing import Protocol | |
| import torch | |
| from torch._prims_common import DeviceLikeType | |
| from ltx_core.components.patchifiers import ( | |
| AudioLatentShape, | |
| AudioPatchifier, | |
| VideoLatentPatchifier, | |
| VideoLatentShape, | |
| get_pixel_coords, | |
| ) | |
| from ltx_core.components.protocols import Patchifier | |
| from ltx_core.types import LatentState, SpatioTemporalScaleFactors | |
| DEFAULT_SCALE_FACTORS = SpatioTemporalScaleFactors.default() | |
| class LatentTools(Protocol): | |
| """ | |
| Tools for building latent states. | |
| """ | |
| patchifier: Patchifier | |
| target_shape: VideoLatentShape | AudioLatentShape | |
| def create_initial_state( | |
| self, | |
| device: DeviceLikeType, | |
| dtype: torch.dtype, | |
| initial_latent: torch.Tensor | None = None, | |
| ) -> LatentState: | |
| """ | |
| Create an initial latent state. If initial_latent is provided, it will be used to create the latent state. | |
| """ | |
| ... | |
| def patchify(self, latent_state: LatentState) -> LatentState: | |
| """ | |
| Patchify the latent state. | |
| """ | |
| if latent_state.latent.shape != self.target_shape.to_torch_shape(): | |
| raise ValueError( | |
| f"Latent state has shape {latent_state.latent.shape}, expected shape is " | |
| f"{self.target_shape.to_torch_shape()}" | |
| ) | |
| latent_state = latent_state.clone() | |
| latent = self.patchifier.patchify(latent_state.latent) | |
| clean_latent = self.patchifier.patchify(latent_state.clean_latent) | |
| denoise_mask = self.patchifier.patchify(latent_state.denoise_mask) | |
| return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent) | |
| def unpatchify(self, latent_state: LatentState) -> LatentState: | |
| """ | |
| Unpatchify the latent state. | |
| """ | |
| latent_state = latent_state.clone() | |
| latent = self.patchifier.unpatchify(latent_state.latent, output_shape=self.target_shape) | |
| clean_latent = self.patchifier.unpatchify(latent_state.clean_latent, output_shape=self.target_shape) | |
| denoise_mask = self.patchifier.unpatchify( | |
| latent_state.denoise_mask, output_shape=self.target_shape.mask_shape() | |
| ) | |
| return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent) | |
| def clear_conditioning(self, latent_state: LatentState) -> LatentState: | |
| """ | |
| Clear the conditioning from the latent state. This method removes extra tokens from the end of the latent. | |
| Therefore, conditioning items should add extra tokens ONLY to the end of the latent. | |
| """ | |
| latent_state = latent_state.clone() | |
| num_tokens = self.patchifier.get_token_count(self.target_shape) | |
| latent = latent_state.latent[:, :num_tokens] | |
| clean_latent = latent_state.clean_latent[:, :num_tokens] | |
| denoise_mask = torch.ones_like(latent_state.denoise_mask)[:, :num_tokens] | |
| positions = latent_state.positions[:, :, :num_tokens] | |
| return LatentState( | |
| latent=latent, | |
| denoise_mask=denoise_mask, | |
| positions=positions, | |
| clean_latent=clean_latent, | |
| attention_mask=None, | |
| ) | |
| class VideoLatentTools(LatentTools): | |
| """ | |
| Tools for building video latent states. | |
| """ | |
| patchifier: VideoLatentPatchifier | |
| target_shape: VideoLatentShape | |
| fps: float | |
| scale_factors: SpatioTemporalScaleFactors = DEFAULT_SCALE_FACTORS | |
| causal_fix: bool = True | |
| def create_initial_state( | |
| self, | |
| device: DeviceLikeType, | |
| dtype: torch.dtype, | |
| initial_latent: torch.Tensor | None = None, | |
| ) -> LatentState: | |
| if initial_latent is not None: | |
| assert initial_latent.shape == self.target_shape.to_torch_shape(), ( | |
| f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}" | |
| ) | |
| else: | |
| initial_latent = torch.zeros( | |
| *self.target_shape.to_torch_shape(), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| clean_latent = initial_latent.clone() | |
| denoise_mask = torch.ones( | |
| *self.target_shape.mask_shape().to_torch_shape(), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| latent_coords = self.patchifier.get_patch_grid_bounds( | |
| output_shape=self.target_shape, | |
| device=device, | |
| ) | |
| positions = get_pixel_coords( | |
| latent_coords=latent_coords, | |
| scale_factors=self.scale_factors, | |
| causal_fix=self.causal_fix, | |
| ).float() | |
| positions[:, 0, ...] = positions[:, 0, ...] / self.fps | |
| return self.patchify( | |
| LatentState( | |
| latent=initial_latent, | |
| denoise_mask=denoise_mask, | |
| positions=positions.to(dtype), | |
| clean_latent=clean_latent, | |
| ) | |
| ) | |
| class AudioLatentTools(LatentTools): | |
| """ | |
| Tools for building audio latent states. | |
| """ | |
| patchifier: AudioPatchifier | |
| target_shape: AudioLatentShape | |
| def create_initial_state( | |
| self, | |
| device: DeviceLikeType, | |
| dtype: torch.dtype, | |
| initial_latent: torch.Tensor | None = None, | |
| ) -> LatentState: | |
| if initial_latent is not None: | |
| assert initial_latent.shape == self.target_shape.to_torch_shape(), ( | |
| f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}" | |
| ) | |
| else: | |
| initial_latent = torch.zeros( | |
| *self.target_shape.to_torch_shape(), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| clean_latent = initial_latent.clone() | |
| denoise_mask = torch.ones( | |
| *self.target_shape.mask_shape().to_torch_shape(), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| latent_coords = self.patchifier.get_patch_grid_bounds( | |
| output_shape=self.target_shape, | |
| device=device, | |
| ) | |
| return self.patchify( | |
| LatentState( | |
| latent=initial_latent, denoise_mask=denoise_mask, positions=latent_coords, clean_latent=clean_latent | |
| ) | |
| ) | |