HOP4NLP / mlm.py
TCMVince's picture
Upload model
45dab51 verified
import torch
import torch.nn as nn
from torch.nn.functional import gelu
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
SequenceClassifierOutput,
)
from hopfield import HopfieldLayer
from hf_configuration import BertEnergyConfig
from positional import PositionalEncoding
class EnergyLMHead(nn.Module):
"""
MLM head for the energy backbone.
Architecture:
hidden -> dense -> gelu -> layer_norm -> decoder(vocab)
"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.embedding_dim, config.embedding_dim)
self.layer_norm = nn.LayerNorm(
config.embedding_dim,
eps=config.layer_norm_eps,
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True)
@property
def bias(self):
return self.decoder.bias
def forward(self, hidden_states):
x = self.dense(hidden_states)
x = gelu(x)
x = self.layer_norm(x)
x = self.dropout(x)
x = self.decoder(x)
return x
def _tie_weights(self):
pass
class AlbertMLMHead(nn.Module):
"""
ALBERT-style MLM head:
hidden (H) -> embedding (E) -> LN -> vocab (V)
"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.embedding_dim)
self.layer_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.embedding_dim, config.vocab_size, bias=True)
def forward(self, hidden_states):
x = self.dense(hidden_states)
x = gelu(x)
x = self.layer_norm(x)
return self.decoder(x)
class MLMHead(nn.Module):
"""
Standard BERT/RoBERTa-style MLM head.
"""
def __init__(self, input_dim, hidden_dim, config):
super().__init__()
self.dense = nn.Linear(input_dim, hidden_dim)
self.layer_norm = nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps)
self.decoder = nn.Linear(hidden_dim, config.vocab_size, bias=True)
@property
def bias(self):
return self.decoder.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
x = self.decoder(x)
return x
def _tie_weights(self):
pass
class BertPreTrainedModel(PreTrainedModel):
"""
Common pretrained model base.
"""
config_class = BertEnergyConfig
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class BertModel(BertPreTrainedModel):
"""
Standard transformer backbone.
Outputs: last hidden state, optional hidden state history.
"""
config_class = BertEnergyConfig
def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs):
super().__init__(config)
self.Emb_in = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=pad_idx)
self.posn = (
PositionalEncoding(
config.embedding_dim,
max_len=config.max_position_embeddings,
)
if config.positional
else None
)
self.embed_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
self.num_layers = config.num_hidden_layers
self.share_layers = config.share_layers
if self.share_layers:
self.embedding_hidden_in = nn.Linear(config.embedding_dim, config.hidden_size)
layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
activation=config.activation,
dim_feedforward=config.hidden_size,
dropout=config.hidden_dropout_prob,
layer_norm_eps=config.layer_norm_eps,
batch_first=True,
norm_first=True,
)
self.layers = nn.ModuleList([layer])
self.output_dim = config.hidden_size
else:
self.embedding_hidden_in = None
self.layers = nn.ModuleList(
[
nn.TransformerEncoderLayer(
d_model=config.embedding_dim,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.hidden_dropout_prob,
layer_norm_eps=config.layer_norm_eps,
batch_first=True,
norm_first=True,
)
for _ in range(config.num_hidden_layers)
]
)
self.output_dim = config.embedding_dim
self.post_init()
def get_input_embeddings(self):
return self.Emb_in
def set_input_embeddings(self, new_embeddings):
self.Emb_in = new_embeddings
def forward(self, input_ids, attention_mask=None, **kwargs):
x = self.Emb_in(input_ids)
if self.posn is not None:
x = x + self.posn(x)
x = self.embed_norm(x)
x = self.embed_dropout(x)
if self.share_layers:
x = self.embedding_hidden_in(x)
history = None if self.training else [x]
pad_mask = None
if attention_mask is not None:
pad_mask = ~attention_mask.to(torch.bool)
for i in range(self.num_layers):
layer = self.layers[0] if self.share_layers else self.layers[i]
x = layer(x, src_key_padding_mask=pad_mask)
if not self.training:
history.append(x)
return BaseModelOutput(
last_hidden_state=x,
hidden_states=history,
attentions=None,
)
class BertModelForMaskedLM(BertPreTrainedModel):
"""
Standard transformer model for MLM.
"""
config_class = BertEnergyConfig
ignore_index = -100
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config, add_pooling_layer=True, pad_idx=None):
super().__init__(config)
self.config = config
self.model = BertModel(config, pad_idx=pad_idx)
if config.share_layers:
self.lm_head = AlbertMLMHead(config)
else:
self.lm_head = MLMHead(config.embedding_dim, config.embedding_dim, config)
self.post_init()
if self.config.tie_word_embeddings:
self.tie_weights()
def get_input_embeddings(self):
return self.model.Emb_in
def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
if attention_mask is not None:
labels = labels.masked_fill(attention_mask == 0, self.ignore_index)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BertModelForSequenceClassification(BertPreTrainedModel):
"""
Standard transformer model for sequence classification.
"""
config_class = BertEnergyConfig
def __init__(
self,
config,
add_pooling_layer=True,
pad_idx=None,
num_labels=2,
classifier_dropout=None,
return_dict=True,
):
super().__init__(config)
self.config = config
self.num_labels = num_labels
self.return_dict = return_dict
self.model = BertModel(config, pad_idx=pad_idx)
output_dim = self.model.output_dim
dropout = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(output_dim, eps=config.layer_norm_eps)
self.classifier = nn.Linear(output_dim, num_labels)
self.post_init()
def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
if return_dict is None:
return_dict = self.return_dict
outputs = self.model(input_ids, **kwargs)
last_hidden_state = self.norm(outputs.last_hidden_state)
x = last_hidden_state[:, 0, :]
x = self.dropout(x)
logits = self.classifier(x)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
else:
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits, outputs.hidden_states, outputs.attentions)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BertEnergyModel(BertPreTrainedModel):
"""
Energy-based backbone.
Update rule:
g = LayerNorm(X)
X <- X - alpha * layer(g)
"""
config_class = BertEnergyConfig
def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs):
super().__init__(config)
self.config = config
self.num_layers = config.num_hidden_layers
self.alpha = config.alpha
self.Emb_in = nn.Embedding(
config.vocab_size,
config.embedding_dim,
padding_idx=pad_idx,
)
self.posn = (
PositionalEncoding(
config.embedding_dim,
max_len=config.max_position_embeddings,
)
if config.positional
else None
)
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
# External normalization, as in the original ET implementation
self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
self.layer = HopfieldLayer(
embedding_dim=config.embedding_dim,
nheads=config.num_attention_heads,
forward_memories=config.hidden_size,
forward_activation=config.activation,
bias=config.bias,
beta=config.beta,
device=None,
dropout=0.0,
initializer_range=config.initializer_hopfield_range,
)
self.post_init()
def set_input_embeddings(self, new_embeddings):
self.Emb_in = new_embeddings
def forward(self, input_ids, attention_mask=None, **kwargs):
x = self.Emb_in(input_ids)
if self.posn is not None:
x = x + self.posn(x)
x = self.embed_dropout(x)
keep_mask = attention_mask.to(torch.bool) if attention_mask is not None else None
history = None if self.training else [x]
for _ in range(self.num_layers):
g = self.norm(x)
update = self.layer(
g,
attention_mask=keep_mask,
)
x = x - self.alpha * update
if not self.training:
history.append(x)
return BaseModelOutput(
last_hidden_state=x,
hidden_states=history,
attentions=None,
)
class BertEnergyModelForMaskedLM(BertPreTrainedModel):
"""
Energy-based model for MLM.
"""
config_class = BertEnergyConfig
ignore_index = -100
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config, add_pooling_layer=True, pad_idx=None):
super().__init__(config)
self.config = config
self.model = BertEnergyModel(config, pad_idx=pad_idx)
self.lm_head = EnergyLMHead(config)
self.post_init()
if self.config.tie_word_embeddings:
self.tie_weights()
def get_input_embeddings(self):
return self.model.Emb_in
def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
if attention_mask is not None:
labels = labels.masked_fill(attention_mask == 0, self.ignore_index)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)