| |
|
|
| from typing import Tuple, Optional, Union |
|
|
| import torch |
|
|
| import triton |
| import triton.language as tl |
|
|
| |
| |
| |
| |
| if "all_gather_into_tensor" not in dir(torch.distributed): |
| torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
|
|
|
|
| @triton.heuristics( |
| { |
| "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
| } |
| ) |
| @triton.jit |
| def cross_entropy_fwd_kernel( |
| loss_ptr, |
| lse_ptr, |
| z_loss_ptr, |
| logits_ptr, |
| labels_ptr, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| n_rows, |
| logits_row_stride, |
| BLOCK_SIZE: tl.constexpr, |
| HAS_SMOOTHING: tl.constexpr, |
| |
| SPLIT: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0) |
| col_block_idx = tl.program_id(1) |
| logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
| col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| label_idx = tl.load(labels_ptr + row_idx) |
| logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( |
| tl.float32 |
| ) * logit_scale |
| max_logits = tl.max(logits, 0) |
| if HAS_SMOOTHING: |
| sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) |
| lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits |
| tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) |
| if label_idx == ignore_index: |
| loss = 0.0 |
| z_loss = 0.0 |
| else: |
| label_idx -= class_start_idx |
| if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( |
| n_cols, (col_block_idx + 1) * BLOCK_SIZE |
| ): |
| logits_label = tl.load(logits_ptr + label_idx) * logit_scale |
| if HAS_SMOOTHING: |
| loss = ( |
| (lse if not SPLIT else 0.0) |
| - smoothing * sum_logits / total_classes |
| - (1 - smoothing) * logits_label |
| ) |
| else: |
| loss = (lse if not SPLIT else 0.0) - logits_label |
| else: |
| |
| if HAS_SMOOTHING: |
| loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) |
| else: |
| loss = 0.0 |
| if not SPLIT: |
| z_loss = lse_square_scale * lse * lse |
| loss += z_loss |
| else: |
| z_loss = 0.0 |
| tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) |
| if not SPLIT: |
| tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) |
|
|
|
|
| @triton.heuristics( |
| { |
| "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
| } |
| ) |
| @triton.jit |
| def cross_entropy_bwd_kernel( |
| dlogits_ptr, |
| dloss_ptr, |
| logits_ptr, |
| lse_ptr, |
| labels_ptr, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| logits_row_stride, |
| dlogits_row_stride, |
| dloss_row_stride, |
| BLOCK_SIZE: tl.constexpr, |
| HAS_SMOOTHING: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0) |
| col_block_idx = tl.program_id(1) |
| logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
| dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) |
| col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| label_idx = tl.load(labels_ptr + row_idx) |
| if label_idx != ignore_index: |
| dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) |
| else: |
| dloss = 0.0 |
| logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( |
| tl.float32 |
| ) * logit_scale |
| lse = tl.load(lse_ptr + row_idx) |
| probs = tl.exp(logits - lse) |
| probs += 2.0 * lse_square_scale * lse * probs |
| label_idx -= class_start_idx |
| if HAS_SMOOTHING: |
| smooth_positive = 1.0 - smoothing |
| smooth_negative = smoothing / total_classes |
| probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative |
| else: |
| probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) |
| tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) |
|
|
|
|
| class CrossEntropyLoss(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward( |
| ctx, |
| logits, |
| labels, |
| smoothing=0.0, |
| logit_scale=1.0, |
| lse_square_scale=0.0, |
| ignore_index=-100, |
| inplace_backward=False, |
| process_group=None, |
| ): |
| n_rows, n_cols = logits.shape |
| assert labels.shape == (n_rows,) |
| world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) |
| total_classes = world_size * n_cols |
| rank = 0 if process_group is None else torch.distributed.get_rank(process_group) |
| class_start_idx = rank * n_cols |
|
|
| if logits.stride(-1) != 1: |
| logits = logits.contiguous() |
| |
| MAX_BLOCK_SIZE = 64 * 1024 |
| BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) |
| num_warps = ( |
| 4 |
| if BLOCK_SIZE < 2048 |
| else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) |
| ) |
| |
| |
| |
| split = world_size > 1 or n_cols > MAX_BLOCK_SIZE |
| n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE |
| loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) |
| losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) |
| lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) |
| z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) |
| |
| |
| with torch.cuda.device(logits.device.index): |
| cross_entropy_fwd_kernel[(n_rows, n_splits)]( |
| losses, |
| lse, |
| z_losses, |
| logits, |
| labels, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| n_rows, |
| logits.stride(0), |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_warps=num_warps, |
| SPLIT=split, |
| ) |
|
|
| if split: |
| |
| |
| |
| |
| |
| |
| if n_splits > 1: |
| lse = torch.logsumexp(lse, dim=0) |
| losses = losses.sum(dim=0) |
| if world_size > 1: |
| lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) |
| torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) |
| handle_losses = torch.distributed.all_reduce( |
| losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True |
| ) |
| lse = torch.logsumexp(lse_allgather, dim=0) |
| handle_losses.wait() |
| |
| |
| |
| |
| |
| losses += lse |
| if lse_square_scale != 0.0: |
| z_losses = lse_square_scale * lse.square() |
| z_losses.masked_fill_(labels == ignore_index, 0.0) |
| losses += z_losses |
| else: |
| z_losses = torch.zeros_like(losses) |
| losses.masked_fill_(labels == ignore_index, 0.0) |
|
|
| ctx.save_for_backward(logits, lse, labels) |
| ctx.mark_non_differentiable(z_losses) |
| ctx.smoothing = smoothing |
| ctx.logit_scale = logit_scale |
| ctx.lse_square_scale = lse_square_scale |
| ctx.ignore_index = ignore_index |
| ctx.total_classes = total_classes |
| ctx.class_start_idx = class_start_idx |
| ctx.inplace_backward = inplace_backward |
|
|
| return losses, z_losses |
|
|
| @staticmethod |
| def backward(ctx, grad_losses, grad_z_losses): |
| del grad_z_losses |
|
|
| logits, lse, labels = ctx.saved_tensors |
| dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) |
| n_rows, n_cols = logits.shape |
| BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) |
| num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) |
| grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) |
| |
| |
| with torch.cuda.device(logits.device.index): |
| cross_entropy_bwd_kernel[grid]( |
| dlogits, |
| grad_losses, |
| logits, |
| lse, |
| labels, |
| ctx.smoothing, |
| ctx.logit_scale, |
| ctx.lse_square_scale, |
| ctx.ignore_index, |
| ctx.total_classes, |
| ctx.class_start_idx, |
| n_cols, |
| logits.stride(0), |
| dlogits.stride(0), |
| grad_losses.stride(0), |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_warps=num_warps, |
| ) |
| return dlogits, None, None, None, None, None, None, None, None |
|
|
| def cross_entropy_loss( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| lse_square_scale: float = 0.0, |
| ignore_index=-100, |
| inplace_backward: bool = False, |
| process_group=None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Arguments: |
| logits: (batch, vocab_size) |
| labels: (batch,) |
| label_smoothing: float |
| logit_scale: float. Multiply logits by this scale before calculating the loss. |
| lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. |
| This is also referred to as "z-loss". |
| ignore_index: int. If labels == ignore_index, the loss is set to 0.0. |
| inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. |
| This saves memory. |
| process_group: if not None, we're doing Tensor Parallel: each process is responsible for |
| one part of the vocab. The loss will be aggregated across processes. |
| Returns: |
| losses: (batch,), float |
| z_losses: (batch,), float |
| """ |
| return CrossEntropyLoss.apply( |
| logits, |
| labels, |
| label_smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| inplace_backward, |
| process_group, |
| ) |
|
|