Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| from diffusers.models import WanTransformer3DModel | |
| from diffusers.models.transformers.transformer_wan import WanAttention, _get_qkv_projections, _get_added_kv_projections | |
| class SageWanAttnProcessor: | |
| def __init__(self, attn_func): | |
| self.attn_func = attn_func | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." | |
| ) | |
| def __call__( | |
| self, | |
| attn: "WanAttention", | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| encoder_hidden_states_img = None | |
| if attn.add_k_proj is not None: | |
| # 512 is the context length of the text encoder, hardcoded for now | |
| image_context_length = encoder_hidden_states.shape[1] - 512 | |
| encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] | |
| encoder_hidden_states = encoder_hidden_states[:, image_context_length:] | |
| query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) | |
| query = attn.norm_q(query) | |
| key = attn.norm_k(key) | |
| query = query.unflatten(2, (attn.heads, -1)) | |
| key = key.unflatten(2, (attn.heads, -1)) | |
| value = value.unflatten(2, (attn.heads, -1)) | |
| if rotary_emb is not None: | |
| def apply_rotary_emb( | |
| hidden_states: torch.Tensor, | |
| freqs_cos: torch.Tensor, | |
| freqs_sin: torch.Tensor, | |
| ): | |
| x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) | |
| cos = freqs_cos[..., 0::2] | |
| sin = freqs_sin[..., 1::2] | |
| out = torch.empty_like(hidden_states) | |
| out[..., 0::2] = x1 * cos - x2 * sin | |
| out[..., 1::2] = x1 * sin + x2 * cos | |
| return out.type_as(hidden_states) | |
| query = apply_rotary_emb(query, *rotary_emb) | |
| key = apply_rotary_emb(key, *rotary_emb) | |
| # ---- transpose to (B, H, N, D) for sageattn/sdpa ---- | |
| query = query.transpose(1, 2) | |
| key = key.transpose(1, 2) | |
| value = value.transpose(1, 2) | |
| # I2V task | |
| hidden_states_img = None | |
| if encoder_hidden_states_img is not None: | |
| key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) | |
| key_img = attn.norm_added_k(key_img) | |
| key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| hidden_states_img = self.attn_func( | |
| query, | |
| key_img, | |
| value_img, | |
| attn_mask=None, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) | |
| hidden_states_img = hidden_states_img.type_as(query) | |
| hidden_states = self.attn_func( | |
| query, | |
| key, | |
| value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) | |
| hidden_states = hidden_states.type_as(query) | |
| if hidden_states_img is not None: | |
| hidden_states = hidden_states + hidden_states_img | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| def set_sage_attn_wan( | |
| model: WanTransformer3DModel, | |
| attn_func, | |
| ): | |
| for idx, block in enumerate(model.blocks): | |
| processor = SageWanAttnProcessor(attn_func) | |
| block.attn1.processor = processor |