Spaces:
Running on Zero
Running on Zero
| import torch, os | |
| from einops import rearrange | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_2_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_2_AVAILABLE = False | |
| try: | |
| from sageattention import sageattn | |
| SAGE_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| SAGE_ATTN_AVAILABLE = False | |
| try: | |
| import xformers.ops as xops | |
| XFORMERS_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| XFORMERS_AVAILABLE = False | |
| def initialize_attention_priority(): | |
| if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None: | |
| return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower() | |
| elif FLASH_ATTN_3_AVAILABLE: | |
| return "flash_attention_3" | |
| elif FLASH_ATTN_2_AVAILABLE: | |
| return "flash_attention_2" | |
| elif SAGE_ATTN_AVAILABLE: | |
| return "sage_attention" | |
| elif XFORMERS_AVAILABLE: | |
| return "xformers" | |
| else: | |
| return "torch" | |
| ATTENTION_IMPLEMENTATION = initialize_attention_priority() | |
| def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None): | |
| dims = {} if dims is None else dims | |
| if q_pattern != required_in_pattern: | |
| q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims) | |
| if k_pattern != required_in_pattern: | |
| k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) | |
| if v_pattern != required_in_pattern: | |
| v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) | |
| return q, k, v | |
| def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None): | |
| dims = {} if dims is None else dims | |
| if out_pattern != required_out_pattern: | |
| out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims) | |
| return out | |
| def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None): | |
| required_in_pattern, required_out_pattern= "b n s d", "b n s d" | |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) | |
| out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale) | |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) | |
| return out | |
| def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): | |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" | |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) | |
| out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale) | |
| if isinstance(out, tuple): | |
| out = out[0] | |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) | |
| return out | |
| def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): | |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" | |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) | |
| out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale) | |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) | |
| return out | |
| def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): | |
| required_in_pattern, required_out_pattern= "b n s d", "b n s d" | |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) | |
| out = sageattn(q, k, v, sm_scale=scale) | |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) | |
| return out | |
| def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): | |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" | |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) | |
| out = xops.memory_efficient_attention(q, k, v, scale=scale) | |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) | |
| return out | |
| def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False): | |
| if compatibility_mode or (attn_mask is not None): | |
| return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale) | |
| else: | |
| if ATTENTION_IMPLEMENTATION == "flash_attention_3": | |
| return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) | |
| elif ATTENTION_IMPLEMENTATION == "flash_attention_2": | |
| return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) | |
| elif ATTENTION_IMPLEMENTATION == "sage_attention": | |
| return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) | |
| elif ATTENTION_IMPLEMENTATION == "xformers": | |
| return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) | |
| else: | |
| return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) | |