Dramabox / ltx2 /ltx_core /components /patchifiers.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
import math
from typing import Optional, Tuple
import einops
import torch
from ltx_core.components.protocols import Patchifier
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
class VideoLatentPatchifier(Patchifier):
def __init__(self, patch_size: int):
# Patch sizes for video latents.
self._patch_size = (
1, # temporal dimension
patch_size, # height dimension
patch_size, # width dimension
)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: torch.Tensor,
output_shape: VideoLatentShape,
) -> torch.Tensor:
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
patch_grid_frames = output_shape.frames // self._patch_size[0]
patch_grid_height = output_shape.height // self._patch_size[1]
patch_grid_width = output_shape.width // self._patch_size[2]
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=patch_grid_frames,
h=patch_grid_height,
w=patch_grid_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the per-dimension bounds [inclusive start, exclusive end) for every
patch produced by `patchify`. The bounds are expressed in the original
video grid coordinates: frame/time, height, and width.
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
output_shape: Video grid description containing frames, height, and width.
device: Device of the latent tensor.
"""
if not isinstance(output_shape, VideoLatentShape):
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
frames = output_shape.frames
height = output_shape.height
width = output_shape.width
batch_size = output_shape.batch
# Validate inputs to ensure positive dimensions
assert frames > 0, f"frames must be positive, got {frames}"
assert height > 0, f"height must be positive, got {height}"
assert width > 0, f"width must be positive, got {width}"
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
# Generate grid coordinates for each dimension (frame, height, width)
# We use torch.arange to create the starting coordinates for each patch.
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
grid_coords = torch.meshgrid(
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
indexing="ij",
)
# Stack the grid coordinates to create the start coordinates tensor.
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_starts = torch.stack(grid_coords, dim=0)
# Create a tensor containing the size of a single patch:
# (frame_patch_size, height_patch_size, width_patch_size).
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
patch_size_delta = torch.tensor(
self._patch_size,
device=patch_starts.device,
dtype=patch_starts.dtype,
).view(3, 1, 1, 1)
# Calculate end coordinates: start + patch_size
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_ends = patch_starts + patch_size_delta
# Stack start and end coordinates together along the last dimension
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
# Final Shape: (batch_size, 3, num_patches, 2)
latent_coords = einops.repeat(
latent_coords,
"c f h w bounds -> b c (f h w) bounds",
b=batch_size,
bounds=2,
)
return latent_coords
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class AudioPatchifier(Patchifier):
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
shift: int = 0,
):
"""
Patchifier tailored for spectrogram/audio latents.
Args:
patch_size: Number of mel bins combined into a single patch. This
controls the resolution along the frequency axis.
sample_rate: Original waveform sampling rate. Used to map latent
indices back to seconds so downstream consumers can align audio
and video cues.
hop_length: Window hop length used for the spectrogram. Determines
how many real-time samples separate two consecutive latent frames.
audio_latent_downsample_factor: Ratio between spectrogram frames and
latent frames; compensates for additional downsampling inside the
VAE encoder.
is_causal: When True, timing is shifted to account for causal
receptive fields so timestamps do not peek into the future.
shift: Integer offset applied to the latent indices. Enables
constructing overlapping windows from the same latent sequence.
"""
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self.shift = shift
self._patch_size = (1, patch_size, patch_size)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
return tgt_shape.frames
def _get_audio_latent_time_in_sec(
self,
start_latent: int,
end_latent: int,
dtype: torch.dtype,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Converts latent indices into real-time seconds while honoring causal
offsets and the configured hop length.
Args:
start_latent: Inclusive start index inside the latent sequence. This
sets the first timestamp returned.
end_latent: Exclusive end index. Determines how many timestamps get
generated.
dtype: Floating-point dtype used for the returned tensor, allowing
callers to control precision.
device: Target device for the timestamp tensor. When omitted the
computation occurs on CPU to avoid surprising GPU allocations.
"""
if device is None:
device = torch.device("cpu")
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
if self.is_causal:
# Frame offset for causal alignment.
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
causal_offset = 1
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
return audio_mel_frame * self.hop_length / self.sample_rate
def _compute_audio_timings(
self,
batch_size: int,
num_steps: int,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
Args:
batch_size: Number of sequences to broadcast the timings over.
num_steps: Number of latent frames (time steps) to convert into timestamps.
device: Device on which the resulting tensor should reside.
"""
resolved_device = device
if resolved_device is None:
resolved_device = torch.device("cpu")
start_timings = self._get_audio_latent_time_in_sec(
self.shift,
num_steps + self.shift,
torch.float32,
resolved_device,
)
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
end_timings = self._get_audio_latent_time_in_sec(
self.shift + 1,
num_steps + self.shift + 1,
torch.float32,
resolved_device,
)
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
return torch.stack([start_timings, end_timings], dim=-1)
def patchify(
self,
audio_latents: torch.Tensor,
) -> torch.Tensor:
"""
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
to derive timestamps for each latent frame based on the configured hop
length and downsampling.
Args:
audio_latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
corresponding timing metadata when needed.
"""
audio_latents = einops.rearrange(
audio_latents,
"b c t f -> b t (c f)",
)
return audio_latents
def unpatchify(
self,
audio_latents: torch.Tensor,
output_shape: AudioLatentShape,
) -> torch.Tensor:
"""
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
frame's position in real time.
Args:
audio_latents: Latent tensor to unpatchify.
output_shape: Shape of the unpatched output tensor.
Returns:
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
metadata associated with the restored latents.
"""
# audio_latents shape: (batch, time, freq * channels)
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=output_shape.channels,
f=output_shape.mel_bins,
)
return audio_latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the temporal bounds `[inclusive start, exclusive end)` for every
patch emitted by `patchify`. For audio this corresponds to timestamps in
seconds aligned with the original spectrogram grid.
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
- axis 1 (size 1) represents the temporal dimension
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
Args:
output_shape: Audio grid specification describing the number of time steps.
device: Target device for the returned tensor.
"""
if not isinstance(output_shape, AudioLatentShape):
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)