| |
|
|
| import torch |
| import time |
| from HybridTensor.utils.utils import arg_parser, generate_random_BH_index |
| from HybridTensor.utils.profiling import cuda_profiler |
| from HybridTensor.utils.generation import InferenceParams |
| from HybridTensor.utils.utils import sparse_index |
| from HybridTensor.utils.utils import _get_device |
| from HybridTensor.models.create_sparse_model import create_block |
|
|
| class Config: |
| def __init__(self, in_features=8192): |
| self.hidden_size = in_features |
| self.num_attention_heads = in_features // 128 |
| self.head_dim = self.hidden_size // self.num_attention_heads |
| self.scale_attn_weights = True |
| self.mup_scale_qk_dot_by_d = False |
| self.mup_attn_multiplier = 1.0 |
| self.scale_attn_by_inverse_layer_idx = False |
| self.attn_dwconv = False |
| self.qkv_proj_bias = True |
| self.out_proj_bias = True |
| self.rotary_emb_fraction = 0.0 |
| self.rotary_emb_base = 10000.0 |
| self.rotary_emb_scale_base = None |
| self.rotary_emb_interleaved = False |
| self.use_alibi = False |
| self.window_size = (-1, -1) |
| self.use_flash_attn = True |
| self.fused_bias_fc = True |
| self.mlp_sparse = True |
| self.att_sparse = True |
| self.attn_pdrop = 0.1 |
| self.n_inner = None |
| self.activation_function = "relu" |
| self.fused_mlp = True |
| self.mlp_checkpoint_lvl = 0 |
| self.sequence_parallel = False |
| self.layer_norm_epsilon = 1e-5 |
| self.residual_in_fp32 = False |
| self.fused_dropout_add_ln = True |
| self.resid_pdrop = 0.1 |
| self.embd_pdrop = 0.1 |
| self.prenorm = True |
| self.parallel_block = False |
|
|
| class SparseConfig: |
| def __init__(self): |
| self.mlp_low_rank_dim = 1024 |
| self.attn_low_rank_dim = 128 |
| self.attn_topk = 0.5 |
| |
| if __name__ =="__main__": |
| |
| args = arg_parser() |
| |
| config = Config() |
| sp_config = SparseConfig() |
| sp_config.attn_topk = args.attn_topk |
| |
| config.hidden_size = args.in_features |
| config.num_attention_heads = args.in_features // 128 |
| config.use_heuristic = False |
| |
| |
| device = _get_device(args.device) |
| dtype = torch.float16 |
|
|
| |
| sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| sparse_block.eval() |
| sparse_block.mlp_topk = args.index_size |
| sparse_block.mlp.use_heuristic = False |
| |
| regular_config = config |
| regular_config.att_sparse = False |
| regular_config.mlp_sparse = False |
| regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| regular_block.eval() |
| |
| |
| max_seqlen = args.seq_len + 128 |
| max_batch_size = args.batch_size |
| in_features = args.in_features |
| head_dim = 128 |
| batch_size = args.batch_size |
| seq_len = args.seq_len |
| index_size = args.index_size |
| |
| inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
| process_group = None |
| sequence_parallel = False |
| |
| |
| heads = config.num_attention_heads |
| selected_heads = heads // 2 |
| |
| |
| total_neurons = args.in_features * 4 |
| test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) |
| active_indices = sparse_index(args.index_size, total_neurons)[0] |
| test_index_vec[:args.index_size] = active_indices |
| if args.index_size < total_neurons: |
| test_index_vec[args.index_size:] = 0 |
| |
| |
| test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) |
| test_index_size = args.index_size |
| |
| mixer_kwargs = ( |
| {"seqlen": seq_len} |
| if process_group is not None and sequence_parallel |
| else {} |
| ) |
| if inference_params is not None: |
| mixer_kwargs["inference_params"] = inference_params |
| |
| with torch.no_grad(): |
| |
| original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) |
| |
| |
| |
| |
| |
| |
| kv = torch.rand(batch_size, seq_len, 2, heads, head_dim, device='cuda', dtype=torch.float16) |
| |
| |
| sparse_block.mixer._update_kv_cache(kv, inference_params) |
| regular_block.mixer._update_kv_cache(kv, inference_params) |
| mixer_kwargs["inference_params"].seqlen_offset = seq_len |
|
|
| |
| input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) |
| |
| out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) |
| |
| mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| |
| out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) |
| |
| |
| print("Without CUDA Graphs") |
| out_decode_regular, regular_time = cuda_profiler(regular_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
| print(f"Regular time: {regular_time} ms") |
| |
| out_decode_sparse, sparse_time = cuda_profiler(sparse_block, input_x, mixer_kwargs=mixer_kwargs, warmup_runs=1, timed_runs=2) |
| print(f"Sparse time: {sparse_time} ms") |
| |
| speedup = regular_time / sparse_time |
| print(f"Speedup: {speedup}") |
| |
| |
| |
| input_x_static = input_x.clone() |
| output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) |
|
|
| |
| _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| torch.cuda.synchronize() |
| graph_regular = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph_regular): |
| res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| if isinstance(res, tuple): |
| res = res[0] |
| output_regular_static.copy_(res) |
|
|
| |
| |
| mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| if isinstance(temp, tuple): |
| temp = temp[0] |
| |
| |
| output_sparse_static = torch.empty_like(temp) |
| |
| torch.cuda.synchronize() |
| |
| mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| graph_sparse = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph_sparse): |
| res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| if isinstance(res, tuple): |
| res = res[0] |
| output_sparse_static.copy_(res) |
|
|
| |
| for _ in range(5): |
| graph_regular.replay() |
| graph_sparse.replay() |
| torch.cuda.synchronize() |
|
|
| |
| num_replays = 10 |
|
|
| start = time.time() |
| for _ in range(num_replays): |
| graph_regular.replay() |
| torch.cuda.synchronize() |
| regular_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
| start = time.time() |
| for _ in range(num_replays): |
| graph_sparse.replay() |
| torch.cuda.synchronize() |
| sparse_graph_time = (time.time() - start) * 1000 / num_replays |
|
|
| print() |
| print("With CUDA Graphs") |
| print(f"Regular block time (CUDA Graphs): {regular_graph_time} ms") |
| print(f"Sparse block time (CUDA Graphs): {sparse_graph_time} ms") |
| print(f"Speedup (CUDA Graphs): {regular_graph_time/sparse_graph_time}") |
| |
| |
| if args.check_results: |
| if isinstance(out_decode_regular, tuple): |
| out_decode_regular = out_decode_regular[0] |
| regular_match = torch.allclose(out_decode_regular, output_regular_static, rtol=1e-3, atol=1e-5) |
| reg_diff = (out_decode_regular - output_regular_static).abs().max() |
| |
| |
| |
| |
| print("\nComparison for Regular Block:") |
| print(f"Outputs match: {regular_match}") |
| print(f"Max difference: {reg_diff}") |
|
|
| if isinstance(out_decode_sparse, tuple): |
| out_decode_sparse = out_decode_sparse[0] |
| sparse_match = torch.allclose(out_decode_sparse, output_sparse_static, rtol=1e-3, atol=1e-5) |
| spa_diff = (out_decode_sparse - output_sparse_static).abs().max() |
| print("\nComparison for Sparse Block:") |
| print(f"Outputs match: {sparse_match}") |
| print(f"Max difference: {spa_diff}") |
|
|
|
|
|
|