import os import pandas as pd from typing import List, Optional from collections import defaultdict from transformers import AutoTokenizer from torch.utils.data import Dataset def compute_recall(args, cutoff: List[int] = [1, 10, 100]): """ Compute recall@k Args: args: arguments cutoff: list of cutoffs Returns: metrics: dict of metrics """ q_gt, q_pred = {}, {} with open(args.res1_save_path, "r") as f: prev_q = "" for line in f.readlines(): query, pred, gt, rank = line[:-1].split("\t") if query != prev_q: q_pred[query] = pred.split(",") prev_q = query if query in q_gt: if len(q_gt[query]) <= 100: q_gt[query].add(gt) else: q_gt[query] = gt.split(",") q_gt[query] = set(q_gt[query]) do_seen_unseen = ( True if args.unseen_query_set is not None and args.seen_query_set is not None else False ) metrics = {} lines = [] lines.append("#####################") for i in cutoff: recall_k, seen_recall_k, unseen_recall_k = [], [], [] for q in q_pred: tmp_recall = ( len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q]) if len(q_gt[q]) > 0 else 0 ) recall_k.append(tmp_recall) if do_seen_unseen: if q in args.seen_query_set: tmp_recall = ( len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q]) if len(q_gt[q]) > 0 else 0 ) seen_recall_k.append(tmp_recall) elif q in args.unseen_query_set: tmp_recall = ( len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q]) if len(q_gt[q]) > 0 else 0 ) unseen_recall_k.append(tmp_recall) recall_avg = sum(recall_k) / len(recall_k) if do_seen_unseen: seen_recall_avg = ( sum(seen_recall_k) / len(seen_recall_k) if len(seen_recall_k) > 0 else 0 ) unseen_recall_avg = ( sum(unseen_recall_k) / len(unseen_recall_k) if len(unseen_recall_k) > 0 else 0 ) metrics.update( { f"recall@{i}": recall_avg, f"recall_unseen@{i}": unseen_recall_avg, f"recall_seen@{i}": seen_recall_avg, } ) lines.append( f"recall@{i} : {recall_avg:.4f} | recall_unseen@{i} : {unseen_recall_avg:.4f} | recall_seen@{i} : {seen_recall_avg:.4f}" ) else: metrics.update({f"recall@{i}": recall_avg}) lines.append(f"recall@{i} : {recall_avg:.4f}") lines.append("-------------------------") print("\n".join(lines)) return metrics def compute_mrr(args, cutoff: List[int] = [10, 100]): """ Compute MRR@k Args: args: arguments cutoff: list of cutoffs Returns: metrics: dict of metrics """ q_gt, q_pred = {}, {} with open(args.res1_save_path, "r") as f: prev_q = "" for line in f.readlines(): query, pred, gt, rank = line[:-1].split("\t") if query != prev_q: q_pred[query] = pred.split(",") prev_q = query if query in q_gt: if len(q_gt[query]) <= 100: q_gt[query].add(gt) else: q_gt[query] = gt.split(",") q_gt[query] = set(q_gt[query]) do_seen_unseen = ( True if args.unseen_query_set is not None and args.seen_query_set is not None else False ) metrics = {} lines = [] for i in cutoff: mrr_k, seen_mrr_k, unseen_mrr_k = [], [], [] for query in q_pred: score = 0 for j, p in enumerate(q_pred[query][: int(i)]): if p in q_gt[query]: score = 1 / (j + 1) break mrr_k.append(score) if do_seen_unseen: if query in args.seen_query_set: seen_mrr_k.append(score) elif query in args.unseen_query_set: unseen_mrr_k.append(score) mrr = sum(mrr_k) / len(mrr_k) if do_seen_unseen: seen_mrr = sum(seen_mrr_k) / len(seen_mrr_k) if len(seen_mrr_k) > 0 else 0 unseen_mrr = ( sum(unseen_mrr_k) / len(unseen_mrr_k) if len(unseen_mrr_k) > 0 else 0 ) metrics.update( { f"MRR@{i}": mrr, f"MRR_unseen@{i}": unseen_mrr, f"MRR_seen@{i}": seen_mrr, } ) lines.append( f"MRR@{i} : {mrr:.4f} | MRR_unseen@{i} : {unseen_mrr:.4f} | MRR_seen@{i} : {seen_mrr:.4f}" ) else: metrics.update({f"MRR@{i}": mrr}) lines.append(f"MRR@{i} : {mrr:.4f}") print("\n".join(lines)) return metrics def evaluate_beir(args, tokenizer: AutoTokenizer, dataset: Optional[Dataset]): """ Evaluate BEIR dataset using beir library Args: args: arguments tokenizer: tokenizer dataset: dataset Returns: metrics: dict of metrics """ q_gt, q_pred = {}, {} with open(args.res1_save_path, "r") as f: prev_q = "" for line in f.readlines(): query, pred, gt, rank = line[:-1].split("\t") if query != prev_q: q_pred[query] = pred.split(",") q_pred[query] = q_pred[query] prev_q = query if query in q_gt: if len(q_gt[query]) <= 100: q_gt[query].add(gt) else: q_gt[query] = gt.split(",") q_gt[query] = set(q_gt[query]) from beir.datasets.data_loader import GenericDataLoader from beir.retrieval.evaluation import EvaluateRetrieval data_path = os.path.join("data/BEIR_dataset", args.dataset_name) _, _, qrels = GenericDataLoader(data_folder=data_path).load(split="test") fname = os.path.join(data_path, "dev_doc_newid.tsv") df = pd.read_csv(fname, encoding="utf-8", sep="\t", dtype=str).loc[ :, ["query", "queryid"] ] df_unique_q = df.drop_duplicates(subset=["query", "queryid"]) query2qid = {} for query, qid in df_unique_q[["query", "queryid"]].values: input_ = dataset.clean_text(query) output_ = tokenizer.batch_encode_plus( [input_], max_length=156, padding="max_length", truncation=True, return_tensors="pt", ) query = tokenizer.decode( output_["input_ids"][0].numpy(), skip_special_tokens=True ) query2qid[query] = qid retriever = EvaluateRetrieval(None, score_function="dot") results = defaultdict(dict) for q in q_pred: qid = query2qid[q] for rank, d in enumerate(q_pred[q]): score = 1 / (rank + 1) oldid = d.split("<->")[0] results[qid][oldid] = score ndcg, _map, recall, precision = retriever.evaluate(qrels, results, args.ndcg_num) metrics = {} print("#####################") for k in args.ndcg_num: metrics.update({f"NDCG@{k}": ndcg[f"NDCG@{k}"]}) score = ndcg[f"NDCG@{k}"] print(f"NDCG@{k} : {score}") print("#####################") for k in args.recall_num: metrics.update({f"Recall@{k}": recall[f"Recall@{k}"]}) score = recall[f"Recall@{k}"] print(f"Recall@{k} : {score}") print("#####################") return metrics