Spaces:
Running on Zero
Running on Zero
| from typing import Protocol, Tuple | |
| import torch | |
| from ltx_core.types import AudioLatentShape, VideoLatentShape | |
| class Patchifier(Protocol): | |
| """ | |
| Protocol for patchifiers that convert latent tensors into patches and assemble them back. | |
| """ | |
| def patchify( | |
| self, | |
| latents: torch.Tensor, | |
| ) -> torch.Tensor: | |
| ... | |
| """ | |
| Convert latent tensors into flattened patch tokens. | |
| Args: | |
| latents: Latent tensor to patchify. | |
| Returns: | |
| Flattened patch tokens tensor. | |
| """ | |
| def unpatchify( | |
| self, | |
| latents: torch.Tensor, | |
| output_shape: AudioLatentShape | VideoLatentShape, | |
| ) -> torch.Tensor: | |
| """ | |
| Converts latent tensors between spatio-temporal formats and flattened sequence representations. | |
| Args: | |
| latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. | |
| output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or | |
| VideoLatentShape. | |
| Returns: | |
| Dense latent tensor restored from the flattened representation. | |
| """ | |
| def patch_size(self) -> Tuple[int, int, int]: | |
| ... | |
| """ | |
| Returns the patch size as a tuple of (temporal, height, width) dimensions | |
| """ | |
| def get_patch_grid_bounds( | |
| self, | |
| output_shape: AudioLatentShape | VideoLatentShape, | |
| device: torch.device | None = None, | |
| ) -> torch.Tensor: | |
| ... | |
| """ | |
| Compute metadata describing where each latent patch resides within the | |
| grid specified by `output_shape`. | |
| Args: | |
| output_shape: Target grid layout for the patches. | |
| device: Target device for the returned tensor. | |
| Returns: | |
| Tensor containing patch coordinate metadata such as spatial or temporal intervals. | |
| """ | |
| class SchedulerProtocol(Protocol): | |
| """ | |
| Protocol for schedulers that provide a sigmas schedule tensor for a | |
| given number of steps. Device is cpu. | |
| """ | |
| def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ... | |
| class GuiderProtocol(Protocol): | |
| """ | |
| Protocol for guiders that compute a delta tensor given conditioning inputs. | |
| The returned delta should be added to the conditional output (cond), enabling | |
| multiple guiders to be chained together by accumulating their deltas. | |
| """ | |
| scale: float | |
| def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ... | |
| def enabled(self) -> bool: | |
| """ | |
| Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale | |
| is 1.0. | |
| """ | |
| ... | |
| class DiffusionStepProtocol(Protocol): | |
| """ | |
| Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor, | |
| current denoised sample tensor, and sigmas tensor. | |
| """ | |
| def step( | |
| self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs | |
| ) -> torch.Tensor: ... | |