|
|
| |
|
|
| from typing import Optional, Tuple |
|
|
| import numpy as np |
| import pickle |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss, MSELoss |
|
|
| from torch import Tensor |
|
|
| import copy |
|
|
| from dataclasses import dataclass |
| from transformers.activations import ACT2FN |
| from transformers.file_utils import ModelOutput |
|
|
| from transformers.models.bert.modeling_bert import ( |
| BertAttention, |
| BertEmbeddings, |
| BertEncoder, |
| BertIntermediate, |
| BertLayer, |
| BertModel, |
| BertOutput, |
| BertPooler, |
| BertPreTrainedModel, |
| ) |
|
|
| import logging |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def use_experts(layer_idx): |
| return True |
|
|
|
|
| def process_ffn(model): |
| if model.config.model_type == "bert": |
| inner_model = model.bert |
| else: |
| raise ValueError("Model type not recognized.") |
|
|
| for i in range(model.config.num_hidden_layers): |
| model_layer = inner_model.encoder.layer[i] |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config, intermediate_size, dropout): |
| nn.Module.__init__(self) |
|
|
| |
| self.fc1 = nn.Linear(config.hidden_size, intermediate_size) |
| if isinstance(config.hidden_act, str): |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.intermediate_act_fn = config.hidden_act |
|
|
| |
| self.fc2 = nn.Linear(intermediate_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, hidden_states: Tensor): |
| input_tensor = hidden_states |
| hidden_states = self.fc1(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| hidden_states = self.fc2(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| @dataclass |
| class MoEModelOutput(ModelOutput): |
| last_hidden_state: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
| gate_loss: torch.FloatTensor = None |
|
|
|
|
| @dataclass |
| class MoEModelOutputWithPooling(ModelOutput): |
| last_hidden_state: torch.FloatTensor = None |
| pooler_output: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
| gate_loss: torch.FloatTensor = None |
|
|
|
|
| |
|
|
|
|
| class MoELayer(nn.Module): |
| def __init__(self, hidden_size, num_experts, expert, route_method, vocab_size, hash_list): |
| nn.Module.__init__(self) |
| self.num_experts = num_experts |
| self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) |
| self.route_method = route_method |
| if route_method in ["gate-token", "gate-sentence"]: |
| self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() |
| elif route_method == "hash-random": |
| self.hash_list = self._random_hash_list(vocab_size) |
| elif route_method == "hash-balance": |
| self.hash_list = self._balance_hash_list(hash_list) |
| else: |
| raise KeyError("Routing method not supported.") |
|
|
| def _random_hash_list(self, vocab_size): |
| hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,)) |
| return hash_list |
|
|
| def _balance_hash_list(self, hash_list): |
| with open(hash_list, "rb") as file: |
| result = pickle.load(file) |
| result = torch.tensor(result, dtype=torch.int64) |
| return result |
|
|
| def _forward_gate_token(self, x): |
| bsz, seq_len, dim = x.size() |
|
|
| x = x.view(-1, dim) |
| logits_gate = self.gate(x) |
| prob_gate = F.softmax(logits_gate, dim=-1) |
| gate = torch.argmax(prob_gate, dim=-1) |
|
|
| order = gate.argsort(0) |
| num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) |
| gate_load = num_tokens.clone() |
| x = x[order] |
| x = x.split(num_tokens.tolist(), dim=0) |
|
|
| |
| P = prob_gate.mean(0) |
| temp = num_tokens.float() |
| f = temp / temp.sum(0, keepdim=True) |
| balance_loss = self.num_experts * torch.sum(P * f) |
|
|
| prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) |
| prob_gate = prob_gate[order] |
| prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) |
|
|
| def forward_expert(input_x, prob_x, expert_idx): |
| input_x = self.experts[expert_idx].forward(input_x) |
| input_x = input_x * prob_x |
| return input_x |
|
|
| x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] |
| x = torch.vstack(x) |
| x = x[order.argsort(0)] |
| x = x.view(bsz, seq_len, dim) |
|
|
| return x, balance_loss, gate_load |
|
|
| def _forward_gate_sentence(self, x, attention_mask): |
| x_masked = x * attention_mask.unsqueeze(-1) |
| x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) |
| logits_gate = self.gate(x_average) |
| prob_gate = F.softmax(logits_gate, dim=-1) |
| gate = torch.argmax(prob_gate, dim=-1) |
|
|
| order = gate.argsort(0) |
| num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0) |
| gate_load = num_sentences.clone() |
| x = x[order] |
| x = x.split(num_sentences.tolist(), dim=0) |
|
|
| |
| P = prob_gate.mean(0) |
| temp = num_sentences.float() |
| f = temp / temp.sum(0, keepdim=True) |
| balance_loss = self.num_experts * torch.sum(P * f) |
|
|
| prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) |
| prob_gate = prob_gate[order] |
| prob_gate = prob_gate.split(num_sentences.tolist(), dim=0) |
|
|
| def forward_expert(input_x, prob_x, expert_idx): |
| input_x = self.experts[expert_idx].forward(input_x) |
| input_x = input_x * prob_x.unsqueeze(-1) |
| return input_x |
|
|
| result = [] |
| for i in range(self.num_experts): |
| if x[i].size(0) > 0: |
| result.append(forward_expert(x[i], prob_gate[i], i)) |
| result = torch.vstack(result) |
| result = result[order.argsort(0)] |
|
|
| return result, balance_loss, gate_load |
|
|
| def _forward_sentence_single_expert(self, x, attention_mask): |
| x_masked = x * attention_mask.unsqueeze(-1) |
| x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) |
| logits_gate = self.gate(x_average) |
| prob_gate = F.softmax(logits_gate, dim=-1) |
| gate = torch.argmax(prob_gate, dim=-1) |
|
|
| gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0) |
| x = self.experts[gate.cpu().item()].forward(x) |
| return x, 0.0, gate_load |
|
|
| def _forward_hash(self, x, input_ids): |
| bsz, seq_len, dim = x.size() |
|
|
| x = x.view(-1, dim) |
| self.hash_list = self.hash_list.to(x.device) |
| gate = self.hash_list[input_ids.view(-1)] |
|
|
| order = gate.argsort(0) |
| num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) |
| gate_load = num_tokens.clone() |
| x = x[order] |
| x = x.split(num_tokens.tolist(), dim=0) |
|
|
| x = [self.experts[i].forward(x[i]) for i in range(self.num_experts)] |
| x = torch.vstack(x) |
| x = x[order.argsort(0)] |
| x = x.view(bsz, seq_len, dim) |
|
|
| return x, 0.0, gate_load |
|
|
| def forward(self, x, input_ids, attention_mask): |
| if self.route_method == "gate-token": |
| x, balance_loss, gate_load = self._forward_gate_token(x) |
| elif self.route_method == "gate-sentence": |
| if x.size(0) == 1: |
| x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask) |
| else: |
| x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) |
| elif self.route_method in ["hash-random", "hash-balance"]: |
| x, balance_loss, gate_load = self._forward_hash(x, input_ids) |
| else: |
| raise KeyError("Routing method not supported.") |
|
|
| return x, balance_loss, gate_load |
|
|
| |
|
|
|
|
|
|
| def symmetric_KL_loss(p, q): |
| """ symmetric KL-divergence 1/2*(KL(p||q)+KL(q||p)) """ |
| p, q = p.float(), q.float() |
| loss = (p - q) * (torch.log(p) - torch.log(q)) |
| return 0.5 * loss.sum() |
|
|
|
|
| def softmax(x): |
| return F.softmax(x, dim=-1, dtype=torch.float32) |
|
|
|
|
| class MoEBertLayer(BertLayer): |
| def __init__(self, config, layer_idx=-100): |
| nn.Module.__init__(self) |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
| self.attention = BertAttention(config) |
| self.is_decoder = config.is_decoder |
| self.add_cross_attention = config.add_cross_attention |
| if self.add_cross_attention: |
| assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" |
| self.crossattention = BertAttention(config) |
| self.intermediate = BertIntermediate(config) |
| self.output = BertOutput(config) |
|
|
| |
| self.use_experts = use_experts(layer_idx) |
| dropout = config.moebert_expert_dropout if self.use_experts else config.hidden_dropout_prob |
| if self.use_experts: |
| ffn = FeedForward(config, config.moebert_expert_dim, dropout) |
| self.experts = MoELayer( |
| hidden_size=config.hidden_size, |
| expert=ffn, |
| num_experts=config.moebert_expert_num, |
| route_method=config.moebert_route_method, |
| vocab_size=config.vocab_size, |
| hash_list=config.moebert_route_hash_list, |
| ) |
| else: |
| self.experts = FeedForward(config, config.intermediate_size, dropout) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| expert_input_ids=None, |
| expert_attention_mask=None, |
| ): |
| |
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| self_attention_outputs = self.attention( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| output_attentions=output_attentions, |
| past_key_value=self_attn_past_key_value, |
| ) |
| attention_output = self_attention_outputs[0] |
|
|
| |
| if self.is_decoder: |
| outputs = self_attention_outputs[1:-1] |
| present_key_value = self_attention_outputs[-1] |
| else: |
| outputs = self_attention_outputs[1:] |
|
|
| cross_attn_present_key_value = None |
| if self.is_decoder and encoder_hidden_states is not None: |
| assert hasattr( |
| self, "crossattention" |
| ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" |
|
|
| |
| cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
| cross_attention_outputs = self.crossattention( |
| attention_output, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| cross_attn_past_key_value, |
| output_attentions, |
| ) |
| attention_output = cross_attention_outputs[0] |
| outputs = outputs + cross_attention_outputs[1:-1] |
|
|
| |
| cross_attn_present_key_value = cross_attention_outputs[-1] |
| present_key_value = present_key_value + cross_attn_present_key_value |
|
|
| layer_output = self.feed_forward(attention_output, expert_input_ids, expert_attention_mask) |
| outputs = (layer_output,) + outputs |
|
|
| |
| if self.is_decoder: |
| outputs = outputs + (present_key_value,) |
|
|
| return outputs |
|
|
| def feed_forward(self, attention_output, expert_input_ids, expert_attention_mask): |
| if not self.use_experts: |
| layer_output = self.experts(attention_output) |
| return layer_output, 0.0 |
|
|
| layer_output, gate_loss, gate_load = self.experts( |
| attention_output, expert_input_ids, expert_attention_mask |
| ) |
| return layer_output, gate_loss |
|
|
|
|
| class MoEBertEncoder(BertEncoder): |
| def __init__(self, config): |
| nn.Module.__init__(self) |
| self.config = config |
| self.layer = nn.ModuleList([MoEBertLayer(config, i) for i in range(config.num_hidden_layers)]) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| expert_input_ids=None, |
| expert_attention_mask=None, |
| ): |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
|
| next_decoder_cache = () if use_cache else None |
| gate_loss = 0.0 |
| for i, layer_module in enumerate(self.layer): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_head_mask = head_mask[i] if head_mask is not None else None |
| past_key_value = past_key_values[i] if past_key_values is not None else None |
|
|
| if getattr(self.config, "gradient_checkpointing", False) and self.training: |
|
|
| if use_cache: |
| logger.warn( |
| "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " |
| "`use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, past_key_value, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| expert_input_ids, |
| expert_attention_mask, |
| ) |
|
|
| hidden_states = layer_outputs[0][0] |
| gate_loss = gate_loss + layer_outputs[0][1] |
| if use_cache: |
| next_decoder_cache += (layer_outputs[-1],) |
| if output_attentions: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
| if self.config.add_cross_attention: |
| all_cross_attentions = all_cross_attentions + (layer_outputs[2],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| next_decoder_cache, |
| all_hidden_states, |
| all_self_attentions, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return MoEModelOutput( |
| last_hidden_state=hidden_states, |
| past_key_values=next_decoder_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| cross_attentions=all_cross_attentions, |
| gate_loss=gate_loss, |
| ) |
|
|
|
|
| class MoEBertModel(BertModel): |
| def __init__(self, config, add_pooling_layer=True): |
| BertModel.__init__(self, config) |
| self.config = config |
|
|
| self.embeddings = BertEmbeddings(config) |
| self.encoder = MoEBertEncoder(config) |
|
|
| self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| expert_input_ids=None, |
| expert_attention_mask=None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if self.config.is_decoder: |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| else: |
| use_cache = False |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| batch_size, seq_length = input_shape |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| batch_size, seq_length = input_shape |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| |
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| if token_type_ids is None: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
| |
| |
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
| |
| |
| if self.config.is_decoder and encoder_hidden_states is not None: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| past_key_values_length=past_key_values_length, |
| ) |
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| expert_input_ids=expert_input_ids, |
| expert_attention_mask=expert_attention_mask, |
| ) |
| sequence_output = encoder_outputs[0] |
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
| if not return_dict: |
| return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
| return MoEModelOutputWithPooling( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| past_key_values=encoder_outputs.past_key_values, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| cross_attentions=encoder_outputs.cross_attentions, |
| gate_loss=encoder_outputs.gate_loss, |
| ) |
|
|