Nayefleb commited on
Commit
2aeeb88
·
verified ·
1 Parent(s): a8a6a4c

Update modeling/lance/qwen2_navit.py

Browse files
Files changed (1) hide show
  1. modeling/lance/qwen2_navit.py +28 -28
modeling/lance/qwen2_navit.py CHANGED
@@ -499,34 +499,34 @@ class PackedAttentionMoT(Qwen2Attention):
499
  cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
500
 
501
  if FLASH_ATTN_AVAILABLE:
502
- packed_attn_output = flash_attn_varlen_func(
503
- q=packed_query_states,
504
- k=merged_key_states,
505
- v=merged_value_states,
506
- cu_seqlens_q=cu_seqlens_q.to(torch.int32),
507
- cu_seqlens_k=cu_seqlens_k.to(torch.int32),
508
- max_seqlen_q=max(query_lens).item(),
509
- max_seqlen_k=max(key_values_lens).item(),
510
- causal=is_causal,
511
- )
512
- else:
513
- q = packed_query_states.transpose(0, 1).unsqueeze(0)
514
- k = merged_key_states.transpose(0, 1).unsqueeze(0)
515
- v = merged_value_states.transpose(0, 1).unsqueeze(0)
516
-
517
- packed_attn_output = scaled_dot_product_attention(
518
- q,
519
- k,
520
- v,
521
- is_causal=is_causal,
522
- )
523
-
524
- packed_attn_output = (
525
- packed_attn_output
526
- .squeeze(0)
527
- .transpose(0, 1)
528
- .contiguous()
529
- )
530
  packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
531
  if mode == "und":
532
  packed_attn_output = self.o_proj(packed_attn_output)
 
499
  cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
500
 
501
  if FLASH_ATTN_AVAILABLE:
502
+ packed_attn_output = flash_attn_varlen_func(
503
+ q=packed_query_states,
504
+ k=merged_key_states,
505
+ v=merged_value_states,
506
+ cu_seqlens_q=cu_seqlens_q.to(torch.int32),
507
+ cu_seqlens_k=cu_seqlens_k.to(torch.int32),
508
+ max_seqlen_q=max(query_lens).item(),
509
+ max_seqlen_k=max(key_values_lens).item(),
510
+ causal=is_causal,
511
+ )
512
+ else:
513
+ q = packed_query_states.transpose(0, 1).unsqueeze(0)
514
+ k = merged_key_states.transpose(0, 1).unsqueeze(0)
515
+ v = merged_value_states.transpose(0, 1).unsqueeze(0)
516
+
517
+ with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
518
+ packed_attn_output = scaled_dot_product_attention(
519
+ q,
520
+ k,
521
+ v,
522
+ is_causal=is_causal,
523
+ )
524
+
525
+ packed_attn_output = (
526
+ packed_attn_output.squeeze(0)
527
+ .transpose(0, 1)
528
+ .contiguous()
529
+ )
530
  packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
531
  if mode == "und":
532
  packed_attn_output = self.o_proj(packed_attn_output)