| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Distributed ops for supporting sequence parallel. |
| """ |
|
|
| from collections import defaultdict |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| import torch |
| import torch.distributed as dist |
| from torch import Tensor |
|
|
| from common.cache import Cache |
| from common.distributed.advanced import ( |
| get_sequence_parallel_group, |
| get_sequence_parallel_rank, |
| get_sequence_parallel_world_size, |
| ) |
|
|
| from .basic import get_device |
|
|
| _SEQ_DATA_BUF = defaultdict(lambda: [None, None, None]) |
| _SEQ_DATA_META_SHAPES = defaultdict() |
| _SEQ_DATA_META_DTYPES = defaultdict() |
| _SEQ_DATA_ASYNC_COMMS = defaultdict(list) |
| _SYNC_BUFFER = defaultdict(dict) |
|
|
|
|
| def single_all_to_all( |
| local_input: Tensor, |
| scatter_dim: int, |
| gather_dim: int, |
| group: dist.ProcessGroup, |
| async_op: bool = False, |
| ): |
| """ |
| A function to do all-to-all on a tensor |
| """ |
| seq_world_size = dist.get_world_size(group) |
| prev_scatter_dim = scatter_dim |
| if scatter_dim != 0: |
| local_input = local_input.transpose(0, scatter_dim) |
| if gather_dim == 0: |
| gather_dim = scatter_dim |
| scatter_dim = 0 |
|
|
| inp_shape = list(local_input.shape) |
| inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
| input_t = local_input.reshape( |
| [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :] |
| ).contiguous() |
| output = torch.empty_like(input_t) |
| comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) |
| if async_op: |
| |
| return output, comm, prev_scatter_dim |
|
|
| |
| output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0) |
| if prev_scatter_dim: |
| output = output.transpose(0, prev_scatter_dim).contiguous() |
| return output |
|
|
|
|
| def _all_to_all( |
| local_input: Tensor, |
| scatter_dim: int, |
| gather_dim: int, |
| group: dist.ProcessGroup, |
| ): |
| seq_world_size = dist.get_world_size(group) |
| input_list = [ |
| t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) |
| ] |
| output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] |
| dist.all_to_all(output_list, input_list, group=group) |
| return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
| class SeqAllToAll(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: Any, |
| group: dist.ProcessGroup, |
| local_input: Tensor, |
| scatter_dim: int, |
| gather_dim: int, |
| async_op: bool, |
| ) -> Tensor: |
| ctx.group = group |
| ctx.scatter_dim = scatter_dim |
| ctx.gather_dim = gather_dim |
| ctx.async_op = async_op |
| if async_op: |
| output, comm, prev_scatter_dim = single_all_to_all( |
| local_input, scatter_dim, gather_dim, group, async_op=async_op |
| ) |
| ctx.prev_scatter_dim = prev_scatter_dim |
| return output, comm |
|
|
| return _all_to_all(local_input, scatter_dim, gather_dim, group) |
|
|
| @staticmethod |
| def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
| if ctx.async_op: |
| input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0) |
| if ctx.prev_scatter_dim: |
| input_t = input_t.transpose(0, ctx.prev_scatter_dim) |
| else: |
| input_t = grad_output[0] |
| return ( |
| None, |
| _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group), |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| class Slice(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: |
| ctx.group = group |
| ctx.rank = dist.get_rank(group) |
| seq_world_size = dist.get_world_size(group) |
| ctx.seq_world_size = seq_world_size |
| ctx.dim = dim |
| dim_size = local_input.shape[dim] |
| return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() |
|
|
| @staticmethod |
| def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: |
| dim_size = list(grad_output.size()) |
| split_size = dim_size[0] |
| dim_size[0] = dim_size[0] * ctx.seq_world_size |
| output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) |
| dist._all_gather_base(output, grad_output, group=ctx.group) |
| return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) |
|
|
|
|
| class Gather(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: Any, |
| group: dist.ProcessGroup, |
| local_input: Tensor, |
| dim: int, |
| grad_scale: Optional[bool] = False, |
| ) -> Tensor: |
| ctx.group = group |
| ctx.rank = dist.get_rank(group) |
| ctx.dim = dim |
| ctx.grad_scale = grad_scale |
| seq_world_size = dist.get_world_size(group) |
| ctx.seq_world_size = seq_world_size |
| dim_size = list(local_input.size()) |
| split_size = dim_size[0] |
| ctx.part_size = dim_size[dim] |
| dim_size[0] = dim_size[0] * seq_world_size |
| output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) |
| dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) |
| return torch.cat(output.split(split_size), dim=dim) |
|
|
| @staticmethod |
| def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: |
| if ctx.grad_scale: |
| grad_output = grad_output * ctx.seq_world_size |
| return ( |
| None, |
| grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), |
| None, |
| None, |
| ) |
|
|
|
|
| def gather_seq_scatter_heads_qkv( |
| qkv_tensor: Tensor, |
| *, |
| seq_dim: int, |
| qkv_shape: Optional[Tensor] = None, |
| cache: Cache = Cache(disable=True), |
| restore_shape: bool = True, |
| ): |
| """ |
| A func to sync splited qkv tensor |
| qkv_tensor: the tensor we want to do alltoall with. The last dim must |
| be the projection_idx, which we will split into 3 part. After |
| spliting, the gather idx will be projecttion_idx + 1 |
| seq_dim: gather_dim for all2all comm |
| restore_shape: if True, output will has the same shape length as input |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return qkv_tensor |
| world = get_sequence_parallel_world_size() |
| orig_shape = qkv_tensor.shape |
| scatter_dim = qkv_tensor.dim() |
| bef_all2all_shape = list(orig_shape) |
| qkv_proj_dim = bef_all2all_shape[-1] |
| bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] |
| qkv_tensor = qkv_tensor.view(bef_all2all_shape) |
| qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False) |
| if restore_shape: |
| out_shape = list(orig_shape) |
| out_shape[seq_dim] *= world |
| out_shape[-1] = qkv_proj_dim // world |
| qkv_tensor = qkv_tensor.view(out_shape) |
|
|
| |
| if qkv_shape is not None: |
| unpad_dim_size = cache( |
| "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item() |
| ) |
| if unpad_dim_size % world != 0: |
| padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size |
| qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) |
| return qkv_tensor |
|
|
|
|
| def slice_inputs(x: Tensor, dim: int, padding: bool = True): |
| """ |
| A func to slice the input sequence in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if group is None: |
| return x |
| sp_rank = get_sequence_parallel_rank() |
| sp_world = get_sequence_parallel_world_size() |
| dim_size = x.shape[dim] |
| unit = (dim_size + sp_world - 1) // sp_world |
| if padding and dim_size % sp_world: |
| padding_size = sp_world - (dim_size % sp_world) |
| x = _pad_tensor(x, dim, padding_size) |
| slc = [slice(None)] * len(x.shape) |
| slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) |
| return x[slc] |
|
|
|
|
| def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): |
| """ |
| A func to remove the padding part of the tensor based on its original shape |
| """ |
| group = get_sequence_parallel_group() |
| if group is None: |
| return x |
| sp_world = get_sequence_parallel_world_size() |
| if unpad_dim_size % sp_world == 0: |
| return x |
| padding_size = sp_world - (unpad_dim_size % sp_world) |
| assert (padding_size + unpad_dim_size) % sp_world == 0 |
| return _unpad_tensor(x, dim=dim, padding_size=padding_size) |
|
|
|
|
| def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: |
| """ |
| A func to sync attention result with alltoall in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return x |
| dim_size = x.size(seq_dim) |
| sp_world = get_sequence_parallel_world_size() |
| if dim_size % sp_world != 0: |
| padding_size = sp_world - (dim_size % sp_world) |
| x = _pad_tensor(x, seq_dim, padding_size) |
| return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) |
|
|
|
|
| def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor: |
| """ |
| A func to sync embedding input with alltoall in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return x |
| return SeqAllToAll.apply(group, x, head_dim, seq_dim, False) |
|
|
|
|
| def scatter_heads(x: Tensor, dim: int) -> Tensor: |
| """ |
| A func to split heads before attention in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return x |
| return Slice.apply(group, x, dim) |
|
|
|
|
| def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor: |
| """ |
| A func to gather heads for the attention result in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return x |
| return Gather.apply(group, x, dim, grad_scale) |
|
|
|
|
| def gather_outputs( |
| x: Tensor, |
| *, |
| gather_dim: int, |
| padding_dim: Optional[int] = None, |
| unpad_shape: Optional[Tensor] = None, |
| cache: Cache = Cache(disable=True), |
| scale_grad=True, |
| ): |
| """ |
| A func to gather the outputs for the model result in sequence parallel |
| """ |
| group = get_sequence_parallel_group() |
| if not group: |
| return x |
| x = Gather.apply(group, x, gather_dim, scale_grad) |
| if padding_dim is not None: |
| unpad_dim_size = cache( |
| "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item() |
| ) |
| x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size) |
| return x |
|
|
|
|
| def _pad_tensor(x: Tensor, dim: int, padding_size: int): |
| shape = list(x.shape) |
| shape[dim] = padding_size |
| pad = torch.zeros(shape, dtype=x.dtype, device=x.device) |
| return torch.cat([x, pad], dim=dim) |
|
|
|
|
| def _unpad_tensor(x: Tensor, dim: int, padding_size): |
| slc = [slice(None)] * len(x.shape) |
| slc[dim] = slice(0, -padding_size) |
| return x[slc] |
|
|
|
|
| def _broadcast_data(data, shape, dtype, src, group, async_op): |
| comms = [] |
| if isinstance(data, (list, tuple)): |
| for i, sub_shape in enumerate(shape): |
| comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op) |
| elif isinstance(data, dict): |
| for key, sub_data in data.items(): |
| comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op) |
| elif isinstance(data, Tensor): |
| comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op)) |
| return comms |
|
|
|
|
| def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]: |
| if isinstance(data, (list, tuple)): |
| return [_traverse(sub_data, op) for sub_data in data] |
| elif isinstance(data, dict): |
| return {key: _traverse(sub_data, op) for key, sub_data in data.items()} |
| elif isinstance(data, Tensor): |
| return op(data) |
| else: |
| return None |
|
|
|
|
| def _get_shapes(data): |
| return _traverse(data, op=lambda x: x.shape) |
|
|
|
|
| def _get_dtypes(data): |
| return _traverse(data, op=lambda x: x.dtype) |
|
|
|
|
| def _construct_broadcast_buffer(shapes, dtypes, device): |
| if isinstance(shapes, torch.Size): |
| return torch.empty(shapes, dtype=dtypes, device=device) |
|
|
| if isinstance(shapes, (list, tuple)): |
| buffer = [] |
| for i, sub_shape in enumerate(shapes): |
| buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device)) |
| elif isinstance(shapes, dict): |
| buffer = {} |
| for key, sub_shape in shapes.items(): |
| buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device) |
| else: |
| return None |
| return buffer |
|
|
|
|
| class SPDistForward: |
| """A forward tool to sync different result across sp group |
| |
| Args: |
| module: a function or module to process users input |
| sp_step: current training step to judge which rank to broadcast its result to all |
| name: a distinct str to save meta and async comm |
| comm_shape: if different ranks have different shape, mark this arg to True |
| device: the device for current rank, can be empty |
| """ |
|
|
| def __init__( |
| self, |
| name: str, |
| comm_shape: bool, |
| device: torch.device = None, |
| ): |
| self.name = name |
| self.comm_shape = comm_shape |
| if device: |
| self.device = device |
| else: |
| self.device = get_device() |
|
|
| def __call__(self, inputs) -> Any: |
| group = get_sequence_parallel_group() |
| if not group: |
| yield inputs |
| else: |
| device = self.device |
| sp_world = get_sequence_parallel_world_size() |
| sp_rank = get_sequence_parallel_rank() |
| for local_step in range(sp_world): |
| src_rank = dist.get_global_rank(group, local_step) |
| is_src = sp_rank == local_step |
| local_shapes = [] |
| local_dtypes = [] |
| if local_step == 0: |
| local_result = inputs |
| _SEQ_DATA_BUF[self.name][-1] = local_result |
| local_shapes = _get_shapes(local_result) |
| local_dtypes = _get_dtypes(local_result) |
| if self.comm_shape: |
| group_shapes_lists = [None] * sp_world |
| dist.all_gather_object(group_shapes_lists, local_shapes, group=group) |
| _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists |
| else: |
| _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world |
| _SEQ_DATA_META_DTYPES[self.name] = local_dtypes |
| shapes = _SEQ_DATA_META_SHAPES[self.name][local_step] |
| dtypes = _SEQ_DATA_META_DTYPES[self.name] |
| buf_id = local_step % 2 |
| if local_step == 0: |
| sync_data = ( |
| local_result |
| if is_src |
| else _construct_broadcast_buffer(shapes, dtypes, device) |
| ) |
| _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False) |
| _SEQ_DATA_BUF[self.name][buf_id] = sync_data |
|
|
| |
| if _SEQ_DATA_ASYNC_COMMS[self.name]: |
| for comm in _SEQ_DATA_ASYNC_COMMS[self.name]: |
| comm.wait() |
| |
| if local_step < sp_world - 1: |
| next_buf_id = 1 - buf_id |
| shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1] |
| src_rank = dist.get_global_rank(group, local_step + 1) |
| is_src = sp_rank == local_step + 1 |
| next_sync_data = ( |
| _SEQ_DATA_BUF[self.name][-1] |
| if is_src |
| else _construct_broadcast_buffer(shapes, dtypes, device) |
| ) |
| _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data( |
| next_sync_data, shapes, dtypes, src_rank, group, True |
| ) |
| _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data |
| yield _SEQ_DATA_BUF[self.name][buf_id] |
|
|
|
|
| sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True) |
|
|
|
|
| def sync_data(data, sp_idx, name="tmp"): |
| group = get_sequence_parallel_group() |
| if group is None: |
| return data |
| |
| |
| sp_rank = get_sequence_parallel_rank() |
| src_rank = dist.get_global_rank(group, sp_idx) |
| objects = [data] if sp_rank == sp_idx else [None] |
| dist.broadcast_object_list(objects, src=src_rank, group=group) |
| |
| return objects[0] |
|
|