import torch import torch.nn as nn from torch.nn.functional import gelu from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel from transformers.modeling_outputs import ( BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, ) from hopfield import HopfieldLayer from hf_configuration import BertEnergyConfig from positional import PositionalEncoding class EnergyLMHead(nn.Module): """ MLM head for the energy backbone. Architecture: hidden -> dense -> gelu -> layer_norm -> decoder(vocab) """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.embedding_dim, config.embedding_dim) self.layer_norm = nn.LayerNorm( config.embedding_dim, eps=config.layer_norm_eps, ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True) @property def bias(self): return self.decoder.bias def forward(self, hidden_states): x = self.dense(hidden_states) x = gelu(x) x = self.layer_norm(x) x = self.dropout(x) x = self.decoder(x) return x def _tie_weights(self): pass class AlbertMLMHead(nn.Module): """ ALBERT-style MLM head: hidden (H) -> embedding (E) -> LN -> vocab (V) """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.embedding_dim) self.layer_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True) def forward(self, hidden_states): x = self.dense(hidden_states) x = gelu(x) x = self.layer_norm(x) return self.decoder(x) class MLMHead(nn.Module): """ Standard BERT/RoBERTa-style MLM head. """ def __init__(self, input_dim, hidden_dim, config): super().__init__() self.dense = nn.Linear(input_dim, hidden_dim) self.layer_norm = nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps) self.decoder = nn.Linear(hidden_dim, config.vocab_size, bias=True) @property def bias(self): return self.decoder.bias def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) x = self.decoder(x) return x def _tie_weights(self): pass class BertPreTrainedModel(PreTrainedModel): """ Common pretrained model base. """ config_class = BertEnergyConfig def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class BertModel(BertPreTrainedModel): """ Standard transformer backbone. Outputs: last hidden state, optional hidden state history. """ config_class = BertEnergyConfig def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): super().__init__(config) self.Emb_in = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=pad_idx) self.posn = ( PositionalEncoding( config.embedding_dim, max_len=config.max_position_embeddings, ) if config.positional else None ) self.embed_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) self.embed_dropout = nn.Dropout(config.hidden_dropout_prob) self.num_layers = config.num_hidden_layers self.share_layers = config.share_layers if self.share_layers: self.embedding_hidden_in = nn.Linear(config.embedding_dim, config.hidden_size) layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, activation=config.activation, dim_feedforward=config.hidden_size, dropout=config.hidden_dropout_prob, layer_norm_eps=config.layer_norm_eps, batch_first=True, norm_first=True, ) self.layers = nn.ModuleList([layer]) self.output_dim = config.hidden_size else: self.embedding_hidden_in = None self.layers = nn.ModuleList( [ nn.TransformerEncoderLayer( d_model=config.embedding_dim, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob, layer_norm_eps=config.layer_norm_eps, batch_first=True, norm_first=True, ) for _ in range(config.num_hidden_layers) ] ) self.output_dim = config.embedding_dim self.post_init() def get_input_embeddings(self): return self.Emb_in def set_input_embeddings(self, new_embeddings): self.Emb_in = new_embeddings def forward(self, input_ids, attention_mask=None, **kwargs): x = self.Emb_in(input_ids) if self.posn is not None: x = x + self.posn(x) x = self.embed_norm(x) x = self.embed_dropout(x) if self.share_layers: x = self.embedding_hidden_in(x) history = None if self.training else [x] pad_mask = None if attention_mask is not None: pad_mask = ~attention_mask.to(torch.bool) for i in range(self.num_layers): layer = self.layers[0] if self.share_layers else self.layers[i] x = layer(x, src_key_padding_mask=pad_mask) if not self.training: history.append(x) return BaseModelOutput( last_hidden_state=x, hidden_states=history, attentions=None, ) class BertModelForMaskedLM(BertPreTrainedModel): """ Standard transformer model for MLM. """ config_class = BertEnergyConfig ignore_index = -100 _tied_weights_keys = ["lm_head.decoder.weight"] def __init__(self, config, add_pooling_layer=True, pad_idx=None): super().__init__(config) self.config = config self.model = BertModel(config, pad_idx=pad_idx) if config.share_layers: self.lm_head = AlbertMLMHead(config) else: self.lm_head = MLMHead(config.embedding_dim, config.embedding_dim, config) self.post_init() if self.config.tie_word_embeddings: self.tie_weights() def get_input_embeddings(self): return self.model.Emb_in def set_input_embeddings(self, new_embeddings): self.model.set_input_embeddings(new_embeddings) def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs) logits = self.lm_head(outputs.last_hidden_state) loss = None if labels is not None: if attention_mask is not None: labels = labels.masked_fill(attention_mask == 0, self.ignore_index) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertModelForSequenceClassification(BertPreTrainedModel): """ Standard transformer model for sequence classification. """ config_class = BertEnergyConfig def __init__( self, config, add_pooling_layer=True, pad_idx=None, num_labels=2, classifier_dropout=None, return_dict=True, ): super().__init__(config) self.config = config self.num_labels = num_labels self.return_dict = return_dict self.model = BertModel(config, pad_idx=pad_idx) output_dim = self.model.output_dim dropout = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob self.dropout = nn.Dropout(dropout) self.norm = nn.LayerNorm(output_dim, eps=config.layer_norm_eps) self.classifier = nn.Linear(output_dim, num_labels) self.post_init() def forward(self, input_ids, labels=None, return_dict=None, **kwargs): if return_dict is None: return_dict = self.return_dict outputs = self.model(input_ids, **kwargs) last_hidden_state = self.norm(outputs.last_hidden_state) x = last_hidden_state[:, 0, :] x = self.dropout(x) logits = self.classifier(x) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() loss = loss_fct(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) else: loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits, outputs.hidden_states, outputs.attentions) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertEnergyModel(BertPreTrainedModel): """ Energy-based backbone. Update rule: g = LayerNorm(X) X <- X - alpha * layer(g) """ config_class = BertEnergyConfig def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): super().__init__(config) self.config = config self.num_layers = config.num_hidden_layers self.alpha = config.alpha self.Emb_in = nn.Embedding( config.vocab_size, config.embedding_dim, padding_idx=pad_idx, ) self.posn = ( PositionalEncoding( config.embedding_dim, max_len=config.max_position_embeddings, ) if config.positional else None ) self.embed_dropout = nn.Dropout(config.hidden_dropout_prob) # External normalization, as in the original ET implementation self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) self.layer = HopfieldLayer( embedding_dim=config.embedding_dim, nheads=config.num_attention_heads, forward_memories=config.hidden_size, forward_activation=config.activation, bias=config.bias, beta=config.beta, device=None, dropout=0.0, initializer_range=config.initializer_hopfield_range, ) self.post_init() def set_input_embeddings(self, new_embeddings): self.Emb_in = new_embeddings def forward(self, input_ids, attention_mask=None, **kwargs): x = self.Emb_in(input_ids) if self.posn is not None: x = x + self.posn(x) x = self.embed_dropout(x) keep_mask = attention_mask.to(torch.bool) if attention_mask is not None else None history = None if self.training else [x] for _ in range(self.num_layers): g = self.norm(x) update = self.layer( g, attention_mask=keep_mask, ) x = x - self.alpha * update if not self.training: history.append(x) return BaseModelOutput( last_hidden_state=x, hidden_states=history, attentions=None, ) class BertEnergyModelForMaskedLM(BertPreTrainedModel): """ Energy-based model for MLM. """ config_class = BertEnergyConfig ignore_index = -100 _tied_weights_keys = ["lm_head.decoder.weight"] def __init__(self, config, add_pooling_layer=True, pad_idx=None): super().__init__(config) self.config = config self.model = BertEnergyModel(config, pad_idx=pad_idx) self.lm_head = EnergyLMHead(config) self.post_init() if self.config.tie_word_embeddings: self.tie_weights() def get_input_embeddings(self): return self.model.Emb_in def set_input_embeddings(self, new_embeddings): self.model.set_input_embeddings(new_embeddings) def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs) logits = self.lm_head(outputs.last_hidden_state) loss = None if labels is not None: if attention_mask is not None: labels = labels.masked_fill(attention_mask == 0, self.ignore_index) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )