| |
|
|
| import math |
| import torch |
| from HybridTensor.modules.SelectiveMHA import SMHA, _update_kv_cache |
| from HybridTensor.utils.utils import arg_parser, generate_random_BH_index |
| from HybridTensor.utils.profiling import cuda_profiler |
| from HybridTensor.utils.generation import InferenceParams |
|
|
| if __name__ == "__main__": |
| args = arg_parser() |
| |
| max_seqlen = args.seq_len + 128 |
| max_batch_size = args.batch_size |
| device = torch.device(f"cuda:{args.device}") |
| |
| |
| inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
| nheads = args.in_features // 128 |
| softmax_scale = 1 / (128 ** 0.5) |
| rotary_emb_dim = 0 |
| |
| mha = SMHA( |
| embed_dim=args.in_features, |
| num_heads=nheads, |
| num_heads_kv=None, |
| causal=True, |
| layer_idx=0, |
| use_flash_attn=True, |
| softmax_scale=softmax_scale, |
| return_residual=False, |
| rotary_emb_dim=rotary_emb_dim, |
| device=device, |
| dtype=torch.float16, |
| ) |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| |
| with torch.no_grad(): |
| |
| og_x = torch.randn(args.batch_size, args.seq_len, args.in_features, device=device, dtype=torch.float16, requires_grad=False) |
| |
| |
| kv = torch.randn(args.batch_size, args.seq_len, 2, nheads, 128, device=device, dtype=torch.float16, requires_grad=False) |
| _ = _update_kv_cache(kv, inference_params, 0) |
| |
| |
| inference_params.seqlen_offset += args.seq_len |
| |
| input_x = torch.randn(args.batch_size, 1, args.in_features, device=device, dtype=torch.float16, requires_grad=False) |
| selected_heads = math.ceil(nheads * args.head_density) |
| |
| |
| |
| batch_head_index = generate_random_BH_index(args.batch_size, nheads, selected_heads, device=device) |
| |
| |
| out, standard_time_ms = cuda_profiler(mha, input_x, inference_params=inference_params, batch_head_idx=None) |
| print(f"Standard MHA time: {standard_time_ms:.3f} ms") |
| |
| |
| out, select_time_ms = cuda_profiler(mha, input_x, inference_params=inference_params, batch_head_idx=batch_head_index) |
| print(f"SelectMHA time: {select_time_ms:.3f} ms") |
| |
| speedup = standard_time_ms / select_time_ms |
| print(f"Speedup: {speedup:.3f}") |
| |
|
|