from typing import Protocol, Tuple import torch from ltx_core.types import AudioLatentShape, VideoLatentShape class Patchifier(Protocol): """ Protocol for patchifiers that convert latent tensors into patches and assemble them back. """ def patchify( self, latents: torch.Tensor, ) -> torch.Tensor: ... """ Convert latent tensors into flattened patch tokens. Args: latents: Latent tensor to patchify. Returns: Flattened patch tokens tensor. """ def unpatchify( self, latents: torch.Tensor, output_shape: AudioLatentShape | VideoLatentShape, ) -> torch.Tensor: """ Converts latent tensors between spatio-temporal formats and flattened sequence representations. Args: latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or VideoLatentShape. Returns: Dense latent tensor restored from the flattened representation. """ @property def patch_size(self) -> Tuple[int, int, int]: ... """ Returns the patch size as a tuple of (temporal, height, width) dimensions """ def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, device: torch.device | None = None, ) -> torch.Tensor: ... """ Compute metadata describing where each latent patch resides within the grid specified by `output_shape`. Args: output_shape: Target grid layout for the patches. device: Target device for the returned tensor. Returns: Tensor containing patch coordinate metadata such as spatial or temporal intervals. """ class SchedulerProtocol(Protocol): """ Protocol for schedulers that provide a sigmas schedule tensor for a given number of steps. Device is cpu. """ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ... class GuiderProtocol(Protocol): """ Protocol for guiders that compute a delta tensor given conditioning inputs. The returned delta should be added to the conditional output (cond), enabling multiple guiders to be chained together by accumulating their deltas. """ scale: float def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ... def enabled(self) -> bool: """ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale is 1.0. """ ... class DiffusionStepProtocol(Protocol): """ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor, current denoised sample tensor, and sigmas tensor. """ def step( self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs ) -> torch.Tensor: ...