| import os |
| import os.path as osp |
| import subprocess |
| from typing import Any, Union, List, Dict, Optional |
| from collections import defaultdict |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from colbert.infra import Run, RunConfig, ColBERTConfig |
| from colbert.data import Queries, Collection |
| from colbert import Indexer, Searcher |
|
|
| from models.model import ModelForSTaRKQA |
| from stark_qa import load_qa |
|
|
|
|
| class Colbertv2(ModelForSTaRKQA): |
| """ |
| ColBERTv2 Model for STaRK QA. |
| |
| This model integrates the ColBERTv2 dense retrieval model to rank candidates based on their relevance |
| to a query from a question-answering dataset. |
| """ |
| |
| url = "https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz" |
| |
| def __init__(self, |
| skb: Any, |
| dataset_name: str, |
| human_generated_eval: bool, |
| add_rel: bool = False, |
| download_dir: str = 'output', |
| save_dir: str = 'output/colbertv2.0', |
| nbits: int = 2, |
| k: int = 100): |
| """ |
| Initialize the ColBERTv2 model with the given knowledge base and parameters. |
| |
| Args: |
| skb (Any): The knowledge base containing candidate documents. |
| dataset_name (str): The name of the dataset being used. |
| human_generated_eval (bool): Whether to use human-generated queries for evaluation. |
| add_rel (bool, optional): Whether to add relational information to the document. Defaults to False. |
| download_dir (str, optional): Directory where the ColBERTv2 model is downloaded. Defaults to 'output'. |
| save_dir (str, optional): Directory where the experiment output is saved. Defaults to 'output/colbertv2.0'. |
| nbits (int, optional): Number of bits for indexing. Defaults to 2. |
| k (int, optional): Number of top candidates to retrieve. Defaults to 100. |
| """ |
| super(Colbertv2, self).__init__(skb) |
|
|
| self.k = k |
| self.nbits = nbits |
|
|
| query_tsv_name = 'query_hg.tsv' if human_generated_eval else 'query.tsv' |
| self.exp_name = dataset_name + '_hg' if human_generated_eval else dataset_name |
|
|
| self.save_dir = save_dir |
| self.download_dir = download_dir |
| self.experiments_dir = './experiments' |
| |
| self.model_ckpt_dir = osp.join(self.download_dir, 'colbertv2.0') |
| self.query_tsv_path = osp.join(self.save_dir, query_tsv_name) |
| self.doc_tsv_path = osp.join(self.save_dir, 'doc.tsv') |
| self.index_ckpt_path = osp.join(self.save_dir, 'index.faiss') |
| self.ranking_path = osp.join(self.save_dir, 'ranking.tsv') |
|
|
| os.makedirs(self.download_dir, exist_ok=True) |
| os.makedirs(self.experiments_dir, exist_ok=True) |
|
|
| |
| qa_dataset = load_qa(dataset_name, human_generated_eval=human_generated_eval) |
| self._check_query_csv(qa_dataset, self.query_tsv_path) |
| self._check_doc_csv(skb, self.doc_tsv_path, add_rel) |
|
|
| |
| self._download() |
|
|
| |
| self.queries = Queries(self.query_tsv_path) |
| self.collection = Collection(self.doc_tsv_path) |
|
|
| |
| self._prepare_indexer() |
|
|
| |
| self.score_dict = self.run_all() |
| |
| def _check_query_csv(self, qa_dataset: Any, query_tsv_path: str) -> None: |
| """ |
| Check if the query TSV file exists; if not, create it from the QA dataset. |
| |
| Args: |
| qa_dataset (Any): The question-answer dataset. |
| query_tsv_path (str): Path to the query TSV file. |
| """ |
| if not osp.exists(query_tsv_path): |
| queries = {qa_dataset[i][1]: qa_dataset[i][0].replace('\n', ' ') |
| for i in range(len(qa_dataset))} |
| lines = [f"{qid}\t{q}" for qid, q in queries.items()] |
| with open(query_tsv_path, 'w') as file: |
| file.write('\n'.join(lines)) |
| else: |
| print(f'Loaded existing queries from {query_tsv_path}') |
|
|
| def _check_doc_csv(self, skb: Any, doc_tsv_path: str, add_rel: bool) -> None: |
| """ |
| Check if the document TSV file exists; if not, create it from the knowledge base. |
| |
| Args: |
| skb (Any): The knowledge base containing candidate documents. |
| doc_tsv_path (str): Path to the document TSV file. |
| add_rel (bool): Whether to add relational information to the documents. |
| """ |
| indices = skb.candidate_ids |
| self.docid2pid = {idx: i for i, idx in enumerate(indices)} |
| self.pid2docid = {i: idx for i, idx in enumerate(indices)} |
|
|
| if not osp.exists(doc_tsv_path): |
| corpus = {self.docid2pid[idx]: skb.get_doc_info(idx, add_rel=add_rel, compact=True) |
| for idx in tqdm(indices, desc="Gathering documents")} |
| |
| lines = [f"{idx}\t{doc}" for idx, doc in corpus.items()] |
| with open(doc_tsv_path, 'w') as file: |
| file.write('\n'.join(lines)) |
| else: |
| print(f'Loaded existing documents from {doc_tsv_path}') |
| |
| def _download(self) -> None: |
| """ |
| Download the ColBERTv2 model if not already available. |
| """ |
| if not osp.exists(osp.join(self.download_dir, 'colbertv2.0')): |
| |
| download_command = f"wget {self.url} -P {self.download_dir}" |
| subprocess.run(download_command, shell=True, check=True) |
|
|
| |
| tar_command = f"tar -xvzf {osp.join(self.download_dir, 'colbertv2.0.tar.gz')} -C {self.download_dir}" |
| subprocess.run(tar_command, shell=True, check=True) |
|
|
| def _prepare_indexer(self) -> None: |
| """ |
| Prepare the BM25 indexer for the document corpus. |
| """ |
| nranks = torch.cuda.device_count() |
| with Run().context(RunConfig(nranks=nranks, experiment=self.exp_name)): |
| config = ColBERTConfig(nbits=self.nbits, root=self.experiments_dir) |
| indexer = Indexer(checkpoint=self.model_ckpt_dir, config=config) |
| indexer.index(name=f"{self.exp_name}.nbits={self.nbits}", collection=self.doc_tsv_path, overwrite='reuse') |
|
|
| def run_all(self) -> Dict[int, Dict[int, float]]: |
| """ |
| Run the retrieval for all queries and store the rankings. |
| |
| Returns: |
| Dict[int, Dict[int, float]]: A dictionary mapping query IDs to a dictionary of candidate scores. |
| """ |
| def find_file_path_by_name(name: str, path: str) -> Optional[str]: |
| """ |
| Find the file path by its name in a given directory. |
| |
| Args: |
| name (str): The name of the file to find. |
| path (str): The directory to search. |
| |
| Returns: |
| Optional[str]: The file path if found, None otherwise. |
| """ |
| for root, dirs, files in os.walk(path): |
| if name in files: |
| return osp.join(root, name) |
| return None |
| |
| exp_root = osp.join(self.experiments_dir, self.exp_name) |
| ranking_path = find_file_path_by_name('ranking.tsv', exp_root) |
| if ranking_path is None: |
| nranks = torch.cuda.device_count() |
| with Run().context(RunConfig(nranks=nranks, experiment=self.exp_name)): |
| config = ColBERTConfig(root=self.experiments_dir) |
| searcher = Searcher(index=f"{self.exp_name}.nbits={self.nbits}", config=config) |
| queries = Queries(self.query_tsv_path) |
| ranking = searcher.search_all(queries, k=self.k) |
| ranking.save('ranking.tsv') |
| |
| self.ranking_path = find_file_path_by_name('ranking.tsv', exp_root) |
|
|
| score_dict = defaultdict(dict) |
| with open(self.ranking_path) as f: |
| for line in f: |
| qid, pid, rank, *score = line.strip().split('\t') |
| qid, pid, rank = int(qid), int(pid), int(rank) |
| if len(score) > 0: |
| assert len(score) == 1 |
| score = float(score[0]) |
| score_dict[qid][pid] = score |
| else: |
| score_dict[qid][pid] = -999 |
|
|
| return score_dict |
|
|
| def forward(self, |
| query: Union[str, None], |
| query_id: int, |
| **kwargs: Any) -> Dict[int, float]: |
| """ |
| Forward pass to retrieve rankings for the given query. |
| |
| Args: |
| query (str): The query string. |
| query_id (int): The query index. |
| |
| Returns: |
| Dict[int, float]: A dictionary of candidate IDs and their corresponding similarity scores. |
| """ |
| score_dict = self.score_dict[query_id] |
| return {self.pid2docid[pid]: score for pid, score in score_dict.items()} |
|
|