Spaces:
Running on Zero
Running on Zero
File size: 5,756 Bytes
abd08dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import torch, os
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ModuleNotFoundError:
XFORMERS_AVAILABLE = False
def initialize_attention_priority():
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
elif FLASH_ATTN_3_AVAILABLE:
return "flash_attention_3"
elif FLASH_ATTN_2_AVAILABLE:
return "flash_attention_2"
elif SAGE_ATTN_AVAILABLE:
return "sage_attention"
elif XFORMERS_AVAILABLE:
return "xformers"
else:
return "torch"
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if q_pattern != required_in_pattern:
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
return q, k, v
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if out_pattern != required_out_pattern:
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
return out
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
if isinstance(out, tuple):
out = out[0]
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = sageattn(q, k, v, sm_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = xops.memory_efficient_attention(q, k, v, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
if compatibility_mode or (attn_mask is not None):
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
else:
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "sage_attention":
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "xformers":
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
else:
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|