| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| try: |
| from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy |
| except ImportError: |
| |
| import sys |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) |
| finally: |
| from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy |
|
|
| import verl.utils.torch_functional as verl_F |
|
|
| compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) |
|
|
| MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) |
| VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False) |
| LOW_MEMORY = os.environ.get("LOW_MEMORY", False) |
| LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16) |
|
|
|
|
| def run_torch_entropy( |
| hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" |
| ) -> list[torch.Tensor]: |
| |
| if len(hidden.shape) > 2: |
| hidden = hidden.view(-1, hidden.shape[-1]) |
| if len(labels.shape) > 1: |
| labels = labels.view(-1) |
| logits = torch.matmul( |
| hidden.to(torch.float32), |
| weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32), |
| ) |
| logits /= temperature |
| pd = torch.nn.functional.softmax(logits, dim=-1) |
| entropy_a = torch.logsumexp(logits, dim=-1) |
| entropy_b = torch.sum(pd * logits, dim=-1) |
| entropy = entropy_a - entropy_b |
| logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) |
| logprobs = torch.neg(logprobs) |
| return logprobs, entropy |
|
|
|
|
| class TorchEntropyTP(torch.autograd.Function): |
| """ |
| it is used for testing the correctness of the kernel |
| it is not efficient and is not recommended to use in practice |
| """ |
|
|
| @staticmethod |
| def forward( |
| ctx, |
| hidden: torch.Tensor, |
| weight: torch.Tensor, |
| labels: torch.Tensor, |
| temperature: float, |
| dist_process_group: torch.distributed.ProcessGroup, |
| ): |
| |
| ctx.original_hidden_shape = hidden.shape |
| if len(hidden.shape) > 2: |
| hidden = hidden.view(-1, hidden.shape[-1]) |
| if len(labels.shape) > 1: |
| labels = labels.view(-1) |
|
|
| logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) |
| logits /= temperature |
| whole_logits = torch.empty( |
| (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), |
| dtype=logits.dtype, |
| device=logits.device, |
| ) |
| whole_logits_ref = [ |
| whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]] |
| for i in range(dist.get_world_size(dist_process_group)) |
| ] |
| dist.all_gather(whole_logits_ref, logits, group=dist_process_group) |
|
|
| pd = torch.nn.functional.softmax(whole_logits, dim=-1) |
| entropy_a = torch.logsumexp(whole_logits, dim=-1) |
| entropy_b = torch.sum(pd * whole_logits, dim=-1) |
| entropy = entropy_a - entropy_b |
|
|
| logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") |
| logprobs = torch.neg(logprobs) |
|
|
| ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) |
| ctx.dist_process_group = dist_process_group |
| ctx.temperature = temperature |
| return logprobs, entropy |
|
|
| @staticmethod |
| def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): |
| hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors |
| dist_process_group = ctx.dist_process_group |
| temperature = ctx.temperature |
| batch_size, hidden_size = hidden.shape |
| vocab_size, hidden_size = weight.shape |
| rank = dist.get_rank(dist_process_group) |
|
|
| |
| maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) |
| exp_logits = torch.exp(whole_logits - maximum) |
| accumulate = exp_logits.sum(dim=-1, keepdim=True) |
| pd = exp_logits / accumulate |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) |
|
|
| |
| |
| |
| one_hot = torch.zeros_like(whole_logits) |
| one_hot.scatter_(1, labels.unsqueeze(1), 1) |
| g_logprobs = torch.neg(g_logprobs) |
| d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) |
| |
| |
|
|
| |
| d_logits = d_logits_entropy + d_logits_logprobs |
| d_logits /= temperature |
|
|
| |
| local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size] |
|
|
| |
| d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32)) |
| d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32)) |
| d_hidden = d_hidden.view(ctx.original_hidden_shape) |
|
|
| return d_hidden, d_weight, None, None, None |
|
|
|
|
| run_torch_entropy_tp = TorchEntropyTP.apply |
|
|
|
|
| class TestLinearCrossEntropy_TensorParallel: |
| def __init__(self): |
| dist.init_process_group(backend="nccl") |
| self.group = dist.group.WORLD |
|
|
| self.local_rank = dist.get_rank(self.group) |
| self.world_size = dist.get_world_size(self.group) |
| device = torch.device(f"cuda:{self.local_rank}") |
| torch.cuda.set_device(device) |
| print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") |
|
|
| def initialize(self, test_case_idx: int, temperature: float = 1.5): |
| self.test_case_idx = test_case_idx |
| self.temperature = temperature |
|
|
| def shutdown(self): |
| dist.destroy_process_group() |
|
|
| def cleanup(self): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| import gc |
|
|
| gc.collect() |
| torch.cuda.synchronize() |
|
|
| def generate_hyper(self): |
| global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES |
|
|
| self.dtype = torch.bfloat16 |
| if self.test_case_idx == 0: |
| self.batch_size = 1 |
| self.num_tokens = 1937 |
| self.hidden_size = 3584 |
| self.vocab_size = 152064 |
| elif self.test_case_idx == 1: |
| self.batch_size = 1 |
| self.num_tokens = 2169 |
| self.hidden_size = 896 |
| self.vocab_size = 151936 |
| elif self.test_case_idx == 2: |
| self.batch_size = 1 |
| self.num_tokens = 1530 |
| self.hidden_size = 2048 |
| self.vocab_size = 32256 |
| elif self.test_case_idx == 3: |
| self.batch_size = 1 |
| self.num_tokens = 1388 |
| self.hidden_size = 4096 |
| self.vocab_size = 102400 |
| elif self.test_case_idx == 4: |
| self.batch_size = 1 |
| self.num_tokens = 8192 |
| self.hidden_size = 4096 |
| self.vocab_size = 102400 |
| else: |
| raise ValueError(f"Invalid test case index: {self.test_case_idx}") |
| if LOW_MEMORY: |
| self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR) |
| assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." |
|
|
| def generate_forward_inputs(self): |
| hidden = ( |
| torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") |
| .uniform_(-0.5, 0.5) |
| .requires_grad_() |
| ) |
| weight = ( |
| torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") |
| .uniform_(-0.5, 0.5) |
| .requires_grad_() |
| ) |
| labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") |
| return hidden, weight, labels |
|
|
| def generate_backward_inputs(self): |
| g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) |
| g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) |
| return g_entropy, g_logprobs |
|
|
| def verify_torch_itself(self, iterations: int = 5): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| for i in range(iterations): |
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| |
| dist.broadcast(hidden, src=0, group=self.group) |
| dist.broadcast(labels, src=0, group=self.group) |
|
|
| |
| |
| |
| |
|
|
| |
| whole_weight = torch.empty( |
| (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device |
| ) |
|
|
| |
| whole_weight_views = [ |
| whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size) |
| ] |
|
|
| |
| dist.all_gather(whole_weight_views, weight, group=self.group) |
|
|
| |
| whole_weight.requires_grad_() |
|
|
| (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) |
|
|
| (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) |
|
|
| torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) |
| torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) |
|
|
| |
| g_entropy, g_logprobs = self.generate_backward_inputs() |
| |
| dist.broadcast(g_entropy, src=0, group=self.group) |
| dist.broadcast(g_logprobs, src=0, group=self.group) |
|
|
| (single_d_hidden, single_d_weight) = torch.autograd.grad( |
| (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
|
|
| (tp_d_hidden, tp_d_weight) = torch.autograd.grad( |
| (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| |
| dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) |
|
|
| torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) |
| |
| |
| |
| torch.testing.assert_close( |
| tp_d_weight, |
| single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], |
| atol=1e-2, |
| rtol=1e-4, |
| ) |
|
|
| |
| if self.local_rank == 0: |
| print("[PASS] torch TP correctness is verified") |
|
|
| def check_torch_storage(self): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| |
| dist.broadcast(hidden, src=0, group=self.group) |
| dist.broadcast(labels, src=0, group=self.group) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) |
| torch.cuda.synchronize() |
| forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
|
|
| g_entropy, g_logprobs = self.generate_backward_inputs() |
| |
| dist.broadcast(g_entropy, src=0, group=self.group) |
| dist.broadcast(g_logprobs, src=0, group=self.group) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (d_tp_hidden, d_tp_weight) = torch.autograd.grad( |
| (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| torch.cuda.synchronize() |
| backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| |
| dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) |
|
|
| if self.local_rank == 0: |
| print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") |
| print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") |
|
|
| def verify_kernel_correctness(self, iterations: int = 5): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| torch_forward_latency = list() |
| torch_backward_latency = list() |
| kernel_forward_latency = list() |
| kernel_backward_latency = list() |
|
|
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
|
|
| for i in range(iterations): |
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| |
| dist.broadcast(hidden, src=0, group=self.group) |
| dist.broadcast(labels, src=0, group=self.group) |
|
|
| start_event.record() |
| (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) |
| end_event.record() |
| torch.cuda.synchronize() |
| torch_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (kernel_logprobs, kernel_entropy) = linear_cross_entropy( |
| hidden, weight, labels, self.temperature, "none", self.group |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| kernel_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) |
| torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) |
|
|
| |
| g_entropy, g_logprobs = self.generate_backward_inputs() |
| |
| dist.broadcast(g_entropy, src=0, group=self.group) |
| dist.broadcast(g_logprobs, src=0, group=self.group) |
|
|
| start_event.record() |
| (torch_d_hidden, torch_d_weight) = torch.autograd.grad( |
| (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| torch_backward_latency.append(start_event.elapsed_time(end_event)) |
| |
| dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) |
|
|
| start_event.record() |
| (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad( |
| (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| kernel_backward_latency.append(start_event.elapsed_time(end_event)) |
| |
| dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) |
|
|
| torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) |
|
|
| |
| torch_forward_latency = torch_forward_latency[1:] |
| torch_backward_latency = torch_backward_latency[1:] |
| kernel_forward_latency = kernel_forward_latency[1:] |
| kernel_backward_latency = kernel_backward_latency[1:] |
|
|
| if self.local_rank == 0: |
| print("\n[PASS]: Verified kernel forward & backward correctness.") |
|
|
| print( |
| f"[INFO]: Forward pass: Torch implementation average time: " |
| f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: torch implementation average time: " |
| f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Forward pass: Kernel implementation average time: " |
| f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: kernel implementation average time: " |
| f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" |
| ) |
|
|
| def check_kernel_storage(self): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| |
| dist.broadcast(hidden, src=0, group=self.group) |
| dist.broadcast(labels, src=0, group=self.group) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (kernel_logprobs, kernel_entropy) = linear_cross_entropy( |
| hidden, weight, labels, self.temperature, "none", self.group |
| ) |
| torch.cuda.synchronize() |
| kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
|
|
| g_entropy, g_logprobs = self.generate_backward_inputs() |
| |
| dist.broadcast(g_entropy, src=0, group=self.group) |
| dist.broadcast(g_logprobs, src=0, group=self.group) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( |
| (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| torch.cuda.synchronize() |
| kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| |
| dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) |
|
|
| if self.local_rank == 0: |
| print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") |
| print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") |
|
|
|
|
| if __name__ == "__main__": |
| |
|
|
| |
| assert int(os.environ["WORLD_SIZE"]) > 1, ( |
| "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to " |
| "execute this script." |
| ) |
| torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) |
|
|
| |
| |
|
|
| test = TestLinearCrossEntropy_TensorParallel() |
| for test_case_idx in range(MAX_TEST_CASES): |
| print(f"[INFO] Running test case {test_case_idx}") |
| test.initialize(test_case_idx) |
| if VERIFY_TORCH_SELF: |
| test.verify_torch_itself() |
| test.check_torch_storage() |
| test.verify_kernel_correctness() |
| test.check_kernel_storage() |
|
|
| test.shutdown() |
|
|