ZeroWan2GP / modify_model /modify_wan.py
Daankular's picture
Add SageAttn + LightX2V distill LoRA + remove text enc quant + flow_shift=6
11fda45 verified
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
from diffusers.models import WanTransformer3DModel
from diffusers.models.transformers.transformer_wan import WanAttention, _get_qkv_projections, _get_added_kv_projections
class SageWanAttnProcessor:
def __init__(self, attn_func):
self.attn_func = attn_func
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# ---- transpose to (B, H, N, D) for sageattn/sdpa ----
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = self.attn_func(
query,
key_img,
value_img,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = self.attn_func(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def set_sage_attn_wan(
model: WanTransformer3DModel,
attn_func,
):
for idx, block in enumerate(model.blocks):
processor = SageWanAttnProcessor(attn_func)
block.attn1.processor = processor