| from typing import Protocol, Tuple |
|
|
| import torch |
|
|
| from ltx_core.types import AudioLatentShape, VideoLatentShape |
|
|
|
|
| class Patchifier(Protocol): |
| """ |
| Protocol for patchifiers that convert latent tensors into patches and assemble them back. |
| """ |
|
|
| def patchify( |
| self, |
| latents: torch.Tensor, |
| ) -> torch.Tensor: |
| ... |
| """ |
| Convert latent tensors into flattened patch tokens. |
| Args: |
| latents: Latent tensor to patchify. |
| Returns: |
| Flattened patch tokens tensor. |
| """ |
|
|
| def unpatchify( |
| self, |
| latents: torch.Tensor, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| ) -> torch.Tensor: |
| """ |
| Converts latent tensors between spatio-temporal formats and flattened sequence representations. |
| Args: |
| latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. |
| output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or |
| VideoLatentShape. |
| Returns: |
| Dense latent tensor restored from the flattened representation. |
| """ |
|
|
| @property |
| def patch_size(self) -> Tuple[int, int, int]: |
| ... |
| """ |
| Returns the patch size as a tuple of (temporal, height, width) dimensions |
| """ |
|
|
| def get_patch_grid_bounds( |
| self, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| device: torch.device | None = None, |
| ) -> torch.Tensor: |
| ... |
| """ |
| Compute metadata describing where each latent patch resides within the |
| grid specified by `output_shape`. |
| Args: |
| output_shape: Target grid layout for the patches. |
| device: Target device for the returned tensor. |
| Returns: |
| Tensor containing patch coordinate metadata such as spatial or temporal intervals. |
| """ |
|
|
|
|
| class SchedulerProtocol(Protocol): |
| """ |
| Protocol for schedulers that provide a sigmas schedule tensor for a |
| given number of steps. Device is cpu. |
| """ |
|
|
| def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ... |
|
|
|
|
| class GuiderProtocol(Protocol): |
| """ |
| Protocol for guiders that compute a delta tensor given conditioning inputs. |
| The returned delta should be added to the conditional output (cond), enabling |
| multiple guiders to be chained together by accumulating their deltas. |
| """ |
|
|
| scale: float |
|
|
| def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ... |
|
|
| def enabled(self) -> bool: |
| """ |
| Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale |
| is 1.0. |
| """ |
| ... |
|
|
|
|
| class DiffusionStepProtocol(Protocol): |
| """ |
| Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor, |
| current denoised sample tensor, and sigmas tensor. |
| """ |
|
|
| def step( |
| self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs |
| ) -> torch.Tensor: ... |
|
|