File size: 10,353 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""Video modality tiling helpers.
Provides :class:`VideoModalityTilingHelper` — a stateless helper that
tiles and blends video :class:`Modality` token sequences by
spatial/temporal region.  Tile geometry is represented by the existing
:class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed
primitives are required.
"""

from __future__ import annotations

from dataclasses import dataclass, replace

import torch

from ltx_core.model.transformer.modality import Modality
from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count
from ltx_core.tools import VideoLatentTools
from ltx_core.types import VideoLatentShape


@dataclass(frozen=True)
class TilingContext:
    """Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`.
    Carries the token-level keep mask and per-conditioning-token blend
    weights needed by :meth:`~VideoModalityTilingHelper.blend`.
    """

    keep_mask: torch.Tensor
    cond_blend_weights: torch.Tensor | None
    """``(num_kept_cond,)`` — weight for each kept conditioning token,
    equal to ``1 / num_tiles_that_keep_this_token``.  ``None`` when
    there are no conditioning tokens."""


class VideoModalityTilingHelper:
    """Stateless helper that tiles and blends video :class:`Modality` sequences.
    Constructed once with a :class:`TileCountConfig` and
    :class:`VideoLatentTools`.  Tiles are computed at construction and
    available via the :attr:`tiles` property.  Use :meth:`tile_modality`
    and :meth:`blend` with any tile from that list.
    Usage::
        helper = VideoModalityTilingHelper(tiling, video_tools)
        for tile in helper.tiles:
            tiled_mod, ctx = helper.tile_modality(modality, tile)
            result = run_model(tiled_mod)
            helper.blend(result, tile, ctx, output=output)
    """

    def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None:
        self._patchifier = video_tools.patchifier
        self._latent_shape = video_tools.target_shape
        self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape)
        self._tiles = create_tiles(
            torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]),
            splitters=[
                split_by_count(tiling.frames.num_tiles, tiling.frames.overlap),
                split_by_count(tiling.height.num_tiles, tiling.height.overlap),
                split_by_count(tiling.width.num_tiles, tiling.width.overlap),
            ],
            mappers=[identity_mapping_operation] * 3,
        )

    @property
    def tiles(self) -> list[Tile]:
        """All tiles for the configured tiling layout."""
        return self._tiles

    # -- tile modality -----------------------------------------------------

    def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]:
        """Slice *modality* to the tokens covered by *tile*.
        Selects generated tokens belonging to the tile's spatial region
        and conditioning tokens that overlap with the tile (or have
        negative time coordinates).
        Returns:
            A ``(tiled_modality, context)`` tuple.  Pass *context* to
            :meth:`blend` together with the model output.
        """
        keep_mask = self._keep_mask(modality, tile)

        tile_attention_mask = None
        if modality.attention_mask is not None:
            keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1)
            tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices]

        tiled = replace(
            modality,
            latent=modality.latent[:, keep_mask, :],
            timesteps=modality.timesteps[:, keep_mask],
            positions=modality.positions[:, :, keep_mask, :],
            attention_mask=tile_attention_mask,
        )

        cond_blend_weights = None
        num_total = modality.latent.shape[1]
        if num_total > self._num_generated_tokens:
            cond_keep = keep_mask[self._num_generated_tokens :]
            # Count how many tiles keep each conditioning token.
            cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32)
            for t in self._tiles:
                other_mask = self._keep_mask(modality, t)
                other_cond = other_mask[self._num_generated_tokens :]
                # Map other tile's kept cond tokens into this tile's kept subset.
                cond_counts += other_cond[cond_keep].float()
            cond_blend_weights = 1.0 / cond_counts

        return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights)

    # -- blend -------------------------------------------------------------

    def blend(
        self,
        tile_to_blend: torch.Tensor,
        tile: Tile,
        context: TilingContext,
        output: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Blend-weight tile results and accumulate into the full token space.
        Premultiplied (blend-weighted) data is **added** to *output*,
        allowing multiple tiles to be accumulated into the same buffer.
        Args:
            tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``,
                where the first ``_tile_generated_token_count(tile)``
                entries are generated tokens and the remainder are
                conditioning tokens.
            tile: The :class:`Tile` that was used in :meth:`tile_modality`.
            context: The :class:`TilingContext` returned by :meth:`tile_modality`.
            output: Optional pre-allocated output tensor.  When provided
                its shape must be ``(B, num_total_tokens, D)`` and the
                blended tile is **added** into it.  When ``None`` a new
                zero-filled tensor is created.
        Returns:
            The output tensor with the blended tile added at the correct
            positions.
        """
        batch, _, dim = tile_to_blend.shape
        num_tile_gen = self._tile_generated_token_count(tile)
        gen_indices = self._generated_token_indices(tile)

        num_total_tokens = context.keep_mask.shape[0]
        expected_shape = (batch, num_total_tokens, dim)

        if output is not None:
            if output.shape != expected_shape:
                raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}")
            result = output
        else:
            result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype)

        # Blend mask is (tile_F, tile_H, tile_W) — one weight per token in row-major order.
        blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
        tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None]

        result[:, gen_indices, :] += tile_gen

        # Scatter kept conditioning tokens, weighted by 1/N where N is
        # the number of tiles that keep each token (so they sum to 1).
        if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None:
            cond_keep = context.keep_mask[self._num_generated_tokens :]
            cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1)
            weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
            result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None]

        return result

    # -- private -----------------------------------------------------------

    def _tile_generated_token_count(self, tile: Tile) -> int:
        """Number of generated tokens in *tile*."""
        frame_slice, height_slice, width_slice = tile.in_coords
        tile_shape = VideoLatentShape(
            batch=self._latent_shape.batch,
            channels=self._latent_shape.channels,
            frames=frame_slice.stop - frame_slice.start,
            height=height_slice.stop - height_slice.start,
            width=width_slice.stop - width_slice.start,
        )
        return self._patchifier.get_token_count(tile_shape)

    def _generated_token_indices(self, tile: Tile) -> torch.Tensor:
        """Flat token indices of *tile*'s generated tokens in the full sequence."""
        frame_slice, height_slice, width_slice = tile.in_coords
        f = torch.arange(frame_slice.start, frame_slice.stop)
        h = torch.arange(height_slice.start, height_slice.stop)
        w = torch.arange(width_slice.start, width_slice.stop)
        return (
            f[:, None, None] * self._latent_shape.height * self._latent_shape.width
            + h[None, :, None] * self._latent_shape.width
            + w[None, None, :]
        ).reshape(-1)

    def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor:
        """Boolean mask ``(num_total_tokens,)`` — True for tokens the tile processes.
        Generated tokens are selected by grid position.  Conditioning
        tokens are kept when their ``[start, end)`` intervals overlap
        the tile in all three dimensions, or when they have a negative
        time coordinate (reference tokens).
        """
        num_total = modality.latent.shape[1]
        mask = torch.zeros(num_total, dtype=torch.bool)

        gen_indices = self._generated_token_indices(tile)
        mask[gen_indices] = True

        if num_total > self._num_generated_tokens:
            gen_positions = modality.positions[:, :, gen_indices, :]  # (B, 3, num_tile_gen, 2)
            tile_start = gen_positions[..., 0].amin(dim=2)  # (B, 3)
            tile_end = gen_positions[..., 1].amax(dim=2)  # (B, 3)

            cond_positions = modality.positions[:, :, self._num_generated_tokens :, :]  # (B, 3, num_cond, 2)

            overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & (
                cond_positions[..., 1] > tile_start.unsqueeze(2)
            )  # (B, 3, num_cond)
            overlaps_all_dims = overlaps.all(dim=1)  # (B, num_cond)

            has_negative_time = cond_positions[:, 0, :, 0] < 0  # (B, num_cond)

            keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0)  # (num_cond,)
            mask[self._num_generated_tokens :] = keep_cond

        return mask