Spaces:
Running on Zero
Running on Zero
| """Utilities for building 2D self-attention masks for conditioning items.""" | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING | |
| import torch | |
| if TYPE_CHECKING: | |
| from ltx_core.types import LatentState | |
| def resolve_cross_mask( | |
| attention_mask: float | int | torch.Tensor, | |
| num_new_tokens: int, | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| """Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor. | |
| Args: | |
| attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,) | |
| broadcast across batch, or 2D tensor of shape (B, M). | |
| num_new_tokens: Number of new conditioning tokens M. | |
| batch_size: Batch size B. | |
| device: Device for the output tensor. | |
| dtype: Data type for the output tensor. | |
| Returns: | |
| Cross-mask tensor of shape (B, M). | |
| """ | |
| if isinstance(attention_mask, (int, float)): | |
| return torch.full( | |
| (batch_size, num_new_tokens), | |
| fill_value=float(attention_mask), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| mask = attention_mask.to(device=device, dtype=dtype) | |
| # Handle scalar (0-D) tensor like a Python scalar. | |
| if mask.dim() == 0: | |
| return torch.full( | |
| (batch_size, num_new_tokens), | |
| fill_value=float(mask.item()), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| if mask.dim() == 1: | |
| if mask.shape[0] != num_new_tokens: | |
| raise ValueError( | |
| f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}" | |
| ) | |
| mask = mask.unsqueeze(0).expand(batch_size, -1) | |
| elif mask.dim() == 2: | |
| b, m = mask.shape | |
| if m != num_new_tokens: | |
| raise ValueError( | |
| f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), " | |
| f"got shape {tuple(mask.shape)}" | |
| ) | |
| if b not in (batch_size, 1): | |
| raise ValueError( | |
| f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, " | |
| f"got shape {tuple(mask.shape)}" | |
| ) | |
| if b == 1 and batch_size > 1: | |
| mask = mask.expand(batch_size, -1) | |
| else: | |
| raise ValueError( | |
| f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}" | |
| ) | |
| return mask | |
| def update_attention_mask( | |
| latent_state: LatentState, | |
| attention_mask: float | torch.Tensor | None, | |
| num_noisy_tokens: int, | |
| num_new_tokens: int, | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor | None: | |
| """Build or update the self-attention mask for newly appended conditioning tokens. | |
| If *attention_mask* is ``None`` and no existing mask is present, returns | |
| ``None``. If *attention_mask* is ``None`` but an existing mask is present, | |
| the mask is expanded with full attention (1s) for the new tokens so that | |
| its dimensions stay consistent with the growing latent sequence. Otherwise, | |
| resolves *attention_mask* to a per-token cross-mask and expands the 2-D | |
| attention mask via :func:`build_attention_mask`. | |
| Args: | |
| latent_state: Current latent state (provides the existing mask and total | |
| existing-token count). | |
| attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D | |
| ``(B, M)`` tensor, or ``None`` (no-op). | |
| num_noisy_tokens: Number of original noisy tokens (from | |
| ``latent_tools.target_shape.token_count()``). | |
| num_new_tokens: Number of new conditioning tokens being appended. | |
| batch_size: Batch size. | |
| device: Device for the output tensor. | |
| dtype: Data type for the output tensor. | |
| Returns: | |
| Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no | |
| masking is needed. | |
| """ | |
| if attention_mask is None: | |
| if latent_state.attention_mask is None: | |
| return None | |
| # Existing mask present but no new mask requested: pad with 1s (full | |
| # attention) so the mask dimensions stay consistent with the growing | |
| # latent sequence. | |
| cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype) | |
| return build_attention_mask( | |
| existing_mask=latent_state.attention_mask, | |
| num_noisy_tokens=num_noisy_tokens, | |
| num_new_tokens=num_new_tokens, | |
| num_existing_tokens=latent_state.latent.shape[1], | |
| cross_mask=cross_mask, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype) | |
| return build_attention_mask( | |
| existing_mask=latent_state.attention_mask, | |
| num_noisy_tokens=num_noisy_tokens, | |
| num_new_tokens=num_new_tokens, | |
| num_existing_tokens=latent_state.latent.shape[1], | |
| cross_mask=cross_mask, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def build_attention_mask( | |
| existing_mask: torch.Tensor | None, | |
| num_noisy_tokens: int, | |
| num_new_tokens: int, | |
| num_existing_tokens: int, | |
| cross_mask: torch.Tensor, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| """ | |
| Expand the attention mask to include newly appended conditioning tokens. | |
| Each conditioning item appends M new reference tokens to the sequence. This function | |
| builds a (B, N+M, N+M) attention mask with the following block structure: | |
| noisy prev_ref new_ref | |
| (N_noisy) (N-N_noisy) (M) | |
| βββββββββββββ¬ββββββββββββ¬ββββββββββββ | |
| noisy β β β β | |
| (N_noisy) β existing β existing β cross β | |
| β β β β | |
| βββββββββββββΌββββββββββββΌββββββββββββ€ | |
| prev_ref β β β β | |
| (N-N_noisy)β existing β existing β 0 β | |
| β β β β | |
| βββββββββββββΌββββββββββββΌββββββββββββ€ | |
| new_ref β β β β | |
| (M) β cross β 0 β 1 β | |
| β β β β | |
| βββββββββββββ΄ββββββββββββ΄ββββββββββββ | |
| Where: | |
| - **existing**: preserved from the previous mask (or 1.0 if first conditioning) | |
| - **cross**: values from *cross_mask* (shape B, M), in [0, 1] | |
| - **0**: no attention between different reference groups | |
| Args: | |
| existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet. | |
| When None, the top-left NxN block is filled with 1s (full attention between all | |
| existing tokens including any prior reference tokens that had no mask). | |
| num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]). | |
| num_new_tokens: Number of new conditioning tokens M being appended. | |
| num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens). | |
| cross_mask: Per-token attention weight of shape (B, M) controlling attention between | |
| new reference tokens and noisy tokens. Values in [0, 1]. | |
| device: Device for the output tensor. | |
| dtype: Data type for the output tensor. | |
| Returns: | |
| Attention mask of shape (B, N+M, N+M) with values in [0, 1]. | |
| """ | |
| batch_size = cross_mask.shape[0] | |
| total = num_existing_tokens + num_new_tokens | |
| # Start with zeros | |
| mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype) | |
| # Top-left: preserve existing mask or fill with 1s for noisy tokens | |
| if existing_mask is not None: | |
| mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask | |
| else: | |
| mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0 | |
| # Bottom-right: new reference tokens fully attend to themselves | |
| mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0 | |
| # Cross-attention between noisy tokens and new reference tokens | |
| # cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy) | |
| # Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M] | |
| # Each column j in this block gets cross_mask[:, j] | |
| mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1) | |
| # New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy] | |
| # Each row i in this block gets cross_mask[:, i] | |
| mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2) | |
| # [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention) | |
| return mask | |