| import torch |
| from typing import Optional, Tuple |
|
|
| from .triton_attention import ( |
| fused_mha_with_paged_cache, fused_mha_with_cache |
| ) |
|
|
| dtype_int = torch.int32 |
|
|
| def fused_mha_interface( |
| query_states: torch.Tensor, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| position_ids: torch.Tensor=None, |
| page_table: torch.Tensor=None, |
| max_seq_len = None, |
| ) -> torch.Tensor: |
| """ |
| Replacement for _flash_attention_forward(...) that uses |
| Tritonβs fused_mha_with_paged_cache under the hood. |
| Returns: [batch, q_len, heads*head_dim] |
| """ |
| |
| b, ql, n_heads, head_dim = query_states.shape |
| _, kvl, n_kv_heads, _ = key_states.shape |
|
|
| q = query_states.reshape(b, ql, n_heads * head_dim) |
| k = key_states.reshape(b, kvl, n_kv_heads * head_dim) |
| v = value_states.reshape(b, kvl, n_kv_heads * head_dim) |
|
|
| if position_ids is not None: |
| if ql == 1: |
| input_pos = position_ids[:, -1] |
| else: |
| input_pos = position_ids[:, 0] |
| else: |
| |
| input_pos = torch.zeros(b, device=q.device, dtype=torch.int32) |
| |
| freqs_cis = None |
| |
| if page_table is None: |
| y = torch.ops.attention.fused_mha_with_cache( |
| q, k, v, |
| input_pos, |
| k_cache, v_cache, |
| freqs_cis, |
| ) |
|
|
| |
| else: |
| batch_size = b |
| |
| |
| cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int) |
|
|
| |
| input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int) |
|
|
| |
| seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int) |
|
|
| |
| seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int) |
|
|
| assert max_seq_len is not None, "max_seq_len must be provided when using paged attention." |
|
|
| y = torch.ops.attention.fused_mha_with_paged_cache( |
| q, k, v, |
| input_positions, cache_loc, |
| seq_len, seq_start, |
| page_table, max_seq_len, |
| k_cache, v_cache, |
| freqs_cis, |
| ) |
| |
| y = y.view(b, ql, n_heads, head_dim) |
| |
| return y |
|
|
|
|
|
|
| def main(): |
| |
| batch_size = 1 |
| q_len = 1 |
| kv_len = 1 |
| num_heads = 16 |
| n_kv_heads = 16 |
| head_dim = 128 |
| |
| max_batch_size = 1 |
| max_seq_len = 1024 |
| |
| page_size = 256 |
|
|
| device = "cuda" |
|
|
| |
| query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device) |
| key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
| value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device) |
| |
| k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
| v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device) |
|
|
| attn_out = fused_mha_interface( |
| query_states, |
| key_states, |
| value_states, |
| k_cache=k_cache, |
| v_cache=v_cache, |
| ) |
| |
| expected_shape = (batch_size, q_len, num_heads, head_dim) |
| print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})") |
|
|
| if attn_out.shape == expected_shape: |
| print("[test] β
Success: output tensor has correct shape.") |
| else: |
| print("[test] β Failure: shape mismatch.") |
|
|
| if __name__ == "__main__": |
| main() |