File size: 14,286 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import math
from typing import Optional, Tuple

import einops
import torch

from ltx_core.components.protocols import Patchifier
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape


class VideoLatentPatchifier(Patchifier):
    def __init__(self, patch_size: int):
        # Patch sizes for video latents.
        self._patch_size = (
            1,  # temporal dimension
            patch_size,  # height dimension
            patch_size,  # width dimension
        )

    @property
    def patch_size(self) -> Tuple[int, int, int]:
        return self._patch_size

    def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
        return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)

    def patchify(
        self,
        latents: torch.Tensor,
    ) -> torch.Tensor:
        latents = einops.rearrange(
            latents,
            "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
            p1=self._patch_size[0],
            p2=self._patch_size[1],
            p3=self._patch_size[2],
        )

        return latents

    def unpatchify(
        self,
        latents: torch.Tensor,
        output_shape: VideoLatentShape,
    ) -> torch.Tensor:
        assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"

        patch_grid_frames = output_shape.frames // self._patch_size[0]
        patch_grid_height = output_shape.height // self._patch_size[1]
        patch_grid_width = output_shape.width // self._patch_size[2]

        latents = einops.rearrange(
            latents,
            "b (f h w) (c p q) -> b c f (h p) (w q)",
            f=patch_grid_frames,
            h=patch_grid_height,
            w=patch_grid_width,
            p=self._patch_size[1],
            q=self._patch_size[2],
        )

        return latents

    def get_patch_grid_bounds(
        self,
        output_shape: AudioLatentShape | VideoLatentShape,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Return the per-dimension bounds [inclusive start, exclusive end) for every
        patch produced by `patchify`. The bounds are expressed in the original
        video grid coordinates: frame/time, height, and width.
        The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
            - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
            - axis 3 (size 2) stores `[start, end)` indices within each dimension
        Args:
            output_shape: Video grid description containing frames, height, and width.
            device: Device of the latent tensor.
        """
        if not isinstance(output_shape, VideoLatentShape):
            raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")

        frames = output_shape.frames
        height = output_shape.height
        width = output_shape.width
        batch_size = output_shape.batch

        # Validate inputs to ensure positive dimensions
        assert frames > 0, f"frames must be positive, got {frames}"
        assert height > 0, f"height must be positive, got {height}"
        assert width > 0, f"width must be positive, got {width}"
        assert batch_size > 0, f"batch_size must be positive, got {batch_size}"

        # Generate grid coordinates for each dimension (frame, height, width)
        # We use torch.arange to create the starting coordinates for each patch.
        # indexing='ij' ensures the dimensions are in the order (frame, height, width).
        grid_coords = torch.meshgrid(
            torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
            torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
            torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
            indexing="ij",
        )

        # Stack the grid coordinates to create the start coordinates tensor.
        # Shape becomes (3, grid_f, grid_h, grid_w)
        patch_starts = torch.stack(grid_coords, dim=0)

        # Create a tensor containing the size of a single patch:
        # (frame_patch_size, height_patch_size, width_patch_size).
        # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
        patch_size_delta = torch.tensor(
            self._patch_size,
            device=patch_starts.device,
            dtype=patch_starts.dtype,
        ).view(3, 1, 1, 1)

        # Calculate end coordinates: start + patch_size
        # Shape becomes (3, grid_f, grid_h, grid_w)
        patch_ends = patch_starts + patch_size_delta

        # Stack start and end coordinates together along the last dimension
        # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
        latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)

        # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
        # Final Shape: (batch_size, 3, num_patches, 2)
        latent_coords = einops.repeat(
            latent_coords,
            "c f h w bounds -> b c (f h w) bounds",
            b=batch_size,
            bounds=2,
        )

        return latent_coords


def get_pixel_coords(
    latent_coords: torch.Tensor,
    scale_factors: SpatioTemporalScaleFactors,
    causal_fix: bool = False,
) -> torch.Tensor:
    """
    Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
    each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
    Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
    Args:
        latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
        scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
        per axis.
        causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
            that treat frame zero differently still yield non-negative timestamps.
    """
    # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
    broadcast_shape = [1] * latent_coords.ndim
    broadcast_shape[1] = -1  # axis dimension corresponds to (frame/time, height, width)
    scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)

    # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
    pixel_coords = latent_coords * scale_tensor

    if causal_fix:
        # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
        # Shift and clamp to keep the first-frame timestamps causal and non-negative.
        pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)

    return pixel_coords


class AudioPatchifier(Patchifier):
    def __init__(
        self,
        patch_size: int,
        sample_rate: int = 16000,
        hop_length: int = 160,
        audio_latent_downsample_factor: int = 4,
        is_causal: bool = True,
        shift: int = 0,
    ):
        """
        Patchifier tailored for spectrogram/audio latents.
        Args:
            patch_size: Number of mel bins combined into a single patch. This
                controls the resolution along the frequency axis.
            sample_rate: Original waveform sampling rate. Used to map latent
                indices back to seconds so downstream consumers can align audio
                and video cues.
            hop_length: Window hop length used for the spectrogram. Determines
                how many real-time samples separate two consecutive latent frames.
            audio_latent_downsample_factor: Ratio between spectrogram frames and
                latent frames; compensates for additional downsampling inside the
                VAE encoder.
            is_causal: When True, timing is shifted to account for causal
                receptive fields so timestamps do not peek into the future.
            shift: Integer offset applied to the latent indices. Enables
                constructing overlapping windows from the same latent sequence.
        """
        self.hop_length = hop_length
        self.sample_rate = sample_rate
        self.audio_latent_downsample_factor = audio_latent_downsample_factor
        self.is_causal = is_causal
        self.shift = shift
        self._patch_size = (1, patch_size, patch_size)

    @property
    def patch_size(self) -> Tuple[int, int, int]:
        return self._patch_size

    def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
        return tgt_shape.frames

    def _get_audio_latent_time_in_sec(
        self,
        start_latent: int,
        end_latent: int,
        dtype: torch.dtype,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Converts latent indices into real-time seconds while honoring causal
        offsets and the configured hop length.
        Args:
            start_latent: Inclusive start index inside the latent sequence. This
                sets the first timestamp returned.
            end_latent: Exclusive end index. Determines how many timestamps get
                generated.
            dtype: Floating-point dtype used for the returned tensor, allowing
                callers to control precision.
            device: Target device for the timestamp tensor. When omitted the
                computation occurs on CPU to avoid surprising GPU allocations.
        """
        if device is None:
            device = torch.device("cpu")

        audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)

        audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor

        if self.is_causal:
            # Frame offset for causal alignment.
            # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
            causal_offset = 1
            audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)

        return audio_mel_frame * self.hop_length / self.sample_rate

    def _compute_audio_timings(
        self,
        batch_size: int,
        num_steps: int,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
        This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
        Args:
            batch_size: Number of sequences to broadcast the timings over.
            num_steps: Number of latent frames (time steps) to convert into timestamps.
            device: Device on which the resulting tensor should reside.
        """
        resolved_device = device
        if resolved_device is None:
            resolved_device = torch.device("cpu")

        start_timings = self._get_audio_latent_time_in_sec(
            self.shift,
            num_steps + self.shift,
            torch.float32,
            resolved_device,
        )
        start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)

        end_timings = self._get_audio_latent_time_in_sec(
            self.shift + 1,
            num_steps + self.shift + 1,
            torch.float32,
            resolved_device,
        )
        end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)

        return torch.stack([start_timings, end_timings], dim=-1)

    def patchify(
        self,
        audio_latents: torch.Tensor,
    ) -> torch.Tensor:
        """
        Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
        to derive timestamps for each latent frame based on the configured hop
        length and downsampling.
        Args:
            audio_latents: Latent tensor to patchify.
        Returns:
            Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
            corresponding timing metadata when needed.
        """
        audio_latents = einops.rearrange(
            audio_latents,
            "b c t f -> b t (c f)",
        )

        return audio_latents

    def unpatchify(
        self,
        audio_latents: torch.Tensor,
        output_shape: AudioLatentShape,
    ) -> torch.Tensor:
        """
        Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
        Use `get_patch_grid_bounds` to recompute the timestamps that describe each
        frame's position in real time.
        Args:
            audio_latents: Latent tensor to unpatchify.
            output_shape: Shape of the unpatched output tensor.
        Returns:
            Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
            metadata associated with the restored latents.
        """
        # audio_latents shape: (batch, time, freq * channels)
        audio_latents = einops.rearrange(
            audio_latents,
            "b t (c f) -> b c t f",
            c=output_shape.channels,
            f=output_shape.mel_bins,
        )

        return audio_latents

    def get_patch_grid_bounds(
        self,
        output_shape: AudioLatentShape | VideoLatentShape,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Return the temporal bounds `[inclusive start, exclusive end)` for every
        patch emitted by `patchify`. For audio this corresponds to timestamps in
        seconds aligned with the original spectrogram grid.
        The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
            - axis 1 (size 1) represents the temporal dimension
            - axis 3 (size 2) stores the `[start, end)` timestamps per patch
        Args:
            output_shape: Audio grid specification describing the number of time steps.
            device: Target device for the returned tensor.
        """
        if not isinstance(output_shape, AudioLatentShape):
            raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")

        return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)