Update modeling/lance/qwen2_navit.py
Browse files- modeling/lance/qwen2_navit.py +28 -28
modeling/lance/qwen2_navit.py
CHANGED
|
@@ -499,34 +499,34 @@ class PackedAttentionMoT(Qwen2Attention):
|
|
| 499 |
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
|
| 500 |
|
| 501 |
if FLASH_ATTN_AVAILABLE:
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
else:
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
|
| 531 |
if mode == "und":
|
| 532 |
packed_attn_output = self.o_proj(packed_attn_output)
|
|
|
|
| 499 |
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
|
| 500 |
|
| 501 |
if FLASH_ATTN_AVAILABLE:
|
| 502 |
+
packed_attn_output = flash_attn_varlen_func(
|
| 503 |
+
q=packed_query_states,
|
| 504 |
+
k=merged_key_states,
|
| 505 |
+
v=merged_value_states,
|
| 506 |
+
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
|
| 507 |
+
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
|
| 508 |
+
max_seqlen_q=max(query_lens).item(),
|
| 509 |
+
max_seqlen_k=max(key_values_lens).item(),
|
| 510 |
+
causal=is_causal,
|
| 511 |
+
)
|
| 512 |
+
else:
|
| 513 |
+
q = packed_query_states.transpose(0, 1).unsqueeze(0)
|
| 514 |
+
k = merged_key_states.transpose(0, 1).unsqueeze(0)
|
| 515 |
+
v = merged_value_states.transpose(0, 1).unsqueeze(0)
|
| 516 |
+
|
| 517 |
+
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
| 518 |
+
packed_attn_output = scaled_dot_product_attention(
|
| 519 |
+
q,
|
| 520 |
+
k,
|
| 521 |
+
v,
|
| 522 |
+
is_causal=is_causal,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
packed_attn_output = (
|
| 526 |
+
packed_attn_output.squeeze(0)
|
| 527 |
+
.transpose(0, 1)
|
| 528 |
+
.contiguous()
|
| 529 |
+
)
|
| 530 |
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
|
| 531 |
if mode == "und":
|
| 532 |
packed_attn_output = self.o_proj(packed_attn_output)
|