| import sys |
|
|
|
|
| import os |
| import os.path as osp |
| from typing import Any, Union, List, Dict |
|
|
| import torch |
| import torch.nn as nn |
| from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings |
| from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings |
| from stark_qa.evaluator import Evaluator |
|
|
|
|
| class ModelForSTaRKQA(nn.Module): |
| |
| def __init__(self, skb, query_emb_dir='.'): |
| """ |
| Initializes the model with the given knowledge base. |
| |
| Args: |
| skb: Knowledge base containing candidate information. |
| """ |
| super(ModelForSTaRKQA, self).__init__() |
| self.skb = skb |
|
|
| self.candidate_ids = skb.candidate_ids |
| self.num_candidates = skb.num_candidates |
| self.query_emb_dir = query_emb_dir |
|
|
| query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt') |
| if os.path.exists(query_emb_path): |
| print(f'Load query embeddings from {query_emb_path}') |
| self.query_emb_dict = torch.load(query_emb_path) |
| else: |
| self.query_emb_dict = {} |
| self.evaluator = Evaluator(self.candidate_ids) |
| |
| def forward(self, |
| query: Union[str, List[str]], |
| candidates: List[int] = None, |
| query_id: Union[int, List[int]] = None, |
| **kwargs: Any) -> Dict[str, Any]: |
| """ |
| Forward pass to compute predictions for the given query. |
| |
| Args: |
| query (Union[str, list]): Query string or a list of query strings. |
| candidates (Union[list, None]): A list of candidate ids (optional). |
| query_id (Union[int, list, None]): Query index (optional). |
| |
| Returns: |
| pred_dict (dict): A dictionary of predicted scores or answer ids. |
| """ |
| raise NotImplementedError |
| |
| def get_query_emb(self, |
| query: Union[str, List[str]], |
| query_id: Union[int, List[int]], |
| emb_model: str = 'text-embedding-ada-002', |
| **encode_kwargs) -> torch.Tensor: |
| """ |
| Retrieves or computes the embedding for the given query. |
| |
| Args: |
| query (str): Query string. |
| query_id (int): Query index. |
| emb_model (str): Embedding model to use. |
| |
| Returns: |
| query_emb (torch.Tensor): Query embedding. |
| """ |
| if isinstance(query_id, int): |
| query_id = [query_id] |
| if isinstance(query, str): |
| query = [query] |
|
|
| if query_id is None: |
| query_emb = get_embeddings(query, emb_model, **encode_kwargs) |
| elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))): |
| query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0) |
| else: |
| query_emb = get_embeddings(query, emb_model, **encode_kwargs) |
| for qid, emb in zip(query_id, query_emb): |
| self.query_emb_dict[qid] = emb.view(1, -1) |
| torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt')) |
| |
| query_emb = query_emb.view(len(query), -1) |
| return query_emb |
| |
| def evaluate(self, |
| pred_dict: Dict[int, float], |
| answer_ids: Union[torch.LongTensor, List[Any]], |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], |
| **kwargs: Any) -> Dict[str, float]: |
| """ |
| Evaluates the predictions using the specified metrics. |
| |
| Args: |
| pred_dict (Dict[int, float]): Predicted answer ids or scores. |
| answer_ids (torch.LongTensor): Ground truth answer ids. |
| metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', |
| 'precision@k', 'map@k', 'ndcg@k'. |
| |
| Returns: |
| Dict[str, float]: A dictionary of evaluation metrics. |
| """ |
| return self.evaluator(pred_dict, answer_ids, metrics) |
| |
| def evaluate_batch(self, |
| pred_ids: List[int], |
| pred: torch.Tensor, |
| answer_ids: Union[torch.LongTensor, List[Any]], |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], |
| **kwargs: Any) -> Dict[str, float]: |
| return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics) |
|
|
|
|
| def get_embeddings(text, model_name, **encode_kwargs): |
| """ |
| Get embeddings for the given text using the specified model. |
| |
| Args: |
| model_name (str): Model name. |
| text (Union[str, List[str]]): The input text to be embedded. |
| |
| Returns: |
| torch.Tensor: Embedding of the input text. |
| """ |
| if isinstance(text, str): |
| text = [text] |
|
|
| if 'GritLM' in model_name: |
| emb = get_gritlm_embeddings(text, model_name, **encode_kwargs) |
| elif 'LLM2Vec' in model_name: |
| emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs) |
| elif 'all-mpnet-base-v2' in model_name: |
| emb = get_sentence_transformer_embeddings(text) |
| elif 'contriever' in model_name: |
| emb = get_contriever_embeddings(text) |
| else: |
| emb = get_api_embeddings(text, model_name, **encode_kwargs) |
| return emb.view(len(text), -1) |
| |
|
|