import math from typing import Optional, Tuple import einops import torch from ltx_core.components.protocols import Patchifier from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape class VideoLatentPatchifier(Patchifier): def __init__(self, patch_size: int): # Patch sizes for video latents. self._patch_size = ( 1, # temporal dimension patch_size, # height dimension patch_size, # width dimension ) @property def patch_size(self) -> Tuple[int, int, int]: return self._patch_size def get_token_count(self, tgt_shape: VideoLatentShape) -> int: return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) def patchify( self, latents: torch.Tensor, ) -> torch.Tensor: latents = einops.rearrange( latents, "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", p1=self._patch_size[0], p2=self._patch_size[1], p3=self._patch_size[2], ) return latents def unpatchify( self, latents: torch.Tensor, output_shape: VideoLatentShape, ) -> torch.Tensor: assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" patch_grid_frames = output_shape.frames // self._patch_size[0] patch_grid_height = output_shape.height // self._patch_size[1] patch_grid_width = output_shape.width // self._patch_size[2] latents = einops.rearrange( latents, "b (f h w) (c p q) -> b c f (h p) (w q)", f=patch_grid_frames, h=patch_grid_height, w=patch_grid_width, p=self._patch_size[1], q=self._patch_size[2], ) return latents def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Return the per-dimension bounds [inclusive start, exclusive end) for every patch produced by `patchify`. The bounds are expressed in the original video grid coordinates: frame/time, height, and width. The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: - axis 1 (size 3) enumerates (frame/time, height, width) dimensions - axis 3 (size 2) stores `[start, end)` indices within each dimension Args: output_shape: Video grid description containing frames, height, and width. device: Device of the latent tensor. """ if not isinstance(output_shape, VideoLatentShape): raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") frames = output_shape.frames height = output_shape.height width = output_shape.width batch_size = output_shape.batch # Validate inputs to ensure positive dimensions assert frames > 0, f"frames must be positive, got {frames}" assert height > 0, f"height must be positive, got {height}" assert width > 0, f"width must be positive, got {width}" assert batch_size > 0, f"batch_size must be positive, got {batch_size}" # Generate grid coordinates for each dimension (frame, height, width) # We use torch.arange to create the starting coordinates for each patch. # indexing='ij' ensures the dimensions are in the order (frame, height, width). grid_coords = torch.meshgrid( torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), torch.arange(start=0, end=height, step=self._patch_size[1], device=device), torch.arange(start=0, end=width, step=self._patch_size[2], device=device), indexing="ij", ) # Stack the grid coordinates to create the start coordinates tensor. # Shape becomes (3, grid_f, grid_h, grid_w) patch_starts = torch.stack(grid_coords, dim=0) # Create a tensor containing the size of a single patch: # (frame_patch_size, height_patch_size, width_patch_size). # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. patch_size_delta = torch.tensor( self._patch_size, device=patch_starts.device, dtype=patch_starts.dtype, ).view(3, 1, 1, 1) # Calculate end coordinates: start + patch_size # Shape becomes (3, grid_f, grid_h, grid_w) patch_ends = patch_starts + patch_size_delta # Stack start and end coordinates together along the last dimension # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. # Final Shape: (batch_size, 3, num_patches, 2) latent_coords = einops.repeat( latent_coords, "c f h w bounds -> b c (f h w) bounds", b=batch_size, bounds=2, ) return latent_coords def get_pixel_coords( latent_coords: torch.Tensor, scale_factors: SpatioTemporalScaleFactors, causal_fix: bool = False, ) -> torch.Tensor: """ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling each axis (frame/time, height, width) with the corresponding VAE downsampling factors. Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. Args: latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied per axis. causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs that treat frame zero differently still yield non-negative timestamps. """ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. broadcast_shape = [1] * latent_coords.ndim broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. pixel_coords = latent_coords * scale_tensor if causal_fix: # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. # Shift and clamp to keep the first-frame timestamps causal and non-negative. pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class AudioPatchifier(Patchifier): def __init__( self, patch_size: int, sample_rate: int = 16000, hop_length: int = 160, audio_latent_downsample_factor: int = 4, is_causal: bool = True, shift: int = 0, ): """ Patchifier tailored for spectrogram/audio latents. Args: patch_size: Number of mel bins combined into a single patch. This controls the resolution along the frequency axis. sample_rate: Original waveform sampling rate. Used to map latent indices back to seconds so downstream consumers can align audio and video cues. hop_length: Window hop length used for the spectrogram. Determines how many real-time samples separate two consecutive latent frames. audio_latent_downsample_factor: Ratio between spectrogram frames and latent frames; compensates for additional downsampling inside the VAE encoder. is_causal: When True, timing is shifted to account for causal receptive fields so timestamps do not peek into the future. shift: Integer offset applied to the latent indices. Enables constructing overlapping windows from the same latent sequence. """ self.hop_length = hop_length self.sample_rate = sample_rate self.audio_latent_downsample_factor = audio_latent_downsample_factor self.is_causal = is_causal self.shift = shift self._patch_size = (1, patch_size, patch_size) @property def patch_size(self) -> Tuple[int, int, int]: return self._patch_size def get_token_count(self, tgt_shape: AudioLatentShape) -> int: return tgt_shape.frames def _get_audio_latent_time_in_sec( self, start_latent: int, end_latent: int, dtype: torch.dtype, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Converts latent indices into real-time seconds while honoring causal offsets and the configured hop length. Args: start_latent: Inclusive start index inside the latent sequence. This sets the first timestamp returned. end_latent: Exclusive end index. Determines how many timestamps get generated. dtype: Floating-point dtype used for the returned tensor, allowing callers to control precision. device: Target device for the timestamp tensor. When omitted the computation occurs on CPU to avoid surprising GPU allocations. """ if device is None: device = torch.device("cpu") audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor if self.is_causal: # Frame offset for causal alignment. # The "+1" ensures the timestamp corresponds to the first sample that is fully available. causal_offset = 1 audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) return audio_mel_frame * self.hop_length / self.sample_rate def _compute_audio_timings( self, batch_size: int, num_steps: int, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. This helper method underpins `get_patch_grid_bounds` for the audio patchifier. Args: batch_size: Number of sequences to broadcast the timings over. num_steps: Number of latent frames (time steps) to convert into timestamps. device: Device on which the resulting tensor should reside. """ resolved_device = device if resolved_device is None: resolved_device = torch.device("cpu") start_timings = self._get_audio_latent_time_in_sec( self.shift, num_steps + self.shift, torch.float32, resolved_device, ) start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) end_timings = self._get_audio_latent_time_in_sec( self.shift + 1, num_steps + self.shift + 1, torch.float32, resolved_device, ) end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) return torch.stack([start_timings, end_timings], dim=-1) def patchify( self, audio_latents: torch.Tensor, ) -> torch.Tensor: """ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` to derive timestamps for each latent frame based on the configured hop length and downsampling. Args: audio_latents: Latent tensor to patchify. Returns: Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the corresponding timing metadata when needed. """ audio_latents = einops.rearrange( audio_latents, "b c t f -> b t (c f)", ) return audio_latents def unpatchify( self, audio_latents: torch.Tensor, output_shape: AudioLatentShape, ) -> torch.Tensor: """ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. Use `get_patch_grid_bounds` to recompute the timestamps that describe each frame's position in real time. Args: audio_latents: Latent tensor to unpatchify. output_shape: Shape of the unpatched output tensor. Returns: Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing metadata associated with the restored latents. """ # audio_latents shape: (batch, time, freq * channels) audio_latents = einops.rearrange( audio_latents, "b t (c f) -> b c t f", c=output_shape.channels, f=output_shape.mel_bins, ) return audio_latents def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Return the temporal bounds `[inclusive start, exclusive end)` for every patch emitted by `patchify`. For audio this corresponds to timestamps in seconds aligned with the original spectrogram grid. The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: - axis 1 (size 1) represents the temporal dimension - axis 3 (size 2) stores the `[start, end)` timestamps per patch Args: output_shape: Audio grid specification describing the number of time steps. device: Target device for the returned tensor. """ if not isinstance(output_shape, AudioLatentShape): raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)