Upload 150 files
Browse files- source/unet_hacked.py +8 -1
- source/vae_hacked.py +8 -1
source/unet_hacked.py
CHANGED
|
@@ -527,7 +527,14 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
| 527 |
)
|
| 528 |
|
| 529 |
# actually compute the attention, what we cannot get enough of
|
| 530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
if exists(mask):
|
| 533 |
raise NotImplementedError
|
|
|
|
| 527 |
)
|
| 528 |
|
| 529 |
# actually compute the attention, what we cannot get enough of
|
| 530 |
+
try:
|
| 531 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
| 532 |
+
except (NotImplementedError, RuntimeError):
|
| 533 |
+
# Fallback to standard attention for CPU or unsupported configs
|
| 534 |
+
scale = self.dim_head ** -0.5
|
| 535 |
+
attn_weights = torch.bmm(q * scale, k.transpose(-2, -1))
|
| 536 |
+
attn_weights = torch.softmax(attn_weights, dim=-1)
|
| 537 |
+
out = torch.bmm(attn_weights, v)
|
| 538 |
|
| 539 |
if exists(mask):
|
| 540 |
raise NotImplementedError
|
source/vae_hacked.py
CHANGED
|
@@ -260,7 +260,14 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|
| 260 |
.contiguous(),
|
| 261 |
(q, k, v),
|
| 262 |
)
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
out = (
|
| 266 |
out.unsqueeze(0)
|
|
|
|
| 260 |
.contiguous(),
|
| 261 |
(q, k, v),
|
| 262 |
)
|
| 263 |
+
try:
|
| 264 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
| 265 |
+
except (NotImplementedError, RuntimeError):
|
| 266 |
+
# Fallback to standard attention for CPU or unsupported configs
|
| 267 |
+
scale = C ** -0.5
|
| 268 |
+
attn_weights = torch.bmm(q * scale, k.transpose(-2, -1))
|
| 269 |
+
attn_weights = torch.softmax(attn_weights, dim=-1)
|
| 270 |
+
out = torch.bmm(attn_weights, v)
|
| 271 |
|
| 272 |
out = (
|
| 273 |
out.unsqueeze(0)
|