| from typing import Any, Dict, Optional |
|
|
| import torch |
| from transformers import AutoModel, PreTrainedModel |
| from transformers.activations import ClippedGELUActivation, GELUActivation |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.modeling_utils import PoolerEndLogits |
|
|
| from .configuration_relik import RelikReaderConfig |
|
|
|
|
| class RelikReaderSample: |
| def __init__(self, **kwargs): |
| super().__setattr__("_d", {}) |
| self._d = kwargs |
|
|
| def __getattribute__(self, item): |
| return super(RelikReaderSample, self).__getattribute__(item) |
|
|
| def __getattr__(self, item): |
| if item.startswith("__") and item.endswith("__"): |
| |
| |
| raise AttributeError(item) |
| elif item in self._d: |
| return self._d[item] |
| else: |
| return None |
|
|
| def __setattr__(self, key, value): |
| if key in self._d: |
| self._d[key] = value |
| else: |
| super().__setattr__(key, value) |
| self._d[key] = value |
|
|
|
|
| activation2functions = { |
| "relu": torch.nn.ReLU(), |
| "gelu": GELUActivation(), |
| "gelu_10": ClippedGELUActivation(-10, 10), |
| } |
|
|
|
|
| class PoolerEndLogitsBi(PoolerEndLogits): |
| def __init__(self, config: PretrainedConfig): |
| super().__init__(config) |
| self.dense_1 = torch.nn.Linear(config.hidden_size, 2) |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| start_states: Optional[torch.FloatTensor] = None, |
| start_positions: Optional[torch.LongTensor] = None, |
| p_mask: Optional[torch.FloatTensor] = None, |
| ) -> torch.FloatTensor: |
| if p_mask is not None: |
| p_mask = p_mask.unsqueeze(-1) |
| logits = super().forward( |
| hidden_states, |
| start_states, |
| start_positions, |
| p_mask, |
| ) |
| return logits |
|
|
|
|
| class RelikReaderSpanModel(PreTrainedModel): |
| config_class = RelikReaderConfig |
|
|
| def __init__(self, config: RelikReaderConfig, *args, **kwargs): |
| super().__init__(config) |
| |
| self.config = config |
| self.transformer_model = ( |
| AutoModel.from_pretrained(self.config.transformer_model) |
| if self.config.num_layers is None |
| else AutoModel.from_pretrained( |
| self.config.transformer_model, num_hidden_layers=self.config.num_layers |
| ) |
| ) |
| self.transformer_model.resize_token_embeddings( |
| self.transformer_model.config.vocab_size |
| + self.config.additional_special_symbols |
| ) |
|
|
| self.activation = self.config.activation |
| self.linears_hidden_size = self.config.linears_hidden_size |
| self.use_last_k_layers = self.config.use_last_k_layers |
|
|
| |
| self.ned_start_classifier = self._get_projection_layer( |
| self.activation, last_hidden=2, layer_norm=False |
| ) |
| if self.config.binary_end_logits: |
| self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) |
| else: |
| self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) |
|
|
| |
| self.ed_start_projector = self._get_projection_layer(self.activation) |
| self.ed_end_projector = self._get_projection_layer(self.activation) |
|
|
| self.training = self.config.training |
|
|
| |
| self.criterion = torch.nn.CrossEntropyLoss() |
|
|
| def _get_projection_layer( |
| self, |
| activation: str, |
| last_hidden: Optional[int] = None, |
| input_hidden=None, |
| layer_norm: bool = True, |
| ) -> torch.nn.Sequential: |
| head_components = [ |
| torch.nn.Dropout(0.1), |
| torch.nn.Linear( |
| ( |
| self.transformer_model.config.hidden_size * self.use_last_k_layers |
| if input_hidden is None |
| else input_hidden |
| ), |
| self.linears_hidden_size, |
| ), |
| activation2functions[activation], |
| torch.nn.Dropout(0.1), |
| torch.nn.Linear( |
| self.linears_hidden_size, |
| self.linears_hidden_size if last_hidden is None else last_hidden, |
| ), |
| ] |
|
|
| if layer_norm: |
| head_components.append( |
| torch.nn.LayerNorm( |
| self.linears_hidden_size if last_hidden is None else last_hidden, |
| self.transformer_model.config.layer_norm_eps, |
| ) |
| ) |
|
|
| return torch.nn.Sequential(*head_components) |
|
|
| def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| mask = mask.unsqueeze(-1) |
| if next(self.parameters()).dtype == torch.float16: |
| logits = logits * (1 - mask) - 65500 * mask |
| else: |
| logits = logits * (1 - mask) - 1e30 * mask |
| return logits |
|
|
| def _get_model_features( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor], |
| ): |
| model_input = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "output_hidden_states": self.use_last_k_layers > 1, |
| } |
|
|
| if token_type_ids is not None: |
| model_input["token_type_ids"] = token_type_ids |
|
|
| model_output = self.transformer_model(**model_input) |
|
|
| if self.use_last_k_layers > 1: |
| model_features = torch.cat( |
| model_output[1][-self.use_last_k_layers :], dim=-1 |
| ) |
| else: |
| model_features = model_output[0] |
|
|
| return model_features |
|
|
| def compute_ned_end_logits( |
| self, |
| start_predictions, |
| start_labels, |
| model_features, |
| prediction_mask, |
| batch_size, |
| ) -> Optional[torch.Tensor]: |
| |
| |
| |
| start_positions = start_labels if self.training else start_predictions |
| start_positions_indices = ( |
| torch.arange(start_positions.size(1), device=start_positions.device) |
| .unsqueeze(0) |
| .expand(batch_size, -1)[start_positions > 0] |
| ).to(start_positions.device) |
|
|
| if len(start_positions_indices) > 0: |
| expanded_features = model_features.repeat_interleave( |
| torch.sum(start_positions > 0, dim=-1), dim=0 |
| ) |
| expanded_prediction_mask = prediction_mask.repeat_interleave( |
| torch.sum(start_positions > 0, dim=-1), dim=0 |
| ) |
| end_logits = self.ned_end_classifier( |
| hidden_states=expanded_features, |
| start_positions=start_positions_indices, |
| p_mask=expanded_prediction_mask, |
| ) |
|
|
| return end_logits |
|
|
| return None |
|
|
| def compute_classification_logits( |
| self, |
| model_features_start, |
| model_features_end, |
| special_symbols_features, |
| ) -> torch.Tensor: |
| model_start_features = self.ed_start_projector(model_features_start) |
| model_end_features = self.ed_end_projector(model_features_end) |
| model_start_features_symbols = self.ed_start_projector(special_symbols_features) |
| model_end_features_symbols = self.ed_end_projector(special_symbols_features) |
|
|
| model_ed_features = torch.cat( |
| [model_start_features, model_end_features], dim=-1 |
| ) |
| special_symbols_representation = torch.cat( |
| [model_start_features_symbols, model_end_features_symbols], dim=-1 |
| ) |
|
|
| logits = torch.bmm( |
| model_ed_features, |
| torch.permute(special_symbols_representation, (0, 2, 1)), |
| ) |
|
|
| logits = self._mask_logits(logits, (model_features_start == -100).all(2).long()) |
| return logits |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor] = None, |
| prediction_mask: Optional[torch.Tensor] = None, |
| special_symbols_mask: Optional[torch.Tensor] = None, |
| start_labels: Optional[torch.Tensor] = None, |
| end_labels: Optional[torch.Tensor] = None, |
| use_predefined_spans: bool = False, |
| *args, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| batch_size, seq_len = input_ids.shape |
|
|
| model_features = self._get_model_features( |
| input_ids, attention_mask, token_type_ids |
| ) |
|
|
| ned_start_labels = None |
|
|
| |
| if use_predefined_spans: |
| ned_start_logits, ned_start_probabilities, ned_start_predictions = ( |
| None, |
| None, |
| ( |
| torch.clone(start_labels) |
| if start_labels is not None |
| else torch.zeros_like(input_ids) |
| ), |
| ) |
| ned_end_logits, ned_end_probabilities, ned_end_predictions = ( |
| None, |
| None, |
| ( |
| torch.clone(end_labels) |
| if end_labels is not None |
| else torch.zeros_like(input_ids) |
| ), |
| ) |
| ned_start_predictions[ned_start_predictions > 0] = 1 |
| ned_end_predictions[end_labels > 0] = 1 |
| ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] |
|
|
| else: |
| |
| ned_start_logits = self.ned_start_classifier(model_features) |
| ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) |
| ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) |
| ned_start_predictions = ned_start_probabilities.argmax(dim=-1) |
|
|
| |
| ned_start_labels = ( |
| torch.zeros_like(start_labels) if start_labels is not None else None |
| ) |
|
|
| if ned_start_labels is not None: |
| ned_start_labels[start_labels == -100] = -100 |
| ned_start_labels[start_labels > 0] = 1 |
|
|
| ned_end_logits = self.compute_ned_end_logits( |
| ned_start_predictions, |
| ned_start_labels, |
| model_features, |
| prediction_mask, |
| batch_size, |
| ) |
|
|
| if ned_end_logits is not None: |
| ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) |
| if not self.config.binary_end_logits: |
| ned_end_predictions = torch.argmax( |
| ned_end_probabilities, dim=-1, keepdim=True |
| ) |
| ned_end_predictions = torch.zeros_like( |
| ned_end_probabilities |
| ).scatter_(1, ned_end_predictions, 1) |
| else: |
| ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) |
| else: |
| ned_end_logits, ned_end_probabilities = None, None |
| ned_end_predictions = ned_start_predictions.new_zeros( |
| batch_size, seq_len |
| ) |
|
|
| if not self.training: |
| |
| |
| end_preds_count = ned_end_predictions.sum(1) |
| |
| if (end_preds_count == 0).any() and (ned_start_predictions > 0).any(): |
| ned_start_predictions[ned_start_predictions == 1] = ( |
| end_preds_count != 0 |
| ).long() |
| ned_end_predictions = ned_end_predictions[end_preds_count != 0] |
|
|
| if end_labels is not None: |
| end_labels = end_labels[~(end_labels == -100).all(2)] |
|
|
| start_position, end_position = ( |
| (start_labels, end_labels) |
| if self.training |
| else (ned_start_predictions, ned_end_predictions) |
| ) |
| start_counts = (start_position > 0).sum(1) |
| if (start_counts > 0).any(): |
| ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) |
| |
| if (end_position > 0).sum() > 0: |
| ends_count = (end_position > 0).sum(1) |
| model_entity_start = torch.repeat_interleave( |
| model_features[start_position > 0], ends_count, dim=0 |
| ) |
| model_entity_end = torch.repeat_interleave( |
| model_features, start_counts, dim=0 |
| )[end_position > 0] |
| ents_count = torch.nn.utils.rnn.pad_sequence( |
| torch.split(ends_count, start_counts.tolist()), |
| batch_first=True, |
| padding_value=0, |
| ).sum(1) |
|
|
| model_entity_start = torch.nn.utils.rnn.pad_sequence( |
| torch.split(model_entity_start, ents_count.tolist()), |
| batch_first=True, |
| padding_value=-100, |
| ) |
|
|
| model_entity_end = torch.nn.utils.rnn.pad_sequence( |
| torch.split(model_entity_end, ents_count.tolist()), |
| batch_first=True, |
| padding_value=-100, |
| ) |
|
|
| ed_logits = self.compute_classification_logits( |
| model_entity_start, |
| model_entity_end, |
| model_features[special_symbols_mask].view( |
| batch_size, -1, model_features.shape[-1] |
| ), |
| ) |
| ed_probabilities = torch.softmax(ed_logits, dim=-1) |
| ed_predictions = torch.argmax(ed_probabilities, dim=-1) |
| else: |
| ed_logits, ed_probabilities, ed_predictions = ( |
| None, |
| ned_start_predictions.new_zeros(batch_size, seq_len), |
| ned_start_predictions.new_zeros(batch_size), |
| ) |
| |
| output_dict = dict( |
| batch_size=batch_size, |
| ned_start_logits=ned_start_logits, |
| ned_start_probabilities=ned_start_probabilities, |
| ned_start_predictions=ned_start_predictions, |
| ned_end_logits=ned_end_logits, |
| ned_end_probabilities=ned_end_probabilities, |
| ned_end_predictions=ned_end_predictions, |
| ed_logits=ed_logits, |
| ed_probabilities=ed_probabilities, |
| ed_predictions=ed_predictions, |
| ) |
|
|
| |
| if start_labels is not None and end_labels is not None and self.training: |
| |
|
|
| |
| if ned_start_logits is not None: |
| ned_start_loss = self.criterion( |
| ned_start_logits.view(-1, ned_start_logits.shape[-1]), |
| ned_start_labels.view(-1), |
| ) |
| else: |
| ned_start_loss = 0 |
|
|
| |
| |
|
|
| if ned_end_logits is not None: |
| ed_labels = end_labels.clone() |
| ed_labels = torch.nn.utils.rnn.pad_sequence( |
| torch.split(ed_labels[ed_labels > 0], ents_count.tolist()), |
| batch_first=True, |
| padding_value=-100, |
| ) |
| end_labels[end_labels > 0] = 1 |
| if not self.config.binary_end_logits: |
| |
| end_labels = end_labels.argmax(dim=-1) |
| ned_end_loss = self.criterion( |
| ned_end_logits.view(-1, ned_end_logits.shape[-1]), |
| end_labels.view(-1), |
| ) |
| else: |
| ned_end_loss = self.criterion( |
| ned_end_logits.reshape(-1, ned_end_logits.shape[-1]), |
| end_labels.reshape(-1).long(), |
| ) |
|
|
| |
| ed_loss = self.criterion( |
| ed_logits.view(-1, ed_logits.shape[-1]), |
| ed_labels.view(-1).long(), |
| ) |
|
|
| else: |
| ned_end_loss = 0 |
| ed_loss = 0 |
|
|
| output_dict["ned_start_loss"] = ned_start_loss |
| output_dict["ned_end_loss"] = ned_end_loss |
| output_dict["ed_loss"] = ed_loss |
|
|
| output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss |
|
|
| return output_dict |
|
|
|
|
| class RelikReaderREModel(PreTrainedModel): |
| config_class = RelikReaderConfig |
|
|
| def __init__(self, config, *args, **kwargs): |
| super().__init__(config) |
| |
| |
| self.config = config |
| self.transformer_model = ( |
| AutoModel.from_pretrained(config.transformer_model) |
| if config.num_layers is None |
| else AutoModel.from_pretrained( |
| config.transformer_model, num_hidden_layers=config.num_layers |
| ) |
| ) |
| self.transformer_model.resize_token_embeddings( |
| self.transformer_model.config.vocab_size |
| + config.additional_special_symbols |
| + config.additional_special_symbols_types, |
| ) |
|
|
| |
| self.ned_start_classifier = self._get_projection_layer( |
| config.activation, last_hidden=2, layer_norm=False |
| ) |
|
|
| self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) |
|
|
| self.relation_disambiguation_loss = ( |
| config.relation_disambiguation_loss |
| if hasattr(config, "relation_disambiguation_loss") |
| else False |
| ) |
|
|
| if self.config.entity_type_loss and self.config.add_entity_embedding: |
| input_hidden_ents = 3 |
| else: |
| input_hidden_ents = 2 |
|
|
| self.re_projector = self._get_projection_layer( |
| config.activation, |
| input_hidden=input_hidden_ents * self.transformer_model.config.hidden_size, |
| hidden=input_hidden_ents * self.config.linears_hidden_size, |
| last_hidden=2 * self.config.linears_hidden_size, |
| ) |
|
|
| self.re_relation_projector = self._get_projection_layer( |
| config.activation, |
| input_hidden=self.transformer_model.config.hidden_size, |
| ) |
|
|
| if self.config.entity_type_loss or self.relation_disambiguation_loss: |
| self.re_entities_projector = self._get_projection_layer( |
| config.activation, |
| input_hidden=2 * self.transformer_model.config.hidden_size, |
| ) |
| self.re_definition_projector = self._get_projection_layer( |
| config.activation, |
| ) |
|
|
| self.re_classifier = self._get_projection_layer( |
| config.activation, |
| input_hidden=config.linears_hidden_size, |
| last_hidden=2, |
| layer_norm=False, |
| ) |
|
|
| self.training = config.training |
|
|
| |
| self.criterion = torch.nn.CrossEntropyLoss() |
| self.criterion_type = torch.nn.BCEWithLogitsLoss() |
|
|
| def _get_projection_layer( |
| self, |
| activation: str, |
| last_hidden: Optional[int] = None, |
| hidden: Optional[int] = None, |
| input_hidden=None, |
| layer_norm: bool = True, |
| ) -> torch.nn.Sequential: |
| head_components = [ |
| torch.nn.Dropout(0.1), |
| torch.nn.Linear( |
| ( |
| self.transformer_model.config.hidden_size |
| * self.config.use_last_k_layers |
| if input_hidden is None |
| else input_hidden |
| ), |
| self.config.linears_hidden_size if hidden is None else hidden, |
| ), |
| activation2functions[activation], |
| torch.nn.Dropout(0.1), |
| torch.nn.Linear( |
| self.config.linears_hidden_size if hidden is None else hidden, |
| self.config.linears_hidden_size if last_hidden is None else last_hidden, |
| ), |
| ] |
|
|
| if layer_norm: |
| head_components.append( |
| torch.nn.LayerNorm( |
| ( |
| self.config.linears_hidden_size |
| if last_hidden is None |
| else last_hidden |
| ), |
| self.transformer_model.config.layer_norm_eps, |
| ) |
| ) |
|
|
| return torch.nn.Sequential(*head_components) |
|
|
| def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| mask = mask.unsqueeze(-1) |
| if next(self.parameters()).dtype == torch.float16: |
| logits = logits * (1 - mask) - 65500 * mask |
| else: |
| logits = logits * (1 - mask) - 1e30 * mask |
| return logits |
|
|
| def _get_model_features( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor], |
| ): |
| model_input = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "output_hidden_states": self.config.use_last_k_layers > 1, |
| } |
|
|
| if token_type_ids is not None: |
| model_input["token_type_ids"] = token_type_ids |
|
|
| model_output = self.transformer_model(**model_input) |
|
|
| if self.config.use_last_k_layers > 1: |
| model_features = torch.cat( |
| model_output[1][-self.config.use_last_k_layers :], dim=-1 |
| ) |
| else: |
| model_features = model_output[0] |
|
|
| return model_features |
|
|
| def compute_ned_end_logits( |
| self, |
| start_predictions, |
| start_labels, |
| model_features, |
| prediction_mask, |
| batch_size, |
| mask_preceding: bool = False, |
| ) -> Optional[torch.Tensor]: |
| |
| |
| |
| start_positions = start_labels if self.training else start_predictions |
| start_positions_indices = ( |
| torch.arange(start_positions.size(1), device=start_positions.device) |
| .unsqueeze(0) |
| .expand(batch_size, -1)[start_positions > 0] |
| ).to(start_positions.device) |
|
|
| if len(start_positions_indices) > 0: |
| expanded_features = model_features.repeat_interleave( |
| torch.sum(start_positions > 0, dim=-1), dim=0 |
| ) |
| expanded_prediction_mask = prediction_mask.repeat_interleave( |
| torch.sum(start_positions > 0, dim=-1), dim=0 |
| ) |
| if mask_preceding: |
| expanded_prediction_mask[ |
| torch.arange( |
| expanded_prediction_mask.shape[1], |
| device=expanded_prediction_mask.device, |
| ) |
| < start_positions_indices.unsqueeze(1) |
| ] = 1 |
| end_logits = self.ned_end_classifier( |
| hidden_states=expanded_features, |
| start_positions=start_positions_indices, |
| p_mask=expanded_prediction_mask, |
| ) |
|
|
| return end_logits |
|
|
| return None |
|
|
| def compute_relation_logits( |
| self, |
| model_entity_features, |
| special_symbols_features, |
| ) -> torch.Tensor: |
| model_subject_object_features = self.re_projector(model_entity_features) |
| model_subject_features = model_subject_object_features[ |
| :, :, : model_subject_object_features.shape[-1] // 2 |
| ] |
| model_object_features = model_subject_object_features[ |
| :, :, model_subject_object_features.shape[-1] // 2 : |
| ] |
| special_symbols_start_representation = self.re_relation_projector( |
| special_symbols_features |
| ) |
| re_logits = torch.einsum( |
| "bse,bde,bfe->bsdfe", |
| model_subject_features, |
| model_object_features, |
| special_symbols_start_representation, |
| ) |
| re_logits = self.re_classifier(re_logits) |
|
|
| return re_logits |
|
|
| def compute_entity_logits( |
| self, |
| model_entity_features, |
| special_symbols_features, |
| ) -> torch.Tensor: |
| model_ed_features = self.re_entities_projector(model_entity_features) |
| special_symbols_ed_representation = self.re_definition_projector( |
| special_symbols_features |
| ) |
|
|
| logits = torch.bmm( |
| model_ed_features, |
| torch.permute(special_symbols_ed_representation, (0, 2, 1)), |
| ) |
| logits = self._mask_logits( |
| logits, (model_entity_features == -100).all(2).long() |
| ) |
| return logits |
|
|
| def compute_loss(self, logits, labels, mask=None): |
| logits = logits.reshape(-1, logits.shape[-1]) |
| labels = labels.reshape(-1).long() |
| if mask is not None: |
| return self.criterion(logits[mask], labels[mask]) |
| return self.criterion(logits, labels) |
|
|
| def compute_ned_type_loss( |
| self, |
| disambiguation_labels, |
| re_ned_entities_logits, |
| ned_type_logits, |
| re_entities_logits, |
| entity_types, |
| mask, |
| ): |
| if self.config.entity_type_loss and self.relation_disambiguation_loss: |
| return self.criterion_type( |
| re_ned_entities_logits[disambiguation_labels != -100], |
| disambiguation_labels[disambiguation_labels != -100], |
| ) |
| if self.config.entity_type_loss: |
| return self.criterion_type( |
| ned_type_logits[mask], |
| disambiguation_labels[:, :, :entity_types][mask], |
| ) |
|
|
| if self.relation_disambiguation_loss: |
| return self.criterion_type( |
| re_entities_logits[disambiguation_labels != -100], |
| disambiguation_labels[disambiguation_labels != -100], |
| ) |
| return 0 |
|
|
| def compute_relation_loss(self, relation_labels, re_logits): |
| return self.compute_loss( |
| re_logits, relation_labels, relation_labels.view(-1) != -100 |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: torch.Tensor, |
| prediction_mask: Optional[torch.Tensor] = None, |
| special_symbols_mask: Optional[torch.Tensor] = None, |
| special_symbols_mask_entities: Optional[torch.Tensor] = None, |
| start_labels: Optional[torch.Tensor] = None, |
| end_labels: Optional[torch.Tensor] = None, |
| disambiguation_labels: Optional[torch.Tensor] = None, |
| relation_labels: Optional[torch.Tensor] = None, |
| relation_threshold: float = None, |
| is_validation: bool = False, |
| is_prediction: bool = False, |
| use_predefined_spans: bool = False, |
| *args, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| relation_threshold = ( |
| self.config.threshold if relation_threshold is None else relation_threshold |
| ) |
|
|
| batch_size = input_ids.shape[0] |
|
|
| model_features = self._get_model_features( |
| input_ids, attention_mask, token_type_ids |
| ) |
|
|
| |
| if use_predefined_spans: |
| ned_start_logits, ned_start_probabilities, ned_start_predictions = ( |
| None, |
| None, |
| torch.zeros_like(start_labels), |
| ) |
| ned_end_logits, ned_end_probabilities, ned_end_predictions = ( |
| None, |
| None, |
| torch.zeros_like(end_labels), |
| ) |
|
|
| ned_start_predictions[start_labels > 0] = 1 |
| ned_end_predictions[end_labels > 0] = 1 |
| ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] |
| ned_start_labels = start_labels |
| ned_start_labels[start_labels > 0] = 1 |
| else: |
| |
| ned_start_logits = self.ned_start_classifier(model_features) |
| if is_validation or is_prediction: |
| ned_start_logits = self._mask_logits( |
| ned_start_logits, prediction_mask |
| ) |
| ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) |
| ned_start_predictions = ned_start_probabilities.argmax(dim=-1) |
|
|
| |
| ned_start_labels = ( |
| torch.zeros_like(start_labels) if start_labels is not None else None |
| ) |
|
|
| |
| if ned_start_labels is not None: |
| ned_start_labels[start_labels == -100] = -100 |
| ned_start_labels[start_labels > 0] = 1 |
|
|
| |
| |
| ned_end_logits = self.compute_ned_end_logits( |
| ned_start_predictions, |
| ned_start_labels, |
| model_features, |
| prediction_mask, |
| batch_size, |
| True, |
| ) |
|
|
| if ned_end_logits is not None: |
| |
| |
| ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) |
| ned_end_predictions = ned_end_probabilities.argmax(dim=-1) |
| else: |
| ned_end_logits, ned_end_probabilities = None, None |
| ned_end_predictions = torch.zeros_like(ned_start_predictions) |
|
|
| if is_prediction or is_validation: |
| end_preds_count = ned_end_predictions.sum(1) |
| |
| if (end_preds_count == 0).any() and (ned_start_predictions > 0).any(): |
| ned_start_predictions[ned_start_predictions == 1] = ( |
| end_preds_count != 0 |
| ).long() |
| ned_end_predictions = ned_end_predictions[end_preds_count != 0] |
|
|
| if end_labels is not None: |
| end_labels = end_labels[~(end_labels == -100).all(2)] |
|
|
| start_position, end_position = ( |
| (start_labels, end_labels) |
| if (not is_prediction and not is_validation) |
| else (ned_start_predictions, ned_end_predictions) |
| ) |
|
|
| start_counts = (start_position > 0).sum(1) |
| if (start_counts > 0).any(): |
| ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) |
| else: |
| ned_end_predictions = [torch.empty(0, input_ids.shape[1], dtype=torch.int64) for _ in range(batch_size)] |
| |
| |
| |
| |
| if (end_position > 0).sum() > 0: |
| ends_count = (end_position > 0).sum(1) |
| model_subject_features = torch.cat( |
| [ |
| torch.repeat_interleave( |
| model_features[start_position > 0], ends_count, dim=0 |
| ), |
| torch.repeat_interleave(model_features, start_counts, dim=0)[ |
| end_position > 0 |
| ], |
| ], |
| dim=-1, |
| ) |
| ents_count = torch.nn.utils.rnn.pad_sequence( |
| torch.split(ends_count, start_counts.tolist()), |
| batch_first=True, |
| padding_value=0, |
| ).sum(1) |
| model_subject_features = torch.nn.utils.rnn.pad_sequence( |
| torch.split(model_subject_features, ents_count.tolist()), |
| batch_first=True, |
| padding_value=-100, |
| ) |
|
|
| |
| |
|
|
| |
| |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: |
| (re_ned_entities_logits) = self.compute_entity_logits( |
| model_subject_features, |
| model_features[ |
| special_symbols_mask | special_symbols_mask_entities |
| ].view(batch_size, -1, model_features.shape[-1]), |
| ) |
| entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() |
| ned_type_logits = re_ned_entities_logits[:, :, :entity_types] |
| re_entities_logits = re_ned_entities_logits[:, :, entity_types:] |
|
|
| if self.config.entity_type_loss: |
| ned_type_probabilities = torch.sigmoid(ned_type_logits) |
| ned_type_predictions = ned_type_probabilities.argmax(dim=-1) |
|
|
| if self.config.add_entity_embedding: |
| special_symbols_representation = model_features[ |
| special_symbols_mask_entities |
| ].view(batch_size, entity_types, -1) |
|
|
| entities_representation = torch.einsum( |
| "bsp,bpe->bse", |
| ned_type_probabilities, |
| special_symbols_representation, |
| ) |
| model_subject_features = torch.cat( |
| [model_subject_features, entities_representation], dim=-1 |
| ) |
| re_entities_probabilities = torch.sigmoid(re_entities_logits) |
| re_entities_predictions = re_entities_probabilities.round() |
| else: |
| ( |
| ned_type_logits, |
| ned_type_probabilities, |
| re_entities_logits, |
| re_entities_probabilities, |
| ) = (None, None, None, None) |
| ned_type_predictions, re_entities_predictions = ( |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), |
| ) |
|
|
| |
| re_logits = self.compute_relation_logits( |
| model_subject_features, |
| model_features[special_symbols_mask].view( |
| batch_size, -1, model_features.shape[-1] |
| ), |
| ) |
|
|
| re_probabilities = torch.softmax(re_logits, dim=-1) |
| |
| re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold |
| re_probabilities = re_probabilities[:, :, :, :, 1] |
| else: |
| ( |
| ned_type_logits, |
| ned_type_probabilities, |
| re_entities_logits, |
| re_entities_probabilities, |
| ) = (None, None, None, None) |
| ned_type_predictions, re_entities_predictions = ( |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), |
| ) |
| re_logits, re_probabilities, re_predictions = ( |
| torch.zeros( |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long |
| ).to(input_ids.device), |
| torch.zeros( |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long |
| ).to(input_ids.device), |
| torch.zeros( |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long |
| ).to(input_ids.device), |
| ) |
|
|
| |
| output_dict = dict( |
| batch_size=batch_size, |
| ned_start_logits=ned_start_logits, |
| ned_start_probabilities=ned_start_probabilities, |
| ned_start_predictions=ned_start_predictions, |
| ned_end_logits=ned_end_logits, |
| ned_end_probabilities=ned_end_probabilities, |
| ned_end_predictions=ned_end_predictions, |
| ned_type_logits=ned_type_logits, |
| ned_type_probabilities=ned_type_probabilities, |
| ned_type_predictions=ned_type_predictions, |
| re_entities_logits=re_entities_logits, |
| re_entities_probabilities=re_entities_probabilities, |
| re_entities_predictions=re_entities_predictions, |
| re_logits=re_logits, |
| re_probabilities=re_probabilities, |
| re_predictions=re_predictions, |
| ) |
|
|
| if ( |
| start_labels is not None |
| and end_labels is not None |
| and relation_labels is not None |
| and is_prediction is False |
| ): |
| ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) |
| end_labels[end_labels > 0] = 1 |
| ned_end_loss = self.compute_loss(ned_end_logits, end_labels) |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: |
| ned_type_loss = self.compute_ned_type_loss( |
| disambiguation_labels, |
| re_ned_entities_logits, |
| ned_type_logits, |
| re_entities_logits, |
| entity_types, |
| (model_subject_features != -100).all(2), |
| ) |
| relation_loss = self.compute_relation_loss(relation_labels, re_logits) |
| |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: |
| output_dict["loss"] = ( |
| ned_start_loss + ned_end_loss + relation_loss + ned_type_loss |
| ) / 4 |
| output_dict["ned_type_loss"] = ned_type_loss |
| else: |
| output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + ( |
| (9 / 10) * relation_loss |
| ) |
| output_dict["ned_start_loss"] = ned_start_loss |
| output_dict["ned_end_loss"] = ned_end_loss |
| output_dict["re_loss"] = relation_loss |
|
|
| return output_dict |
|
|