qian43 commited on
Commit
8f55a74
·
verified ·
1 Parent(s): 3ed17f3

Upload 150 files

Browse files
Files changed (2) hide show
  1. source/unet_hacked.py +8 -1
  2. source/vae_hacked.py +8 -1
source/unet_hacked.py CHANGED
@@ -527,7 +527,14 @@ class MemoryEfficientCrossAttention(nn.Module):
527
  )
528
 
529
  # actually compute the attention, what we cannot get enough of
530
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
 
 
 
 
 
 
 
531
 
532
  if exists(mask):
533
  raise NotImplementedError
 
527
  )
528
 
529
  # actually compute the attention, what we cannot get enough of
530
+ try:
531
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
532
+ except (NotImplementedError, RuntimeError):
533
+ # Fallback to standard attention for CPU or unsupported configs
534
+ scale = self.dim_head ** -0.5
535
+ attn_weights = torch.bmm(q * scale, k.transpose(-2, -1))
536
+ attn_weights = torch.softmax(attn_weights, dim=-1)
537
+ out = torch.bmm(attn_weights, v)
538
 
539
  if exists(mask):
540
  raise NotImplementedError
source/vae_hacked.py CHANGED
@@ -260,7 +260,14 @@ class MemoryEfficientAttnBlock(nn.Module):
260
  .contiguous(),
261
  (q, k, v),
262
  )
263
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
 
 
 
 
 
 
 
264
 
265
  out = (
266
  out.unsqueeze(0)
 
260
  .contiguous(),
261
  (q, k, v),
262
  )
263
+ try:
264
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
265
+ except (NotImplementedError, RuntimeError):
266
+ # Fallback to standard attention for CPU or unsupported configs
267
+ scale = C ** -0.5
268
+ attn_weights = torch.bmm(q * scale, k.transpose(-2, -1))
269
+ attn_weights = torch.softmax(attn_weights, dim=-1)
270
+ out = torch.bmm(attn_weights, v)
271
 
272
  out = (
273
  out.unsqueeze(0)