Dramabox / patches /attention.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
from enum import Enum
from typing import Protocol
import torch
from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
memory_efficient_attention = None
flash_attn_interface = None
try:
from xformers.ops import memory_efficient_attention
except ImportError:
memory_efficient_attention = None
try:
# FlashAttention3 and XFormersAttention cannot be used together
if memory_efficient_attention is None:
import flash_attn_interface
except ImportError:
flash_attn_interface = None
class AttentionCallable(Protocol):
def __call__(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
) -> torch.Tensor: ...
class PytorchAttention(AttentionCallable):
def __call__(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
) -> torch.Tensor:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
return out
class XFormersAttention(AttentionCallable):
def __call__(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
heads: int,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
if memory_efficient_attention is None:
raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
b, _, dim_head = q.shape
dim_head //= heads
# xformers expects [B, M, H, K]
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
# Use v.dtype as the target since q/k get cast to v.dtype for xformers
target_dtype = v.dtype
if mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty(
[mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=target_dtype, device=q.device
)
mask_out[..., : mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., : mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = memory_efficient_attention(q.to(target_dtype), k.to(target_dtype), v, attn_bias=mask, p=0.0)
out = out.reshape(b, -1, heads * dim_head)
return out
class FlashAttention3(AttentionCallable):
def __call__(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
heads: int,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
if flash_attn_interface is None:
raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
if mask is not None:
raise NotImplementedError("Mask is not supported for FlashAttention3")
out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
out = out.reshape(b, -1, heads * dim_head)
return out
class AttentionFunction(Enum):
PYTORCH = "pytorch"
XFORMERS = "xformers"
FLASH_ATTENTION_3 = "flash_attention_3"
DEFAULT = "default"
def to_callable(self) -> AttentionCallable:
"""Resolve to a concrete callable. Use this at module init time so that
torch.compile can trace through the attention call without graph breaks."""
if self is AttentionFunction.PYTORCH:
return PytorchAttention()
elif self is AttentionFunction.XFORMERS:
return XFormersAttention()
elif self is AttentionFunction.FLASH_ATTENTION_3:
return FlashAttention3()
else:
# Default behavior: XFormers if installed else - PyTorch
return XFormersAttention() if memory_efficient_attention is not None else PytorchAttention()
class Attention(torch.nn.Module):
def __init__(
self,
query_dim: int,
context_dim: int | None = None,
heads: int = 8,
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
apply_gated_attention: bool = False,
) -> None:
super().__init__()
self.rope_type = rope_type
self.attention_function = (
attention_function.to_callable()
if isinstance(attention_function, AttentionFunction)
else attention_function
)
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.heads = heads
self.dim_head = dim_head
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
def forward(
self,
x: torch.Tensor,
context: torch.Tensor | None = None,
mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
k_pe: torch.Tensor | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool = False,
) -> torch.Tensor:
"""Multi-head attention with optional RoPE, perturbation masking, and per-head gating.
When ``perturbation_mask`` is all zeros, the expensive query/key path
(linear projections, RMSNorm, RoPE) is skipped entirely and only the
value projection is used as a pass-through.
Args:
x: Query input tensor of shape ``(B, T, query_dim)``.
context: Key/value context tensor of shape ``(B, S, context_dim)``.
Falls back to ``x`` (self-attention) when *None*.
mask: Optional attention mask. Interpretation depends on the attention
backend (additive bias for xformers/PyTorch SDPA).
pe: Rotary positional embeddings applied to both ``q`` and ``k``.
k_pe: Separate rotary positional embeddings for ``k`` only. When
*None*, ``pe`` is reused for keys.
perturbation_mask: Optional mask in ``[0, 1]`` that
blends the attention output with the raw value projection:
``out = attn_out * mask + v * (1 - mask)``.
**1** keeps the full attention output, **0** bypasses attention
and passes the value projection through unchanged.
*None* or all-ones means standard attention; all-zeros skips
the query/key path entirely for efficiency.
all_perturbed: Whether all perturbations are active for this block.
Returns:
Output tensor of shape ``(B, T, query_dim)``.
"""
context = x if context is None else context
use_attention = not all_perturbed
v = self.to_v(context)
if not use_attention:
out = v
else:
q = self.to_q(x)
k = self.to_k(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
out = self.attention_function(q, k, v, self.heads, mask) # (B, T, H*D)
if perturbation_mask is not None:
out = out * perturbation_mask + v * (1 - perturbation_mask)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
# Reshape to (B, T, H, D) for per-head gating
out = out.view(b, t, self.heads, self.dim_head)
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
# Reshape back to (B, T, H*D)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)