| from typing import List, Optional |
|
|
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
| import torch |
| from haystack.nodes.base import BaseComponent |
| from haystack.modeling.utils import initialize_device_settings |
| from haystack.schema import Document |
|
|
|
|
| class EntailmentChecker(BaseComponent): |
| """ |
| This node checks the entailment between every document content and the query. |
| It enrichs the documents metadata with entailment informations. |
| It also returns aggregate entailment information. |
| """ |
|
|
| outgoing_edges = 1 |
|
|
| def __init__( |
| self, |
| model_name_or_path: str = "roberta-large-mnli", |
| model_version: Optional[str] = None, |
| tokenizer: Optional[str] = None, |
| use_gpu: bool = True, |
| batch_size: int = 16, |
| entailment_contradiction_threshold: float = 0.5, |
| ): |
| """ |
| Load a Natural Language Inference model from Transformers. |
| |
| :param model_name_or_path: Directory of a saved model or the name of a public model. |
| See https://huggingface.co/models for full list of available models. |
| :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. |
| :param tokenizer: Name of the tokenizer (usually the same as model) |
| :param use_gpu: Whether to use GPU (if available). |
| :param batch_size: Number of Documents to be processed at a time. |
| :param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction |
| (aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account |
| """ |
| super().__init__() |
|
|
| self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) |
|
|
| tokenizer = tokenizer or model_name_or_path |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| pretrained_model_name_or_path=model_name_or_path, revision=model_version |
| ) |
| self.batch_size = batch_size |
| self.entailment_contradiction_threshold = entailment_contradiction_threshold |
| self.model.to(str(self.devices[0])) |
|
|
| id2label = AutoConfig.from_pretrained(model_name_or_path).id2label |
| self.labels = [id2label[k].lower() for k in sorted(id2label)] |
| if "entailment" not in self.labels: |
| raise ValueError( |
| "The model config must contain entailment value in the id2label dict." |
| ) |
|
|
| def run(self, query: str, documents: List[Document]): |
|
|
| scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 |
| for i, doc in enumerate(documents): |
| entailment_info = self.get_entailment(premise=doc.content, hypotesis=query) |
| doc.meta["entailment_info"] = entailment_info |
|
|
| scores += doc.score |
| con, neu, ent = ( |
| entailment_info["contradiction"], |
| entailment_info["neutral"], |
| entailment_info["entailment"], |
| ) |
| agg_con += con * doc.score |
| agg_neu += neu * doc.score |
| agg_ent += ent * doc.score |
|
|
| |
| |
| if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold: |
| break |
|
|
| aggregate_entailment_info = { |
| "contradiction": round(agg_con / scores, 2), |
| "neutral": round(agg_neu / scores, 2), |
| "entailment": round(agg_ent / scores, 2), |
| } |
|
|
| entailment_checker_result = { |
| "documents": documents[: i + 1], |
| "aggregate_entailment_info": aggregate_entailment_info, |
| } |
|
|
| return entailment_checker_result, "output_1" |
|
|
| def run_batch(self, queries: List[str], documents: List[Document]): |
| pass |
|
|
| def get_entailment(self, premise, hypotesis): |
| with torch.inference_mode(): |
| inputs = self.tokenizer( |
| f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt" |
| ).to(self.devices[0]) |
| out = self.model(**inputs) |
| logits = out.logits |
| probs = ( |
| torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy() |
| ) |
| entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)} |
| return entailment_dict |
|
|