| import math |
| from typing import Optional, Tuple |
|
|
| import einops |
| import torch |
|
|
| from ltx_core.components.protocols import Patchifier |
| from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape |
|
|
|
|
| class VideoLatentPatchifier(Patchifier): |
| def __init__(self, patch_size: int): |
| |
| self._patch_size = ( |
| 1, |
| patch_size, |
| patch_size, |
| ) |
|
|
| @property |
| def patch_size(self) -> Tuple[int, int, int]: |
| return self._patch_size |
|
|
| def get_token_count(self, tgt_shape: VideoLatentShape) -> int: |
| return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) |
|
|
| def patchify( |
| self, |
| latents: torch.Tensor, |
| ) -> torch.Tensor: |
| latents = einops.rearrange( |
| latents, |
| "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", |
| p1=self._patch_size[0], |
| p2=self._patch_size[1], |
| p3=self._patch_size[2], |
| ) |
|
|
| return latents |
|
|
| def unpatchify( |
| self, |
| latents: torch.Tensor, |
| output_shape: VideoLatentShape, |
| ) -> torch.Tensor: |
| assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" |
|
|
| patch_grid_frames = output_shape.frames // self._patch_size[0] |
| patch_grid_height = output_shape.height // self._patch_size[1] |
| patch_grid_width = output_shape.width // self._patch_size[2] |
|
|
| latents = einops.rearrange( |
| latents, |
| "b (f h w) (c p q) -> b c f (h p) (w q)", |
| f=patch_grid_frames, |
| h=patch_grid_height, |
| w=patch_grid_width, |
| p=self._patch_size[1], |
| q=self._patch_size[2], |
| ) |
|
|
| return latents |
|
|
| def get_patch_grid_bounds( |
| self, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Return the per-dimension bounds [inclusive start, exclusive end) for every |
| patch produced by `patchify`. The bounds are expressed in the original |
| video grid coordinates: frame/time, height, and width. |
| The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: |
| - axis 1 (size 3) enumerates (frame/time, height, width) dimensions |
| - axis 3 (size 2) stores `[start, end)` indices within each dimension |
| Args: |
| output_shape: Video grid description containing frames, height, and width. |
| device: Device of the latent tensor. |
| """ |
| if not isinstance(output_shape, VideoLatentShape): |
| raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") |
|
|
| frames = output_shape.frames |
| height = output_shape.height |
| width = output_shape.width |
| batch_size = output_shape.batch |
|
|
| |
| assert frames > 0, f"frames must be positive, got {frames}" |
| assert height > 0, f"height must be positive, got {height}" |
| assert width > 0, f"width must be positive, got {width}" |
| assert batch_size > 0, f"batch_size must be positive, got {batch_size}" |
|
|
| |
| |
| |
| grid_coords = torch.meshgrid( |
| torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), |
| torch.arange(start=0, end=height, step=self._patch_size[1], device=device), |
| torch.arange(start=0, end=width, step=self._patch_size[2], device=device), |
| indexing="ij", |
| ) |
|
|
| |
| |
| patch_starts = torch.stack(grid_coords, dim=0) |
|
|
| |
| |
| |
| patch_size_delta = torch.tensor( |
| self._patch_size, |
| device=patch_starts.device, |
| dtype=patch_starts.dtype, |
| ).view(3, 1, 1, 1) |
|
|
| |
| |
| patch_ends = patch_starts + patch_size_delta |
|
|
| |
| |
| latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) |
|
|
| |
| |
| latent_coords = einops.repeat( |
| latent_coords, |
| "c f h w bounds -> b c (f h w) bounds", |
| b=batch_size, |
| bounds=2, |
| ) |
|
|
| return latent_coords |
|
|
|
|
| def get_pixel_coords( |
| latent_coords: torch.Tensor, |
| scale_factors: SpatioTemporalScaleFactors, |
| causal_fix: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling |
| each axis (frame/time, height, width) with the corresponding VAE downsampling factors. |
| Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. |
| Args: |
| latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. |
| scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied |
| per axis. |
| causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs |
| that treat frame zero differently still yield non-negative timestamps. |
| """ |
| |
| broadcast_shape = [1] * latent_coords.ndim |
| broadcast_shape[1] = -1 |
| scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) |
|
|
| |
| pixel_coords = latent_coords * scale_tensor |
|
|
| if causal_fix: |
| |
| |
| pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) |
|
|
| return pixel_coords |
|
|
|
|
| class AudioPatchifier(Patchifier): |
| def __init__( |
| self, |
| patch_size: int, |
| sample_rate: int = 16000, |
| hop_length: int = 160, |
| audio_latent_downsample_factor: int = 4, |
| is_causal: bool = True, |
| shift: int = 0, |
| ): |
| """ |
| Patchifier tailored for spectrogram/audio latents. |
| Args: |
| patch_size: Number of mel bins combined into a single patch. This |
| controls the resolution along the frequency axis. |
| sample_rate: Original waveform sampling rate. Used to map latent |
| indices back to seconds so downstream consumers can align audio |
| and video cues. |
| hop_length: Window hop length used for the spectrogram. Determines |
| how many real-time samples separate two consecutive latent frames. |
| audio_latent_downsample_factor: Ratio between spectrogram frames and |
| latent frames; compensates for additional downsampling inside the |
| VAE encoder. |
| is_causal: When True, timing is shifted to account for causal |
| receptive fields so timestamps do not peek into the future. |
| shift: Integer offset applied to the latent indices. Enables |
| constructing overlapping windows from the same latent sequence. |
| """ |
| self.hop_length = hop_length |
| self.sample_rate = sample_rate |
| self.audio_latent_downsample_factor = audio_latent_downsample_factor |
| self.is_causal = is_causal |
| self.shift = shift |
| self._patch_size = (1, patch_size, patch_size) |
|
|
| @property |
| def patch_size(self) -> Tuple[int, int, int]: |
| return self._patch_size |
|
|
| def get_token_count(self, tgt_shape: AudioLatentShape) -> int: |
| return tgt_shape.frames |
|
|
| def _get_audio_latent_time_in_sec( |
| self, |
| start_latent: int, |
| end_latent: int, |
| dtype: torch.dtype, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Converts latent indices into real-time seconds while honoring causal |
| offsets and the configured hop length. |
| Args: |
| start_latent: Inclusive start index inside the latent sequence. This |
| sets the first timestamp returned. |
| end_latent: Exclusive end index. Determines how many timestamps get |
| generated. |
| dtype: Floating-point dtype used for the returned tensor, allowing |
| callers to control precision. |
| device: Target device for the timestamp tensor. When omitted the |
| computation occurs on CPU to avoid surprising GPU allocations. |
| """ |
| if device is None: |
| device = torch.device("cpu") |
|
|
| audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) |
|
|
| audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor |
|
|
| if self.is_causal: |
| |
| |
| causal_offset = 1 |
| audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) |
|
|
| return audio_mel_frame * self.hop_length / self.sample_rate |
|
|
| def _compute_audio_timings( |
| self, |
| batch_size: int, |
| num_steps: int, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. |
| This helper method underpins `get_patch_grid_bounds` for the audio patchifier. |
| Args: |
| batch_size: Number of sequences to broadcast the timings over. |
| num_steps: Number of latent frames (time steps) to convert into timestamps. |
| device: Device on which the resulting tensor should reside. |
| """ |
| resolved_device = device |
| if resolved_device is None: |
| resolved_device = torch.device("cpu") |
|
|
| start_timings = self._get_audio_latent_time_in_sec( |
| self.shift, |
| num_steps + self.shift, |
| torch.float32, |
| resolved_device, |
| ) |
| start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) |
|
|
| end_timings = self._get_audio_latent_time_in_sec( |
| self.shift + 1, |
| num_steps + self.shift + 1, |
| torch.float32, |
| resolved_device, |
| ) |
| end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) |
|
|
| return torch.stack([start_timings, end_timings], dim=-1) |
|
|
| def patchify( |
| self, |
| audio_latents: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` |
| to derive timestamps for each latent frame based on the configured hop |
| length and downsampling. |
| Args: |
| audio_latents: Latent tensor to patchify. |
| Returns: |
| Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the |
| corresponding timing metadata when needed. |
| """ |
| audio_latents = einops.rearrange( |
| audio_latents, |
| "b c t f -> b t (c f)", |
| ) |
|
|
| return audio_latents |
|
|
| def unpatchify( |
| self, |
| audio_latents: torch.Tensor, |
| output_shape: AudioLatentShape, |
| ) -> torch.Tensor: |
| """ |
| Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. |
| Use `get_patch_grid_bounds` to recompute the timestamps that describe each |
| frame's position in real time. |
| Args: |
| audio_latents: Latent tensor to unpatchify. |
| output_shape: Shape of the unpatched output tensor. |
| Returns: |
| Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing |
| metadata associated with the restored latents. |
| """ |
| |
| audio_latents = einops.rearrange( |
| audio_latents, |
| "b t (c f) -> b c t f", |
| c=output_shape.channels, |
| f=output_shape.mel_bins, |
| ) |
|
|
| return audio_latents |
|
|
| def get_patch_grid_bounds( |
| self, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Return the temporal bounds `[inclusive start, exclusive end)` for every |
| patch emitted by `patchify`. For audio this corresponds to timestamps in |
| seconds aligned with the original spectrogram grid. |
| The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: |
| - axis 1 (size 1) represents the temporal dimension |
| - axis 3 (size 2) stores the `[start, end)` timestamps per patch |
| Args: |
| output_shape: Audio grid specification describing the number of time steps. |
| device: Target device for the returned tensor. |
| """ |
| if not isinstance(output_shape, AudioLatentShape): |
| raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") |
|
|
| return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) |
|
|