"""Video modality tiling helpers. Provides :class:`VideoModalityTilingHelper` — a stateless helper that tiles and blends video :class:`Modality` token sequences by spatial/temporal region. Tile geometry is represented by the existing :class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed primitives are required. """ from __future__ import annotations from dataclasses import dataclass, replace import torch from ltx_core.model.transformer.modality import Modality from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count from ltx_core.tools import VideoLatentTools from ltx_core.types import VideoLatentShape @dataclass(frozen=True) class TilingContext: """Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`. Carries the token-level keep mask and per-conditioning-token blend weights needed by :meth:`~VideoModalityTilingHelper.blend`. """ keep_mask: torch.Tensor cond_blend_weights: torch.Tensor | None """``(num_kept_cond,)`` — weight for each kept conditioning token, equal to ``1 / num_tiles_that_keep_this_token``. ``None`` when there are no conditioning tokens.""" class VideoModalityTilingHelper: """Stateless helper that tiles and blends video :class:`Modality` sequences. Constructed once with a :class:`TileCountConfig` and :class:`VideoLatentTools`. Tiles are computed at construction and available via the :attr:`tiles` property. Use :meth:`tile_modality` and :meth:`blend` with any tile from that list. Usage:: helper = VideoModalityTilingHelper(tiling, video_tools) for tile in helper.tiles: tiled_mod, ctx = helper.tile_modality(modality, tile) result = run_model(tiled_mod) helper.blend(result, tile, ctx, output=output) """ def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None: self._patchifier = video_tools.patchifier self._latent_shape = video_tools.target_shape self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape) self._tiles = create_tiles( torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]), splitters=[ split_by_count(tiling.frames.num_tiles, tiling.frames.overlap), split_by_count(tiling.height.num_tiles, tiling.height.overlap), split_by_count(tiling.width.num_tiles, tiling.width.overlap), ], mappers=[identity_mapping_operation] * 3, ) @property def tiles(self) -> list[Tile]: """All tiles for the configured tiling layout.""" return self._tiles # -- tile modality ----------------------------------------------------- def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]: """Slice *modality* to the tokens covered by *tile*. Selects generated tokens belonging to the tile's spatial region and conditioning tokens that overlap with the tile (or have negative time coordinates). Returns: A ``(tiled_modality, context)`` tuple. Pass *context* to :meth:`blend` together with the model output. """ keep_mask = self._keep_mask(modality, tile) tile_attention_mask = None if modality.attention_mask is not None: keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1) tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices] tiled = replace( modality, latent=modality.latent[:, keep_mask, :], timesteps=modality.timesteps[:, keep_mask], positions=modality.positions[:, :, keep_mask, :], attention_mask=tile_attention_mask, ) cond_blend_weights = None num_total = modality.latent.shape[1] if num_total > self._num_generated_tokens: cond_keep = keep_mask[self._num_generated_tokens :] # Count how many tiles keep each conditioning token. cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32) for t in self._tiles: other_mask = self._keep_mask(modality, t) other_cond = other_mask[self._num_generated_tokens :] # Map other tile's kept cond tokens into this tile's kept subset. cond_counts += other_cond[cond_keep].float() cond_blend_weights = 1.0 / cond_counts return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights) # -- blend ------------------------------------------------------------- def blend( self, tile_to_blend: torch.Tensor, tile: Tile, context: TilingContext, output: torch.Tensor | None = None, ) -> torch.Tensor: """Blend-weight tile results and accumulate into the full token space. Premultiplied (blend-weighted) data is **added** to *output*, allowing multiple tiles to be accumulated into the same buffer. Args: tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``, where the first ``_tile_generated_token_count(tile)`` entries are generated tokens and the remainder are conditioning tokens. tile: The :class:`Tile` that was used in :meth:`tile_modality`. context: The :class:`TilingContext` returned by :meth:`tile_modality`. output: Optional pre-allocated output tensor. When provided its shape must be ``(B, num_total_tokens, D)`` and the blended tile is **added** into it. When ``None`` a new zero-filled tensor is created. Returns: The output tensor with the blended tile added at the correct positions. """ batch, _, dim = tile_to_blend.shape num_tile_gen = self._tile_generated_token_count(tile) gen_indices = self._generated_token_indices(tile) num_total_tokens = context.keep_mask.shape[0] expected_shape = (batch, num_total_tokens, dim) if output is not None: if output.shape != expected_shape: raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}") result = output else: result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype) # Blend mask is (tile_F, tile_H, tile_W) — one weight per token in row-major order. blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None] result[:, gen_indices, :] += tile_gen # Scatter kept conditioning tokens, weighted by 1/N where N is # the number of tiles that keep each token (so they sum to 1). if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None: cond_keep = context.keep_mask[self._num_generated_tokens :] cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1) weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None] return result # -- private ----------------------------------------------------------- def _tile_generated_token_count(self, tile: Tile) -> int: """Number of generated tokens in *tile*.""" frame_slice, height_slice, width_slice = tile.in_coords tile_shape = VideoLatentShape( batch=self._latent_shape.batch, channels=self._latent_shape.channels, frames=frame_slice.stop - frame_slice.start, height=height_slice.stop - height_slice.start, width=width_slice.stop - width_slice.start, ) return self._patchifier.get_token_count(tile_shape) def _generated_token_indices(self, tile: Tile) -> torch.Tensor: """Flat token indices of *tile*'s generated tokens in the full sequence.""" frame_slice, height_slice, width_slice = tile.in_coords f = torch.arange(frame_slice.start, frame_slice.stop) h = torch.arange(height_slice.start, height_slice.stop) w = torch.arange(width_slice.start, width_slice.stop) return ( f[:, None, None] * self._latent_shape.height * self._latent_shape.width + h[None, :, None] * self._latent_shape.width + w[None, None, :] ).reshape(-1) def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor: """Boolean mask ``(num_total_tokens,)`` — True for tokens the tile processes. Generated tokens are selected by grid position. Conditioning tokens are kept when their ``[start, end)`` intervals overlap the tile in all three dimensions, or when they have a negative time coordinate (reference tokens). """ num_total = modality.latent.shape[1] mask = torch.zeros(num_total, dtype=torch.bool) gen_indices = self._generated_token_indices(tile) mask[gen_indices] = True if num_total > self._num_generated_tokens: gen_positions = modality.positions[:, :, gen_indices, :] # (B, 3, num_tile_gen, 2) tile_start = gen_positions[..., 0].amin(dim=2) # (B, 3) tile_end = gen_positions[..., 1].amax(dim=2) # (B, 3) cond_positions = modality.positions[:, :, self._num_generated_tokens :, :] # (B, 3, num_cond, 2) overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & ( cond_positions[..., 1] > tile_start.unsqueeze(2) ) # (B, 3, num_cond) overlaps_all_dims = overlaps.all(dim=1) # (B, num_cond) has_negative_time = cond_positions[:, 0, :, 0] < 0 # (B, num_cond) keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0) # (num_cond,) mask[self._num_generated_tokens :] = keep_cond return mask