| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers.models.distilbert import modeling_distilbert |
| from .configuration_distilbert_ane import DistilBertConfig |
|
|
| |
| |
| EPS = 1e-7 |
|
|
| WARN_MSG_FOR_TRAINING_ATTEMPT = \ |
| "This model is optimized for on-device execution only. " \ |
| "Please use the original implementation from Hugging Face for training" |
|
|
| WARN_MSG_FOR_DICT_RETURN = \ |
| "coremltools does not support dict outputs. Please set return_dict=False" |
|
|
|
|
| class LayerNormANE(nn.Module): |
| """ LayerNorm optimized for Apple Neural Engine (ANE) execution |
| |
| Note: This layer only supports normalization over the final dim. It expects `num_channels` |
| as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`. |
| """ |
|
|
| def __init__(self, |
| num_channels, |
| clip_mag=None, |
| eps=1e-5, |
| elementwise_affine=True): |
| """ |
| Args: |
| num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length. |
| clip_mag: Optional float value to use for clamping the input range before layer norm is applied. |
| If specified, helps reduce risk of overflow. |
| eps: Small value to avoid dividing by zero |
| elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters |
| """ |
| super().__init__() |
| |
| self.expected_rank = len('BC1S') |
|
|
| self.num_channels = num_channels |
| self.eps = eps |
| self.clip_mag = clip_mag |
| self.elementwise_affine = elementwise_affine |
|
|
| if self.elementwise_affine: |
| self.weight = nn.Parameter(torch.Tensor(num_channels)) |
| self.bias = nn.Parameter(torch.Tensor(num_channels)) |
|
|
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| if self.elementwise_affine: |
| nn.init.ones_(self.weight) |
| nn.init.zeros_(self.bias) |
|
|
| def forward(self, inputs): |
| input_rank = len(inputs.size()) |
|
|
| |
| |
| if input_rank == 3 and inputs.size(2) == self.num_channels: |
| inputs = inputs.transpose(1, 2).unsqueeze(2) |
| input_rank = len(inputs.size()) |
|
|
| assert input_rank == self.expected_rank |
| assert inputs.size(1) == self.num_channels |
|
|
| if self.clip_mag is not None: |
| inputs.clamp_(-self.clip_mag, self.clip_mag) |
|
|
| channels_mean = inputs.mean(dim=1, keepdims=True) |
|
|
| zero_mean = inputs - channels_mean |
|
|
| zero_mean_sq = zero_mean * zero_mean |
|
|
| denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt() |
|
|
| out = zero_mean * denom |
|
|
| if self.elementwise_affine: |
| out = (out + self.bias.view(1, self.num_channels, 1, 1) |
| ) * self.weight.view(1, self.num_channels, 1, 1) |
|
|
| return out |
|
|
|
|
| class Embeddings(modeling_distilbert.Embeddings): |
| """ Embeddings module optimized for Apple Neural Engine |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'LayerNorm', LayerNormANE(config.dim, eps=EPS)) |
|
|
|
|
| class MultiHeadSelfAttention(modeling_distilbert.MultiHeadSelfAttention): |
| """ MultiHeadSelfAttention module optimized for Apple Neural Engine |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| setattr( |
| self, 'q_lin', |
| nn.Conv2d( |
| in_channels=config.dim, |
| out_channels=config.dim, |
| kernel_size=1, |
| )) |
|
|
| setattr( |
| self, 'k_lin', |
| nn.Conv2d( |
| in_channels=config.dim, |
| out_channels=config.dim, |
| kernel_size=1, |
| )) |
|
|
| setattr( |
| self, 'v_lin', |
| nn.Conv2d( |
| in_channels=config.dim, |
| out_channels=config.dim, |
| kernel_size=1, |
| )) |
|
|
| setattr( |
| self, 'out_lin', |
| nn.Conv2d( |
| in_channels=config.dim, |
| out_channels=config.dim, |
| kernel_size=1, |
| )) |
|
|
| def prune_heads(self, heads): |
| raise NotImplementedError |
|
|
| def forward(self, |
| query, |
| key, |
| value, |
| mask, |
| head_mask=None, |
| output_attentions=False): |
| """ |
| Parameters: |
| query: torch.tensor(bs, dim, 1, seq_length) |
| key: torch.tensor(bs, dim, 1, seq_length) |
| value: torch.tensor(bs, dim, 1, seq_length) |
| mask: torch.tensor(bs, seq_length) or torch.tensor(bs, seq_length, 1, 1) |
| |
| Returns: |
| weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, |
| dim, 1, seq_length) Contextualized layer. Optional: only if `output_attentions=True` |
| """ |
| |
| assert len(query.size()) == 4 and len(key.size()) == 4 and len( |
| value.size()) == 4 |
|
|
| bs, dim, dummy, seqlen = query.size() |
| |
| |
| |
|
|
| |
| q = self.q_lin(query) |
| k = self.k_lin(key) |
| v = self.v_lin(value) |
|
|
| |
| if mask is not None: |
| expected_mask_shape = [bs, seqlen, 1, 1] |
| if mask.dtype == torch.bool: |
| mask = mask.logical_not().float() * -1e4 |
| elif mask.dtype == torch.int64: |
| mask = (1 - mask).float() * -1e4 |
| elif mask.dtype != torch.float32: |
| raise TypeError(f"Unexpected dtype for mask: {mask.dtype}") |
|
|
| if len(mask.size()) == 2: |
| mask = mask.unsqueeze(2).unsqueeze(2) |
|
|
| if list(mask.size()) != expected_mask_shape: |
| raise RuntimeError( |
| f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}" |
| ) |
|
|
| if head_mask is not None: |
| raise NotImplementedError |
|
|
| |
| dim_per_head = self.dim // self.n_heads |
| mh_q = q.split( |
| dim_per_head, |
| dim=1) |
| mh_k = k.transpose(1, 3).split( |
| dim_per_head, |
| dim=3) |
| mh_v = v.split( |
| dim_per_head, |
| dim=1) |
|
|
| normalize_factor = float(dim_per_head)**-0.5 |
| attn_weights = [ |
| torch.einsum('bchq,bkhc->bkhq', [qi, ki]) * normalize_factor |
| for qi, ki in zip(mh_q, mh_k) |
| ] |
|
|
| if mask is not None: |
| for head_idx in range(self.n_heads): |
| attn_weights[head_idx] = attn_weights[head_idx] + mask |
|
|
| attn_weights = [aw.softmax(dim=1) for aw in attn_weights |
| ] |
| attn = [ |
| torch.einsum('bkhq,bchk->bchq', wi, vi) |
| for wi, vi in zip(attn_weights, mh_v) |
| ] |
|
|
| attn = torch.cat(attn, dim=1) |
|
|
| attn = self.out_lin(attn) |
|
|
| if output_attentions: |
| return attn, attn_weights.cat(dim=2) |
| else: |
| return (attn, ) |
|
|
|
|
| class FFN(modeling_distilbert.FFN): |
| """ FFN module optimized for Apple Neural Engine |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.seq_len_dim = 3 |
|
|
| setattr( |
| self, 'lin1', |
| nn.Conv2d( |
| in_channels=config.dim, |
| out_channels=config.hidden_dim, |
| kernel_size=1, |
| )) |
|
|
| setattr( |
| self, 'lin2', |
| nn.Conv2d( |
| in_channels=config.hidden_dim, |
| out_channels=config.dim, |
| kernel_size=1, |
| )) |
|
|
|
|
| class TransformerBlock(modeling_distilbert.TransformerBlock): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'attention', MultiHeadSelfAttention(config)) |
| setattr(self, 'sa_layer_norm', LayerNormANE(config.dim, eps=EPS)) |
| setattr(self, 'ffn', FFN(config)) |
| setattr(self, 'output_layer_norm', LayerNormANE(config.dim, eps=EPS)) |
|
|
|
|
| class Transformer(modeling_distilbert.Transformer): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr( |
| self, 'layer', |
| nn.ModuleList( |
| [TransformerBlock(config) for _ in range(config.n_layers)])) |
|
|
|
|
| class DistilBertModel(modeling_distilbert.DistilBertModel): |
| config_class = DistilBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'embeddings', Embeddings(config)) |
| setattr(self, 'transformer', Transformer(config)) |
|
|
| |
| self._register_load_state_dict_pre_hook(linear_to_conv2d_map) |
|
|
| def _prune_heads(self, heads_to_prune): |
| raise NotImplementedError |
|
|
|
|
| class DistilBertForMaskedLM(modeling_distilbert.DistilBertForMaskedLM): |
| config_class = DistilBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| from transformers.activations import get_activation |
| setattr(self, 'activation', get_activation(config.activation)) |
| setattr(self, 'distilbert', DistilBertModel(config)) |
| setattr(self, 'vocab_transform', nn.Conv2d(config.dim, config.dim, 1)) |
| setattr(self, 'vocab_layer_norm', LayerNormANE(config.dim, eps=EPS)) |
| setattr(self, 'vocab_projector', |
| nn.Conv2d(config.dim, config.vocab_size, 1)) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| if self.training or labels is not None: |
| raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if return_dict: |
| raise ValueError(WARN_MSG_FOR_DICT_RETURN) |
|
|
| dlbrt_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
| hidden_states = dlbrt_output[0] |
| prediction_logits = self.vocab_transform( |
| hidden_states) |
| prediction_logits = self.activation( |
| prediction_logits) |
| prediction_logits = self.vocab_layer_norm( |
| prediction_logits) |
| prediction_logits = self.vocab_projector( |
| prediction_logits) |
| prediction_logits = prediction_logits.squeeze(-1).squeeze( |
| -1) |
|
|
| output = (prediction_logits, ) + dlbrt_output[1:] |
| mlm_loss = None |
|
|
| return ((mlm_loss, ) + output) if mlm_loss is not None else output |
|
|
|
|
| class DistilBertForSequenceClassification( |
| modeling_distilbert.DistilBertForSequenceClassification): |
| config_class = DistilBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'distilbert', DistilBertModel(config)) |
| setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1)) |
| setattr(self, 'classifier', nn.Conv2d(config.dim, config.num_labels, |
| 1)) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| if labels is not None or self.training: |
| raise NotImplementedError(WARN_MSG_FOR_TRAINING_ATTEMPT) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if return_dict: |
| raise ValueError(WARN_MSG_FOR_DICT_RETURN) |
|
|
| distilbert_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
| hidden_state = distilbert_output[0] |
| pooled_output = hidden_state[:, :, :, 0:1] |
| pooled_output = self.pre_classifier(pooled_output) |
| pooled_output = nn.ReLU()(pooled_output) |
| logits = self.classifier(pooled_output) |
| logits = logits.squeeze(-1).squeeze(-1) |
|
|
| output = (logits, ) + distilbert_output[1:] |
| loss = None |
|
|
| return ((loss, ) + output) if loss is not None else output |
|
|
|
|
| class DistilBertForQuestionAnswering( |
| modeling_distilbert.DistilBertForQuestionAnswering): |
| config_class = DistilBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'distilbert', DistilBertModel(config)) |
| setattr(self, 'qa_outputs', nn.Conv2d(config.dim, config.num_labels, |
| 1)) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=None, |
| start_positions=None, |
| end_positions=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
|
|
| if self.training or start_positions is not None or end_positions is not None: |
| raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if return_dict: |
| raise ValueError(WARN_MSG_FOR_DICT_RETURN) |
|
|
| distilbert_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
| hidden_states = distilbert_output[0] |
|
|
| hidden_states = self.dropout( |
| hidden_states) |
| logits = self.qa_outputs(hidden_states) |
| start_logits, end_logits = logits.split( |
| 1, dim=1) |
| start_logits = start_logits.squeeze().contiguous( |
| ) |
| end_logits = end_logits.squeeze().contiguous() |
|
|
| output = (start_logits, end_logits) + distilbert_output[1:] |
| total_loss = None |
|
|
| return ((total_loss, ) + output) if total_loss is not None else output |
|
|
|
|
| class DistilBertForTokenClassification( |
| modeling_distilbert.DistilBertForTokenClassification): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'distilbert', DistilBertModel(config)) |
| setattr(self, 'classifier', |
| nn.Conv2d(config.hidden_size, config.num_labels, 1)) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| if self.training or labels is not None: |
| raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if return_dict: |
| raise ValueError(WARN_MSG_FOR_DICT_RETURN) |
|
|
| outputs = self.distilbert( |
| input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
|
|
| sequence_output = outputs[0] |
| logits = self.classifier( |
| sequence_output) |
| logits = logits.squeeze(2).transpose(1, 2) |
|
|
| output = (logits, ) + outputs[1:] |
| loss = None |
| return ((loss, ) + output) if loss is not None else output |
|
|
|
|
| class DistilBertForMultipleChoice( |
| modeling_distilbert.DistilBertForMultipleChoice): |
| config_class = DistilBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| setattr(self, 'distilbert', DistilBertModel(config)) |
| setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1)) |
| setattr(self, 'classifier', nn.Conv2d(config.dim, 1, 1)) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| if self.training or labels is not None: |
| raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT) |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if return_dict: |
| raise ValueError(WARN_MSG_FOR_DICT_RETURN) |
|
|
| num_choices = input_ids.shape[ |
| 1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
| input_ids = input_ids.view( |
| -1, input_ids.size(-1)) if input_ids is not None else None |
| attention_mask = attention_mask.view( |
| -1, |
| attention_mask.size(-1)) if attention_mask is not None else None |
| inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), |
| inputs_embeds.size(-1)) |
| if inputs_embeds is not None else None) |
|
|
| outputs = self.distilbert( |
| input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
|
|
| hidden_state = outputs[0] |
| pooled_output = hidden_state[:, :, :, |
| 0:1] |
| pooled_output = self.pre_classifier( |
| pooled_output) |
| pooled_output = nn.ReLU()( |
| pooled_output) |
| logits = self.classifier(pooled_output) |
| logits = logits.squeeze() |
|
|
| reshaped_logits = logits.view(-1, num_choices) |
|
|
| output = (reshaped_logits, ) + outputs[1:] |
| loss = None |
|
|
| return ((loss, ) + output) if loss is not None else output |
|
|
|
|
| def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights |
| """ |
| for k in state_dict: |
| is_internal_proj = all(substr in k for substr in ['lin', '.weight']) |
| is_output_proj = all(substr in k |
| for substr in ['classifier', '.weight']) |
| if is_internal_proj or is_output_proj: |
| if len(state_dict[k].shape) == 2: |
| state_dict[k] = state_dict[k][:, :, None, None] |
|
|