Spaces:
Running on Zero
Running on Zero
| import math | |
| from typing import Optional, Tuple | |
| import einops | |
| import torch | |
| from ltx_core.components.protocols import Patchifier | |
| from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape | |
| class VideoLatentPatchifier(Patchifier): | |
| def __init__(self, patch_size: int): | |
| # Patch sizes for video latents. | |
| self._patch_size = ( | |
| 1, # temporal dimension | |
| patch_size, # height dimension | |
| patch_size, # width dimension | |
| ) | |
| def patch_size(self) -> Tuple[int, int, int]: | |
| return self._patch_size | |
| def get_token_count(self, tgt_shape: VideoLatentShape) -> int: | |
| return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) | |
| def patchify( | |
| self, | |
| latents: torch.Tensor, | |
| ) -> torch.Tensor: | |
| latents = einops.rearrange( | |
| latents, | |
| "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", | |
| p1=self._patch_size[0], | |
| p2=self._patch_size[1], | |
| p3=self._patch_size[2], | |
| ) | |
| return latents | |
| def unpatchify( | |
| self, | |
| latents: torch.Tensor, | |
| output_shape: VideoLatentShape, | |
| ) -> torch.Tensor: | |
| assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" | |
| patch_grid_frames = output_shape.frames // self._patch_size[0] | |
| patch_grid_height = output_shape.height // self._patch_size[1] | |
| patch_grid_width = output_shape.width // self._patch_size[2] | |
| latents = einops.rearrange( | |
| latents, | |
| "b (f h w) (c p q) -> b c f (h p) (w q)", | |
| f=patch_grid_frames, | |
| h=patch_grid_height, | |
| w=patch_grid_width, | |
| p=self._patch_size[1], | |
| q=self._patch_size[2], | |
| ) | |
| return latents | |
| def get_patch_grid_bounds( | |
| self, | |
| output_shape: AudioLatentShape | VideoLatentShape, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Return the per-dimension bounds [inclusive start, exclusive end) for every | |
| patch produced by `patchify`. The bounds are expressed in the original | |
| video grid coordinates: frame/time, height, and width. | |
| The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: | |
| - axis 1 (size 3) enumerates (frame/time, height, width) dimensions | |
| - axis 3 (size 2) stores `[start, end)` indices within each dimension | |
| Args: | |
| output_shape: Video grid description containing frames, height, and width. | |
| device: Device of the latent tensor. | |
| """ | |
| if not isinstance(output_shape, VideoLatentShape): | |
| raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") | |
| frames = output_shape.frames | |
| height = output_shape.height | |
| width = output_shape.width | |
| batch_size = output_shape.batch | |
| # Validate inputs to ensure positive dimensions | |
| assert frames > 0, f"frames must be positive, got {frames}" | |
| assert height > 0, f"height must be positive, got {height}" | |
| assert width > 0, f"width must be positive, got {width}" | |
| assert batch_size > 0, f"batch_size must be positive, got {batch_size}" | |
| # Generate grid coordinates for each dimension (frame, height, width) | |
| # We use torch.arange to create the starting coordinates for each patch. | |
| # indexing='ij' ensures the dimensions are in the order (frame, height, width). | |
| grid_coords = torch.meshgrid( | |
| torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), | |
| torch.arange(start=0, end=height, step=self._patch_size[1], device=device), | |
| torch.arange(start=0, end=width, step=self._patch_size[2], device=device), | |
| indexing="ij", | |
| ) | |
| # Stack the grid coordinates to create the start coordinates tensor. | |
| # Shape becomes (3, grid_f, grid_h, grid_w) | |
| patch_starts = torch.stack(grid_coords, dim=0) | |
| # Create a tensor containing the size of a single patch: | |
| # (frame_patch_size, height_patch_size, width_patch_size). | |
| # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. | |
| patch_size_delta = torch.tensor( | |
| self._patch_size, | |
| device=patch_starts.device, | |
| dtype=patch_starts.dtype, | |
| ).view(3, 1, 1, 1) | |
| # Calculate end coordinates: start + patch_size | |
| # Shape becomes (3, grid_f, grid_h, grid_w) | |
| patch_ends = patch_starts + patch_size_delta | |
| # Stack start and end coordinates together along the last dimension | |
| # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] | |
| latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) | |
| # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. | |
| # Final Shape: (batch_size, 3, num_patches, 2) | |
| latent_coords = einops.repeat( | |
| latent_coords, | |
| "c f h w bounds -> b c (f h w) bounds", | |
| b=batch_size, | |
| bounds=2, | |
| ) | |
| return latent_coords | |
| def get_pixel_coords( | |
| latent_coords: torch.Tensor, | |
| scale_factors: SpatioTemporalScaleFactors, | |
| causal_fix: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling | |
| each axis (frame/time, height, width) with the corresponding VAE downsampling factors. | |
| Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. | |
| Args: | |
| latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. | |
| scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied | |
| per axis. | |
| causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs | |
| that treat frame zero differently still yield non-negative timestamps. | |
| """ | |
| # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. | |
| broadcast_shape = [1] * latent_coords.ndim | |
| broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) | |
| scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) | |
| # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. | |
| pixel_coords = latent_coords * scale_tensor | |
| if causal_fix: | |
| # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. | |
| # Shift and clamp to keep the first-frame timestamps causal and non-negative. | |
| pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) | |
| return pixel_coords | |
| class AudioPatchifier(Patchifier): | |
| def __init__( | |
| self, | |
| patch_size: int, | |
| sample_rate: int = 16000, | |
| hop_length: int = 160, | |
| audio_latent_downsample_factor: int = 4, | |
| is_causal: bool = True, | |
| shift: int = 0, | |
| ): | |
| """ | |
| Patchifier tailored for spectrogram/audio latents. | |
| Args: | |
| patch_size: Number of mel bins combined into a single patch. This | |
| controls the resolution along the frequency axis. | |
| sample_rate: Original waveform sampling rate. Used to map latent | |
| indices back to seconds so downstream consumers can align audio | |
| and video cues. | |
| hop_length: Window hop length used for the spectrogram. Determines | |
| how many real-time samples separate two consecutive latent frames. | |
| audio_latent_downsample_factor: Ratio between spectrogram frames and | |
| latent frames; compensates for additional downsampling inside the | |
| VAE encoder. | |
| is_causal: When True, timing is shifted to account for causal | |
| receptive fields so timestamps do not peek into the future. | |
| shift: Integer offset applied to the latent indices. Enables | |
| constructing overlapping windows from the same latent sequence. | |
| """ | |
| self.hop_length = hop_length | |
| self.sample_rate = sample_rate | |
| self.audio_latent_downsample_factor = audio_latent_downsample_factor | |
| self.is_causal = is_causal | |
| self.shift = shift | |
| self._patch_size = (1, patch_size, patch_size) | |
| def patch_size(self) -> Tuple[int, int, int]: | |
| return self._patch_size | |
| def get_token_count(self, tgt_shape: AudioLatentShape) -> int: | |
| return tgt_shape.frames | |
| def _get_audio_latent_time_in_sec( | |
| self, | |
| start_latent: int, | |
| end_latent: int, | |
| dtype: torch.dtype, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Converts latent indices into real-time seconds while honoring causal | |
| offsets and the configured hop length. | |
| Args: | |
| start_latent: Inclusive start index inside the latent sequence. This | |
| sets the first timestamp returned. | |
| end_latent: Exclusive end index. Determines how many timestamps get | |
| generated. | |
| dtype: Floating-point dtype used for the returned tensor, allowing | |
| callers to control precision. | |
| device: Target device for the timestamp tensor. When omitted the | |
| computation occurs on CPU to avoid surprising GPU allocations. | |
| """ | |
| if device is None: | |
| device = torch.device("cpu") | |
| audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) | |
| audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor | |
| if self.is_causal: | |
| # Frame offset for causal alignment. | |
| # The "+1" ensures the timestamp corresponds to the first sample that is fully available. | |
| causal_offset = 1 | |
| audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) | |
| return audio_mel_frame * self.hop_length / self.sample_rate | |
| def _compute_audio_timings( | |
| self, | |
| batch_size: int, | |
| num_steps: int, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. | |
| This helper method underpins `get_patch_grid_bounds` for the audio patchifier. | |
| Args: | |
| batch_size: Number of sequences to broadcast the timings over. | |
| num_steps: Number of latent frames (time steps) to convert into timestamps. | |
| device: Device on which the resulting tensor should reside. | |
| """ | |
| resolved_device = device | |
| if resolved_device is None: | |
| resolved_device = torch.device("cpu") | |
| start_timings = self._get_audio_latent_time_in_sec( | |
| self.shift, | |
| num_steps + self.shift, | |
| torch.float32, | |
| resolved_device, | |
| ) | |
| start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) | |
| end_timings = self._get_audio_latent_time_in_sec( | |
| self.shift + 1, | |
| num_steps + self.shift + 1, | |
| torch.float32, | |
| resolved_device, | |
| ) | |
| end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) | |
| return torch.stack([start_timings, end_timings], dim=-1) | |
| def patchify( | |
| self, | |
| audio_latents: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` | |
| to derive timestamps for each latent frame based on the configured hop | |
| length and downsampling. | |
| Args: | |
| audio_latents: Latent tensor to patchify. | |
| Returns: | |
| Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the | |
| corresponding timing metadata when needed. | |
| """ | |
| audio_latents = einops.rearrange( | |
| audio_latents, | |
| "b c t f -> b t (c f)", | |
| ) | |
| return audio_latents | |
| def unpatchify( | |
| self, | |
| audio_latents: torch.Tensor, | |
| output_shape: AudioLatentShape, | |
| ) -> torch.Tensor: | |
| """ | |
| Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. | |
| Use `get_patch_grid_bounds` to recompute the timestamps that describe each | |
| frame's position in real time. | |
| Args: | |
| audio_latents: Latent tensor to unpatchify. | |
| output_shape: Shape of the unpatched output tensor. | |
| Returns: | |
| Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing | |
| metadata associated with the restored latents. | |
| """ | |
| # audio_latents shape: (batch, time, freq * channels) | |
| audio_latents = einops.rearrange( | |
| audio_latents, | |
| "b t (c f) -> b c t f", | |
| c=output_shape.channels, | |
| f=output_shape.mel_bins, | |
| ) | |
| return audio_latents | |
| def get_patch_grid_bounds( | |
| self, | |
| output_shape: AudioLatentShape | VideoLatentShape, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Return the temporal bounds `[inclusive start, exclusive end)` for every | |
| patch emitted by `patchify`. For audio this corresponds to timestamps in | |
| seconds aligned with the original spectrogram grid. | |
| The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: | |
| - axis 1 (size 1) represents the temporal dimension | |
| - axis 3 (size 2) stores the `[start, end)` timestamps per patch | |
| Args: | |
| output_shape: Audio grid specification describing the number of time steps. | |
| device: Target device for the returned tensor. | |
| """ | |
| if not isinstance(output_shape, AudioLatentShape): | |
| raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") | |
| return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) | |