| |
| from typing import List |
|
|
| import torch |
| import torch.distributed |
| from accelerate import init_empty_weights |
| from torch import nn |
| from torch.nn import functional as F |
|
|
|
|
| |
| @classmethod |
| def load_layer_norm(cls, prefix, weights, eps): |
| weight = weights.get_tensor(f"{prefix}.weight") |
| bias = weights.get_tensor(f"{prefix}.bias") |
| with init_empty_weights(): |
| ln = cls(weight.shape, eps=eps) |
|
|
| ln.weight = nn.Parameter(weight) |
| ln.bias = nn.Parameter(bias) |
| return ln |
|
|
|
|
| @classmethod |
| def load_layer_norm_no_bias(cls, prefix, weights, eps): |
| weight = weights.get_tensor(f"{prefix}.weight") |
| with init_empty_weights(): |
| ln = cls(weight.shape, eps=eps) |
|
|
| ln.weight = nn.Parameter(weight) |
| ln.bias = None |
| return ln |
|
|
|
|
| torch.nn.LayerNorm.load = load_layer_norm |
| torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias |
|
|
|
|
| class FastLinear(nn.Module): |
| def __init__( |
| self, |
| weight, |
| bias, |
| ) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(weight) |
| if bias is not None: |
| self.bias = nn.Parameter(bias) |
| else: |
| self.bias = None |
|
|
| @classmethod |
| def load(cls, config, prefix: str, weights, bias: bool): |
| weight = weights.get_tensor(f"{prefix}.weight") |
| if bias: |
| bias = weights.get_tensor(f"{prefix}.bias") |
| else: |
| bias = None |
| return cls(weight, bias) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return F.linear(input, self.weight, self.bias) |
|
|
|
|
| def get_linear(weight, bias): |
| linear = FastLinear(weight, bias) |
| return linear |
|
|
|
|
| class SuperLayer(nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
|
|
| def forward(self, x): |
| return self.linear.forward(x) |
|
|
|
|
| class TensorParallelHead(SuperLayer): |
| def __init__(self, linear, process_group, should_gather: bool): |
| super().__init__(linear) |
| self.process_group = process_group |
| self.should_gather = should_gather |
|
|
| @staticmethod |
| def load(config, prefix: str, weights): |
| if weights.process_group.size() > 1: |
| try: |
| weight = weights.get_sharded(f"{prefix}.weight", dim=0) |
| should_gather = True |
| except AssertionError: |
| |
| |
| weight = weights.get_tensor(f"{prefix}.weight") |
| should_gather = False |
| else: |
| weight = weights.get_tensor(f"{prefix}.weight") |
| should_gather = False |
|
|
| return TensorParallelHead( |
| get_linear(weight, bias=None), |
| process_group=weights.process_group, |
| should_gather=should_gather, |
| ) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| if not self.should_gather: |
| return super().forward(input) |
|
|
| world_size = self.process_group.size() |
| if len(input.shape) == 2 and isinstance(self.linear, FastLinear): |
| out_dim = self.linear.weight.shape[0] |
|
|
| if input.shape[0] == 1: |
| world_out = input.new_empty(1, out_dim * world_size) |
| local_out = input.new_empty(1, out_dim) |
| gather_input = local_out |
| else: |
| world_out = input.new_empty(out_dim * world_size, input.shape[0]) |
| gather_input = input.new_empty(out_dim, input.shape[0]) |
| local_out = gather_input.T |
|
|
| torch.mm(input, self.linear.weight.T, out=local_out) |
|
|
| torch.distributed.all_gather_into_tensor(world_out, gather_input, group=self.process_group) |
|
|
| if input.shape[0] == 1: |
| return world_out |
| return world_out.T |
|
|
| output = super().forward(input) |
| world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] |
| torch.distributed.all_gather(world_output, output, group=self.process_group) |
| world_output = torch.cat(world_output, dim=-1) |
| return world_output |
|
|
|
|
| class TensorParallelColumnLinear(SuperLayer): |
| @classmethod |
| def load(cls, config, prefix: str, weights, bias: bool): |
| return cls.load_multi(config, [prefix], weights, bias, dim=0) |
|
|
| @classmethod |
| def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): |
| weight = weights.get_multi_weights_col(prefixes, dim=dim, quantize=config.quantize) |
|
|
| if bias: |
| b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] |
| bias = torch.cat(b, dim=dim) |
| else: |
| bias = None |
| linear = get_linear(weight, bias) |
| return cls(linear) |
|
|
|
|
| class TensorParallelRowLinear(SuperLayer): |
| def __init__(self, linear, process_group): |
| super().__init__(linear) |
| self.process_group = process_group |
|
|
| @classmethod |
| def load(cls, config, prefix: str, weights, bias: bool): |
| weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) |
|
|
| if bias and weights.process_group.rank() == 0: |
| |
| bias = weights.get_tensor(f"{prefix}.bias") |
| else: |
| bias = None |
| return cls( |
| get_linear(weight, bias), |
| process_group=weights.process_group, |
| ) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| out = super().forward(input) |
| if self.process_group.size() > 1: |
| torch.distributed.all_reduce(out, group=self.process_group) |
| return out |
|
|
|
|
| class TensorParallelEmbedding(nn.Module): |
| def __init__(self, prefix: str, weights, reduce=True): |
| super().__init__() |
| weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) |
| num_embeddings = weights.get_shape(f"{prefix}.weight")[0] |
|
|
| process_group = weights.process_group |
|
|
| world_size = process_group.size() |
| rank = process_group.rank() |
|
|
| block_size = num_embeddings // world_size |
| self.min_id = rank * block_size |
| self.max_id = min(num_embeddings, (rank + 1) * block_size) |
| self.null_idx = block_size |
| self.process_group = weights.process_group |
| self.reduce = reduce |
|
|
| """Additional 0 entry used for masking""" |
| self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| |
| |
| input = torch.where( |
| (self.min_id > input) | (input >= self.max_id), |
| self.null_idx, |
| input - self.min_id, |
| ) |
| out = torch.nn.functional.embedding(input, self.weight) |
| if self.reduce and self.process_group.size() > 1: |
| torch.distributed.all_reduce(out, group=self.process_group) |
| return out |
|
|
|
|
| try: |
| import dropout_layer_norm |
|
|
| class FastLayerNorm(nn.LayerNorm): |
| def forward(self, hidden_states, residual=None): |
| if hidden_states.shape[-1] > 8192: |
| if residual is not None: |
| hidden_states += residual |
| residual = hidden_states |
|
|
| return super(FastLayerNorm, self).forward(hidden_states), residual |
| else: |
| ( |
| normed_hidden_states, |
| residual, |
| *rest, |
| ) = dropout_layer_norm.dropout_add_ln_fwd( |
| hidden_states, |
| residual, |
| self.weight, |
| self.bias, |
| None, |
| None, |
| None, |
| None, |
| 0.0, |
| self.eps, |
| 1.0, |
| 0, |
| None, |
| False, |
| False, |
| ) |
| if residual is None: |
| residual = hidden_states |
|
|
| return normed_hidden_states, residual |
|
|
| except ImportError: |
| pass |
|
|