| import torch |
| import flash_attn |
|
|
| |
| torch.manual_seed(0) |
|
|
|
|
| def _attention_torch(query, key, value, *, backend): |
| query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) |
| with torch.nn.attention.sdpa_kernel(backend): |
| out = torch.nn.functional.scaled_dot_product_attention(query, key, value) |
| out = out.transpose(1, 2).contiguous() |
| return out |
|
|
|
|
| def test_flash_attn(): |
| """Test standard flash attention with mha_fwd""" |
| print("===== Testing mha_fwd =====") |
|
|
| batch_size = 1 |
| seq_len = 4224 |
| num_attention_heads = 24 |
| attention_head_dim = 128 |
|
|
| shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) |
|
|
| print(f"Testing shape: {shape}") |
| print(f"Batch size: {batch_size}, Seq len: {seq_len}") |
| print(f"Num heads: {num_attention_heads}, Head dim: {attention_head_dim}") |
|
|
| query = torch.randn(shape, device="cuda", dtype=torch.float16) |
| key = torch.randn(shape, device="cuda", dtype=torch.float16) |
| value = torch.randn(shape, device="cuda", dtype=torch.float16) |
|
|
| |
| golden_truth = _attention_torch( |
| query, key, value, backend=torch.nn.attention.SDPBackend.MATH |
| ) |
|
|
| print(f"Golden truth shape: {golden_truth.shape}") |
| print(f"Query sum: {query.sum().item()}") |
|
|
| |
| out, softmax_lse, p, rng_state = flash_attn.fwd( |
| q=query, |
| k=key, |
| v=value, |
| is_causal=False, |
| ) |
|
|
| print(f"Flash attention output shape: {out.shape}") |
| print(f"Query sum after attention: {query.sum().item()}") |
|
|
| |
| diff = (out - golden_truth).abs().max() |
| print(f"Max absolute difference (non-causal): {diff.item()}") |
|
|
| assert out.shape == shape |
| assert diff < 1e-2, f"Difference too large: {diff.item()}" |
|
|
| |
| print("\n--- Testing with causal=True ---") |
| out_causal, _, _, _ = flash_attn.fwd( |
| q=query, |
| k=key, |
| v=value, |
| is_causal=True, |
| ) |
|
|
| print(f"Causal attention output shape: {out_causal.shape}") |
| assert out_causal.shape == shape |
|
|
| |
| diff_causal = (out - out_causal).abs().max() |
| print(f"Difference between causal and non-causal: {diff_causal.item()}") |
| assert diff_causal > 1e-3, "Causal and non-causal should produce different results" |
|
|
| print("✓ mha_fwd test passed!") |
|
|
|
|
| def test_mha_varlen_fwd(): |
| """Test variable-length sequences with mha_varlen_fwd""" |
| print("\n===== Testing mha_varlen_fwd =====") |
|
|
| |
| |
| seq_lens = [512, 1024, 256] |
| total_seq_len = sum(seq_lens) |
| num_attention_heads = 16 |
| attention_head_dim = 64 |
|
|
| |
| cu_seqlens = torch.tensor( |
| [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], |
| device="cuda", |
| dtype=torch.int32, |
| ) |
|
|
| print(f"Sequence lengths: {seq_lens}") |
| print(f"Cumulative sequence lengths: {cu_seqlens}") |
| print(f"Total sequence length: {total_seq_len}") |
|
|
| |
| query = torch.randn( |
| total_seq_len, |
| num_attention_heads, |
| attention_head_dim, |
| device="cuda", |
| dtype=torch.float16, |
| ) |
| key = torch.randn( |
| total_seq_len, |
| num_attention_heads, |
| attention_head_dim, |
| device="cuda", |
| dtype=torch.float16, |
| ) |
| value = torch.randn( |
| total_seq_len, |
| num_attention_heads, |
| attention_head_dim, |
| device="cuda", |
| dtype=torch.float16, |
| ) |
|
|
| print(f"Query shape: {query.shape}") |
| print(f"Key shape: {key.shape}") |
| print(f"Value shape: {value.shape}") |
|
|
| |
| |
| golden_truth_parts = [] |
| for i, seq_len in enumerate(seq_lens): |
| start_idx = cu_seqlens[i] |
| end_idx = cu_seqlens[i + 1] |
|
|
| |
| q_seq = query[start_idx:end_idx].unsqueeze(0) |
| k_seq = key[start_idx:end_idx].unsqueeze(0) |
| v_seq = value[start_idx:end_idx].unsqueeze(0) |
|
|
| |
| golden_seq = _attention_torch( |
| q_seq, k_seq, v_seq, backend=torch.nn.attention.SDPBackend.MATH |
| ) |
| golden_truth_parts.append(golden_seq.squeeze(0)) |
|
|
| |
| golden_truth = torch.cat(golden_truth_parts, dim=0) |
| print(f"Golden truth shape: {golden_truth.shape}") |
|
|
| |
| out, softmax_lse, p, rng_state = flash_attn.varlen_fwd( |
| q=query, |
| k=key, |
| v=value, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max(seq_lens), |
| max_seqlen_k=max(seq_lens), |
| is_causal=False, |
| ) |
|
|
| print(f"Flash attention varlen output shape: {out.shape}") |
| print(f"Output should match input: {out.shape == query.shape}") |
|
|
| |
| diff = (out - golden_truth).abs().max() |
| print(f"Max absolute difference (non-causal): {diff.item()}") |
|
|
| |
| assert out.shape == (total_seq_len, num_attention_heads, attention_head_dim) |
| assert diff < 1e-2, f"Difference too large: {diff.item()}" |
|
|
| |
| print("\n--- Testing with causal=True ---") |
| out_causal, _, _, _ = flash_attn.varlen_fwd( |
| q=query, |
| k=key, |
| v=value, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max(seq_lens), |
| max_seqlen_k=max(seq_lens), |
| is_causal=True, |
| ) |
|
|
| print(f"Causal attention output shape: {out_causal.shape}") |
| assert out_causal.shape == (total_seq_len, num_attention_heads, attention_head_dim) |
|
|
| |
| diff_causal = (out - out_causal).abs().max() |
| print(f"Difference between causal and non-causal: {diff_causal.item()}") |
| assert diff_causal > 1e-3, "Causal and non-causal should produce different results" |
|
|
| print("✓ mha_varlen_fwd test passed!") |
|
|
|
|
| if __name__ == "__main__": |
| test_flash_attn() |
| test_mha_varlen_fwd() |
|
|