| import torch |
| import torch.nn as nn |
| from torch.nn.functional import gelu |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| MaskedLMOutput, |
| SequenceClassifierOutput, |
| ) |
|
|
| from hopfield import HopfieldLayer |
| from hf_configuration import BertEnergyConfig |
| from positional import PositionalEncoding |
|
|
|
|
| class EnergyLMHead(nn.Module): |
| """ |
| MLM head for the energy backbone. |
| |
| Architecture: |
| hidden -> dense -> gelu -> layer_norm -> decoder(vocab) |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.embedding_dim, config.embedding_dim) |
| self.layer_norm = nn.LayerNorm( |
| config.embedding_dim, |
| eps=config.layer_norm_eps, |
| ) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True) |
|
|
| @property |
| def bias(self): |
| return self.decoder.bias |
|
|
| def forward(self, hidden_states): |
| x = self.dense(hidden_states) |
| x = gelu(x) |
| x = self.layer_norm(x) |
| x = self.dropout(x) |
| x = self.decoder(x) |
| return x |
|
|
| def _tie_weights(self): |
| pass |
|
|
| class AlbertMLMHead(nn.Module): |
| """ |
| ALBERT-style MLM head: |
| hidden (H) -> embedding (E) -> LN -> vocab (V) |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.embedding_dim) |
| self.layer_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) |
| self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True) |
|
|
| def forward(self, hidden_states): |
| x = self.dense(hidden_states) |
| x = gelu(x) |
| x = self.layer_norm(x) |
| return self.decoder(x) |
|
|
|
|
| class MLMHead(nn.Module): |
| """ |
| Standard BERT/RoBERTa-style MLM head. |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim, config): |
| super().__init__() |
| self.dense = nn.Linear(input_dim, hidden_dim) |
| self.layer_norm = nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps) |
| self.decoder = nn.Linear(hidden_dim, config.vocab_size, bias=True) |
|
|
| @property |
| def bias(self): |
| return self.decoder.bias |
|
|
| def forward(self, features, **kwargs): |
| x = self.dense(features) |
| x = gelu(x) |
| x = self.layer_norm(x) |
| x = self.decoder(x) |
| return x |
|
|
| def _tie_weights(self): |
| pass |
|
|
|
|
| class BertPreTrainedModel(PreTrainedModel): |
| """ |
| Common pretrained model base. |
| """ |
| config_class = BertEnergyConfig |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
|
|
| class BertModel(BertPreTrainedModel): |
| """ |
| Standard transformer backbone. |
| Outputs: last hidden state, optional hidden state history. |
| """ |
|
|
| config_class = BertEnergyConfig |
|
|
| def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): |
| super().__init__(config) |
|
|
| self.Emb_in = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=pad_idx) |
| self.posn = ( |
| PositionalEncoding( |
| config.embedding_dim, |
| max_len=config.max_position_embeddings, |
| ) |
| if config.positional |
| else None |
| ) |
| self.embed_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) |
| self.embed_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| self.num_layers = config.num_hidden_layers |
| self.share_layers = config.share_layers |
|
|
| if self.share_layers: |
| self.embedding_hidden_in = nn.Linear(config.embedding_dim, config.hidden_size) |
|
|
| layer = nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_attention_heads, |
| activation=config.activation, |
| dim_feedforward=config.hidden_size, |
| dropout=config.hidden_dropout_prob, |
| layer_norm_eps=config.layer_norm_eps, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.layers = nn.ModuleList([layer]) |
| self.output_dim = config.hidden_size |
| else: |
| self.embedding_hidden_in = None |
| self.layers = nn.ModuleList( |
| [ |
| nn.TransformerEncoderLayer( |
| d_model=config.embedding_dim, |
| nhead=config.num_attention_heads, |
| dim_feedforward=config.intermediate_size, |
| dropout=config.hidden_dropout_prob, |
| layer_norm_eps=config.layer_norm_eps, |
| batch_first=True, |
| norm_first=True, |
| ) |
| for _ in range(config.num_hidden_layers) |
| ] |
| ) |
| self.output_dim = config.embedding_dim |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.Emb_in |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.Emb_in = new_embeddings |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| x = self.Emb_in(input_ids) |
|
|
| if self.posn is not None: |
| x = x + self.posn(x) |
|
|
| x = self.embed_norm(x) |
| x = self.embed_dropout(x) |
|
|
| if self.share_layers: |
| x = self.embedding_hidden_in(x) |
|
|
| history = None if self.training else [x] |
|
|
| pad_mask = None |
| if attention_mask is not None: |
| pad_mask = ~attention_mask.to(torch.bool) |
|
|
| for i in range(self.num_layers): |
| layer = self.layers[0] if self.share_layers else self.layers[i] |
| x = layer(x, src_key_padding_mask=pad_mask) |
|
|
| if not self.training: |
| history.append(x) |
|
|
| return BaseModelOutput( |
| last_hidden_state=x, |
| hidden_states=history, |
| attentions=None, |
| ) |
|
|
|
|
| class BertModelForMaskedLM(BertPreTrainedModel): |
| """ |
| Standard transformer model for MLM. |
| """ |
|
|
| config_class = BertEnergyConfig |
| ignore_index = -100 |
| _tied_weights_keys = ["lm_head.decoder.weight"] |
|
|
| def __init__(self, config, add_pooling_layer=True, pad_idx=None): |
| super().__init__(config) |
| self.config = config |
|
|
| self.model = BertModel(config, pad_idx=pad_idx) |
|
|
| if config.share_layers: |
| self.lm_head = AlbertMLMHead(config) |
| else: |
| self.lm_head = MLMHead(config.embedding_dim, config.embedding_dim, config) |
|
|
| self.post_init() |
|
|
| if self.config.tie_word_embeddings: |
| self.tie_weights() |
|
|
| def get_input_embeddings(self): |
| return self.model.Emb_in |
| |
| def set_input_embeddings(self, new_embeddings): |
| self.model.set_input_embeddings(new_embeddings) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head.decoder = new_embeddings |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs) |
| logits = self.lm_head(outputs.last_hidden_state) |
|
|
| loss = None |
| if labels is not None: |
| if attention_mask is not None: |
| labels = labels.masked_fill(attention_mask == 0, self.ignore_index) |
|
|
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class BertModelForSequenceClassification(BertPreTrainedModel): |
| """ |
| Standard transformer model for sequence classification. |
| """ |
|
|
| config_class = BertEnergyConfig |
|
|
| def __init__( |
| self, |
| config, |
| add_pooling_layer=True, |
| pad_idx=None, |
| num_labels=2, |
| classifier_dropout=None, |
| return_dict=True, |
| ): |
| super().__init__(config) |
| self.config = config |
| self.num_labels = num_labels |
| self.return_dict = return_dict |
|
|
| self.model = BertModel(config, pad_idx=pad_idx) |
|
|
| output_dim = self.model.output_dim |
| dropout = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.norm = nn.LayerNorm(output_dim, eps=config.layer_norm_eps) |
| self.classifier = nn.Linear(output_dim, num_labels) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None, return_dict=None, **kwargs): |
| if return_dict is None: |
| return_dict = self.return_dict |
|
|
| outputs = self.model(input_ids, **kwargs) |
| last_hidden_state = self.norm(outputs.last_hidden_state) |
|
|
| x = last_hidden_state[:, 0, :] |
| x = self.dropout(x) |
| logits = self.classifier(x) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
|
|
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| else: |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
|
|
| if not return_dict: |
| output = (logits, outputs.hidden_states, outputs.attentions) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class BertEnergyModel(BertPreTrainedModel): |
| """ |
| Energy-based backbone. |
| |
| Update rule: |
| g = LayerNorm(X) |
| X <- X - alpha * layer(g) |
| """ |
|
|
| config_class = BertEnergyConfig |
|
|
| def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): |
| super().__init__(config) |
|
|
| self.config = config |
| self.num_layers = config.num_hidden_layers |
| self.alpha = config.alpha |
|
|
| self.Emb_in = nn.Embedding( |
| config.vocab_size, |
| config.embedding_dim, |
| padding_idx=pad_idx, |
| ) |
|
|
| self.posn = ( |
| PositionalEncoding( |
| config.embedding_dim, |
| max_len=config.max_position_embeddings, |
| ) |
| if config.positional |
| else None |
| ) |
|
|
| self.embed_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| |
| self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) |
|
|
| self.layer = HopfieldLayer( |
| embedding_dim=config.embedding_dim, |
| nheads=config.num_attention_heads, |
| forward_memories=config.hidden_size, |
| forward_activation=config.activation, |
| bias=config.bias, |
| beta=config.beta, |
| device=None, |
| dropout=0.0, |
| initializer_range=config.initializer_hopfield_range, |
| ) |
|
|
| self.post_init() |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.Emb_in = new_embeddings |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| x = self.Emb_in(input_ids) |
|
|
| if self.posn is not None: |
| x = x + self.posn(x) |
|
|
| x = self.embed_dropout(x) |
|
|
| keep_mask = attention_mask.to(torch.bool) if attention_mask is not None else None |
| history = None if self.training else [x] |
|
|
| for _ in range(self.num_layers): |
| g = self.norm(x) |
| update = self.layer( |
| g, |
| attention_mask=keep_mask, |
| ) |
| x = x - self.alpha * update |
|
|
| if not self.training: |
| history.append(x) |
|
|
| return BaseModelOutput( |
| last_hidden_state=x, |
| hidden_states=history, |
| attentions=None, |
| ) |
|
|
|
|
| class BertEnergyModelForMaskedLM(BertPreTrainedModel): |
| """ |
| Energy-based model for MLM. |
| """ |
|
|
| config_class = BertEnergyConfig |
| ignore_index = -100 |
| _tied_weights_keys = ["lm_head.decoder.weight"] |
|
|
| def __init__(self, config, add_pooling_layer=True, pad_idx=None): |
| super().__init__(config) |
| self.config = config |
|
|
| self.model = BertEnergyModel(config, pad_idx=pad_idx) |
| self.lm_head = EnergyLMHead(config) |
|
|
| self.post_init() |
|
|
| if self.config.tie_word_embeddings: |
| self.tie_weights() |
|
|
| def get_input_embeddings(self): |
| return self.model.Emb_in |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.model.set_input_embeddings(new_embeddings) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head.decoder = new_embeddings |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs) |
| logits = self.lm_head(outputs.last_hidden_state) |
|
|
| loss = None |
| if labels is not None: |
| if attention_mask is not None: |
| labels = labels.masked_fill(attention_mask == 0, self.ignore_index) |
|
|
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |