| import torch |
| import random |
| import torch.nn.functional as F |
|
|
| import flash_mla |
|
|
| |
|
|
|
|
| def test_flash_mla(): |
| |
| |
| |
| |
| |
| |
| |
|
|
| b = 16 |
| s_q = 16 |
| mean_sk = 16 |
| h_q = 16 |
| h_kv = 1 |
| d = 576 |
| dv = 512 |
|
|
|
|
| causal = True |
| varlen = False |
|
|
| print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") |
|
|
| cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) |
| if varlen: |
| for i in range(b): |
| cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) |
| total_seqlens = cache_seqlens.sum().item() |
| mean_seqlens = cache_seqlens.float().mean().int().item() |
| max_seqlen = cache_seqlens.max().item() |
| |
| |
| print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") |
| max_seqlen_pad = max_seqlen + 255 & ~255 |
| q = torch.randn(b, s_q, h_q, d) |
| block_size = 64 |
| block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view( |
| b, max_seqlen_pad // block_size |
| ) |
| blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) |
| print(blocked_k.shape) |
| for i in range(b): |
| blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float( |
| "nan" |
| ) |
| blocked_v = blocked_k[..., :dv] |
| print(blocked_k.shape, blocked_v.shape) |
|
|
| cache_seqlens = cache_seqlens.to("cuda") |
|
|
| tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( |
| seqlens_k=cache_seqlens, |
| |
| s_q=s_q * h_q // h_kv, |
| h_kv=h_kv, |
| ) |
| print(tile_scheduler_metadata, num_splits) |
|
|
| |
| assert False |
|
|