Spaces:
Running on Zero
Running on Zero
File size: 3,147 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | from typing import Protocol, Tuple
import torch
from ltx_core.types import AudioLatentShape, VideoLatentShape
class Patchifier(Protocol):
"""
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
"""
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
...
"""
Convert latent tensors into flattened patch tokens.
Args:
latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor.
"""
def unpatchify(
self,
latents: torch.Tensor,
output_shape: AudioLatentShape | VideoLatentShape,
) -> torch.Tensor:
"""
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
Args:
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
VideoLatentShape.
Returns:
Dense latent tensor restored from the flattened representation.
"""
@property
def patch_size(self) -> Tuple[int, int, int]:
...
"""
Returns the patch size as a tuple of (temporal, height, width) dimensions
"""
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: torch.device | None = None,
) -> torch.Tensor:
...
"""
Compute metadata describing where each latent patch resides within the
grid specified by `output_shape`.
Args:
output_shape: Target grid layout for the patches.
device: Target device for the returned tensor.
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
class SchedulerProtocol(Protocol):
"""
Protocol for schedulers that provide a sigmas schedule tensor for a
given number of steps. Device is cpu.
"""
def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
class GuiderProtocol(Protocol):
"""
Protocol for guiders that compute a delta tensor given conditioning inputs.
The returned delta should be added to the conditional output (cond), enabling
multiple guiders to be chained together by accumulating their deltas.
"""
scale: float
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
def enabled(self) -> bool:
"""
Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
is 1.0.
"""
...
class DiffusionStepProtocol(Protocol):
"""
Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
current denoised sample tensor, and sigmas tensor.
"""
def step(
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs
) -> torch.Tensor: ...
|