| """Video modality tiling helpers. |
| Provides :class:`VideoModalityTilingHelper` — a stateless helper that |
| tiles and blends video :class:`Modality` token sequences by |
| spatial/temporal region. Tile geometry is represented by the existing |
| :class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed |
| primitives are required. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, replace |
|
|
| import torch |
|
|
| from ltx_core.model.transformer.modality import Modality |
| from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count |
| from ltx_core.tools import VideoLatentTools |
| from ltx_core.types import VideoLatentShape |
|
|
|
|
| @dataclass(frozen=True) |
| class TilingContext: |
| """Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`. |
| Carries the token-level keep mask and per-conditioning-token blend |
| weights needed by :meth:`~VideoModalityTilingHelper.blend`. |
| """ |
|
|
| keep_mask: torch.Tensor |
| cond_blend_weights: torch.Tensor | None |
| """``(num_kept_cond,)`` — weight for each kept conditioning token, |
| equal to ``1 / num_tiles_that_keep_this_token``. ``None`` when |
| there are no conditioning tokens.""" |
|
|
|
|
| class VideoModalityTilingHelper: |
| """Stateless helper that tiles and blends video :class:`Modality` sequences. |
| Constructed once with a :class:`TileCountConfig` and |
| :class:`VideoLatentTools`. Tiles are computed at construction and |
| available via the :attr:`tiles` property. Use :meth:`tile_modality` |
| and :meth:`blend` with any tile from that list. |
| Usage:: |
| helper = VideoModalityTilingHelper(tiling, video_tools) |
| for tile in helper.tiles: |
| tiled_mod, ctx = helper.tile_modality(modality, tile) |
| result = run_model(tiled_mod) |
| helper.blend(result, tile, ctx, output=output) |
| """ |
|
|
| def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None: |
| self._patchifier = video_tools.patchifier |
| self._latent_shape = video_tools.target_shape |
| self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape) |
| self._tiles = create_tiles( |
| torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]), |
| splitters=[ |
| split_by_count(tiling.frames.num_tiles, tiling.frames.overlap), |
| split_by_count(tiling.height.num_tiles, tiling.height.overlap), |
| split_by_count(tiling.width.num_tiles, tiling.width.overlap), |
| ], |
| mappers=[identity_mapping_operation] * 3, |
| ) |
|
|
| @property |
| def tiles(self) -> list[Tile]: |
| """All tiles for the configured tiling layout.""" |
| return self._tiles |
|
|
| |
|
|
| def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]: |
| """Slice *modality* to the tokens covered by *tile*. |
| Selects generated tokens belonging to the tile's spatial region |
| and conditioning tokens that overlap with the tile (or have |
| negative time coordinates). |
| Returns: |
| A ``(tiled_modality, context)`` tuple. Pass *context* to |
| :meth:`blend` together with the model output. |
| """ |
| keep_mask = self._keep_mask(modality, tile) |
|
|
| tile_attention_mask = None |
| if modality.attention_mask is not None: |
| keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1) |
| tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices] |
|
|
| tiled = replace( |
| modality, |
| latent=modality.latent[:, keep_mask, :], |
| timesteps=modality.timesteps[:, keep_mask], |
| positions=modality.positions[:, :, keep_mask, :], |
| attention_mask=tile_attention_mask, |
| ) |
|
|
| cond_blend_weights = None |
| num_total = modality.latent.shape[1] |
| if num_total > self._num_generated_tokens: |
| cond_keep = keep_mask[self._num_generated_tokens :] |
| |
| cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32) |
| for t in self._tiles: |
| other_mask = self._keep_mask(modality, t) |
| other_cond = other_mask[self._num_generated_tokens :] |
| |
| cond_counts += other_cond[cond_keep].float() |
| cond_blend_weights = 1.0 / cond_counts |
|
|
| return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights) |
|
|
| |
|
|
| def blend( |
| self, |
| tile_to_blend: torch.Tensor, |
| tile: Tile, |
| context: TilingContext, |
| output: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Blend-weight tile results and accumulate into the full token space. |
| Premultiplied (blend-weighted) data is **added** to *output*, |
| allowing multiple tiles to be accumulated into the same buffer. |
| Args: |
| tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``, |
| where the first ``_tile_generated_token_count(tile)`` |
| entries are generated tokens and the remainder are |
| conditioning tokens. |
| tile: The :class:`Tile` that was used in :meth:`tile_modality`. |
| context: The :class:`TilingContext` returned by :meth:`tile_modality`. |
| output: Optional pre-allocated output tensor. When provided |
| its shape must be ``(B, num_total_tokens, D)`` and the |
| blended tile is **added** into it. When ``None`` a new |
| zero-filled tensor is created. |
| Returns: |
| The output tensor with the blended tile added at the correct |
| positions. |
| """ |
| batch, _, dim = tile_to_blend.shape |
| num_tile_gen = self._tile_generated_token_count(tile) |
| gen_indices = self._generated_token_indices(tile) |
|
|
| num_total_tokens = context.keep_mask.shape[0] |
| expected_shape = (batch, num_total_tokens, dim) |
|
|
| if output is not None: |
| if output.shape != expected_shape: |
| raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}") |
| result = output |
| else: |
| result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype) |
|
|
| |
| blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) |
| tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None] |
|
|
| result[:, gen_indices, :] += tile_gen |
|
|
| |
| |
| if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None: |
| cond_keep = context.keep_mask[self._num_generated_tokens :] |
| cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1) |
| weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) |
| result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None] |
|
|
| return result |
|
|
| |
|
|
| def _tile_generated_token_count(self, tile: Tile) -> int: |
| """Number of generated tokens in *tile*.""" |
| frame_slice, height_slice, width_slice = tile.in_coords |
| tile_shape = VideoLatentShape( |
| batch=self._latent_shape.batch, |
| channels=self._latent_shape.channels, |
| frames=frame_slice.stop - frame_slice.start, |
| height=height_slice.stop - height_slice.start, |
| width=width_slice.stop - width_slice.start, |
| ) |
| return self._patchifier.get_token_count(tile_shape) |
|
|
| def _generated_token_indices(self, tile: Tile) -> torch.Tensor: |
| """Flat token indices of *tile*'s generated tokens in the full sequence.""" |
| frame_slice, height_slice, width_slice = tile.in_coords |
| f = torch.arange(frame_slice.start, frame_slice.stop) |
| h = torch.arange(height_slice.start, height_slice.stop) |
| w = torch.arange(width_slice.start, width_slice.stop) |
| return ( |
| f[:, None, None] * self._latent_shape.height * self._latent_shape.width |
| + h[None, :, None] * self._latent_shape.width |
| + w[None, None, :] |
| ).reshape(-1) |
|
|
| def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor: |
| """Boolean mask ``(num_total_tokens,)`` — True for tokens the tile processes. |
| Generated tokens are selected by grid position. Conditioning |
| tokens are kept when their ``[start, end)`` intervals overlap |
| the tile in all three dimensions, or when they have a negative |
| time coordinate (reference tokens). |
| """ |
| num_total = modality.latent.shape[1] |
| mask = torch.zeros(num_total, dtype=torch.bool) |
|
|
| gen_indices = self._generated_token_indices(tile) |
| mask[gen_indices] = True |
|
|
| if num_total > self._num_generated_tokens: |
| gen_positions = modality.positions[:, :, gen_indices, :] |
| tile_start = gen_positions[..., 0].amin(dim=2) |
| tile_end = gen_positions[..., 1].amax(dim=2) |
|
|
| cond_positions = modality.positions[:, :, self._num_generated_tokens :, :] |
|
|
| overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & ( |
| cond_positions[..., 1] > tile_start.unsqueeze(2) |
| ) |
| overlaps_all_dims = overlaps.all(dim=1) |
|
|
| has_negative_time = cond_positions[:, 0, :, 0] < 0 |
|
|
| keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0) |
| mask[self._num_generated_tokens :] = keep_cond |
|
|
| return mask |
|
|