from enum import Enum from typing import Protocol import torch from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb memory_efficient_attention = None flash_attn_interface = None try: from xformers.ops import memory_efficient_attention except ImportError: memory_efficient_attention = None try: # FlashAttention3 and XFormersAttention cannot be used together if memory_efficient_attention is None: import flash_attn_interface except ImportError: flash_attn_interface = None class AttentionCallable(Protocol): def __call__( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None ) -> torch.Tensor: ... class PytorchAttention(AttentionCallable): def __call__( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None ) -> torch.Tensor: b, _, dim_head = q.shape dim_head //= heads q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)) if mask is not None: # add a batch dimension if there isn't already one if mask.ndim == 2: mask = mask.unsqueeze(0) # add a heads dimension if there isn't already one if mask.ndim == 3: mask = mask.unsqueeze(1) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) return out class XFormersAttention(AttentionCallable): def __call__( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None, ) -> torch.Tensor: if memory_efficient_attention is None: raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.") b, _, dim_head = q.shape dim_head //= heads # xformers expects [B, M, H, K] q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) # Use v.dtype as the target since q/k get cast to v.dtype for xformers target_dtype = v.dtype if mask is not None: # add a singleton batch dimension if mask.ndim == 2: mask = mask.unsqueeze(0) # add a singleton heads dimension if mask.ndim == 3: mask = mask.unsqueeze(1) # pad to a multiple of 8 pad = 8 - mask.shape[-1] % 8 # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk) # but when using separated heads, the shape has to be (B, H, Nq, Nk) # in flux, this matrix ends up being over 1GB # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full) mask_out = torch.empty( [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=target_dtype, device=q.device ) mask_out[..., : mask.shape[-1]] = mask # doesn't this remove the padding again?? mask = mask_out[..., : mask.shape[-1]] mask = mask.expand(b, heads, -1, -1) out = memory_efficient_attention(q.to(target_dtype), k.to(target_dtype), v, attn_bias=mask, p=0.0) out = out.reshape(b, -1, heads * dim_head) return out class FlashAttention3(AttentionCallable): def __call__( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None, ) -> torch.Tensor: if flash_attn_interface is None: raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.") b, _, dim_head = q.shape dim_head //= heads q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) if mask is not None: raise NotImplementedError("Mask is not supported for FlashAttention3") out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v) out = out.reshape(b, -1, heads * dim_head) return out class AttentionFunction(Enum): PYTORCH = "pytorch" XFORMERS = "xformers" FLASH_ATTENTION_3 = "flash_attention_3" DEFAULT = "default" def to_callable(self) -> AttentionCallable: """Resolve to a concrete callable. Use this at module init time so that torch.compile can trace through the attention call without graph breaks.""" if self is AttentionFunction.PYTORCH: return PytorchAttention() elif self is AttentionFunction.XFORMERS: return XFormersAttention() elif self is AttentionFunction.FLASH_ATTENTION_3: return FlashAttention3() else: # Default behavior: XFormers if installed else - PyTorch return XFormersAttention() if memory_efficient_attention is not None else PytorchAttention() class Attention(torch.nn.Module): def __init__( self, query_dim: int, context_dim: int | None = None, heads: int = 8, dim_head: int = 64, norm_eps: float = 1e-6, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT, apply_gated_attention: bool = False, ) -> None: super().__init__() self.rope_type = rope_type self.attention_function = ( attention_function.to_callable() if isinstance(attention_function, AttentionFunction) else attention_function ) inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim self.heads = heads self.dim_head = dim_head self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) # Optional per-head gating if apply_gated_attention: self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) else: self.to_gate_logits = None self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, mask: torch.Tensor | None = None, pe: torch.Tensor | None = None, k_pe: torch.Tensor | None = None, perturbation_mask: torch.Tensor | None = None, all_perturbed: bool = False, ) -> torch.Tensor: """Multi-head attention with optional RoPE, perturbation masking, and per-head gating. When ``perturbation_mask`` is all zeros, the expensive query/key path (linear projections, RMSNorm, RoPE) is skipped entirely and only the value projection is used as a pass-through. Args: x: Query input tensor of shape ``(B, T, query_dim)``. context: Key/value context tensor of shape ``(B, S, context_dim)``. Falls back to ``x`` (self-attention) when *None*. mask: Optional attention mask. Interpretation depends on the attention backend (additive bias for xformers/PyTorch SDPA). pe: Rotary positional embeddings applied to both ``q`` and ``k``. k_pe: Separate rotary positional embeddings for ``k`` only. When *None*, ``pe`` is reused for keys. perturbation_mask: Optional mask in ``[0, 1]`` that blends the attention output with the raw value projection: ``out = attn_out * mask + v * (1 - mask)``. **1** keeps the full attention output, **0** bypasses attention and passes the value projection through unchanged. *None* or all-ones means standard attention; all-zeros skips the query/key path entirely for efficiency. all_perturbed: Whether all perturbations are active for this block. Returns: Output tensor of shape ``(B, T, query_dim)``. """ context = x if context is None else context use_attention = not all_perturbed v = self.to_v(context) if not use_attention: out = v else: q = self.to_q(x) k = self.to_k(context) q = self.q_norm(q) k = self.k_norm(k) if pe is not None: q = apply_rotary_emb(q, pe, self.rope_type) k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) out = self.attention_function(q, k, v, self.heads, mask) # (B, T, H*D) if perturbation_mask is not None: out = out * perturbation_mask + v * (1 - perturbation_mask) # Apply per-head gating if enabled if self.to_gate_logits is not None: gate_logits = self.to_gate_logits(x) # (B, T, H) b, t, _ = out.shape # Reshape to (B, T, H, D) for per-head gating out = out.view(b, t, self.heads, self.dim_head) # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0) gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H) out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1) # Reshape back to (B, T, H*D) out = out.view(b, t, self.heads * self.dim_head) return self.to_out(out)