GLEN-model / src /tevatron /metrics.py
QuanTH02's picture
Phase 1
3d5551b
import os
import pandas as pd
from typing import List, Optional
from collections import defaultdict
from transformers import AutoTokenizer
from torch.utils.data import Dataset
def compute_recall(args, cutoff: List[int] = [1, 10, 100]):
"""
Compute recall@k
Args:
args: arguments
cutoff: list of cutoffs
Returns:
metrics: dict of metrics
"""
q_gt, q_pred = {}, {}
with open(args.res1_save_path, "r") as f:
prev_q = ""
for line in f.readlines():
query, pred, gt, rank = line[:-1].split("\t")
if query != prev_q:
q_pred[query] = pred.split(",")
prev_q = query
if query in q_gt:
if len(q_gt[query]) <= 100:
q_gt[query].add(gt)
else:
q_gt[query] = gt.split(",")
q_gt[query] = set(q_gt[query])
do_seen_unseen = (
True
if args.unseen_query_set is not None and args.seen_query_set is not None
else False
)
metrics = {}
lines = []
lines.append("#####################")
for i in cutoff:
recall_k, seen_recall_k, unseen_recall_k = [], [], []
for q in q_pred:
tmp_recall = (
len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q])
if len(q_gt[q]) > 0
else 0
)
recall_k.append(tmp_recall)
if do_seen_unseen:
if q in args.seen_query_set:
tmp_recall = (
len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q])
if len(q_gt[q]) > 0
else 0
)
seen_recall_k.append(tmp_recall)
elif q in args.unseen_query_set:
tmp_recall = (
len(set(q_gt[q]) & set(q_pred[q][: int(i)])) / len(q_gt[q])
if len(q_gt[q]) > 0
else 0
)
unseen_recall_k.append(tmp_recall)
recall_avg = sum(recall_k) / len(recall_k)
if do_seen_unseen:
seen_recall_avg = (
sum(seen_recall_k) / len(seen_recall_k) if len(seen_recall_k) > 0 else 0
)
unseen_recall_avg = (
sum(unseen_recall_k) / len(unseen_recall_k)
if len(unseen_recall_k) > 0
else 0
)
metrics.update(
{
f"recall@{i}": recall_avg,
f"recall_unseen@{i}": unseen_recall_avg,
f"recall_seen@{i}": seen_recall_avg,
}
)
lines.append(
f"recall@{i} : {recall_avg:.4f} | recall_unseen@{i} : {unseen_recall_avg:.4f} | recall_seen@{i} : {seen_recall_avg:.4f}"
)
else:
metrics.update({f"recall@{i}": recall_avg})
lines.append(f"recall@{i} : {recall_avg:.4f}")
lines.append("-------------------------")
print("\n".join(lines))
return metrics
def compute_mrr(args, cutoff: List[int] = [10, 100]):
"""
Compute MRR@k
Args:
args: arguments
cutoff: list of cutoffs
Returns:
metrics: dict of metrics
"""
q_gt, q_pred = {}, {}
with open(args.res1_save_path, "r") as f:
prev_q = ""
for line in f.readlines():
query, pred, gt, rank = line[:-1].split("\t")
if query != prev_q:
q_pred[query] = pred.split(",")
prev_q = query
if query in q_gt:
if len(q_gt[query]) <= 100:
q_gt[query].add(gt)
else:
q_gt[query] = gt.split(",")
q_gt[query] = set(q_gt[query])
do_seen_unseen = (
True
if args.unseen_query_set is not None and args.seen_query_set is not None
else False
)
metrics = {}
lines = []
for i in cutoff:
mrr_k, seen_mrr_k, unseen_mrr_k = [], [], []
for query in q_pred:
score = 0
for j, p in enumerate(q_pred[query][: int(i)]):
if p in q_gt[query]:
score = 1 / (j + 1)
break
mrr_k.append(score)
if do_seen_unseen:
if query in args.seen_query_set:
seen_mrr_k.append(score)
elif query in args.unseen_query_set:
unseen_mrr_k.append(score)
mrr = sum(mrr_k) / len(mrr_k)
if do_seen_unseen:
seen_mrr = sum(seen_mrr_k) / len(seen_mrr_k) if len(seen_mrr_k) > 0 else 0
unseen_mrr = (
sum(unseen_mrr_k) / len(unseen_mrr_k) if len(unseen_mrr_k) > 0 else 0
)
metrics.update(
{
f"MRR@{i}": mrr,
f"MRR_unseen@{i}": unseen_mrr,
f"MRR_seen@{i}": seen_mrr,
}
)
lines.append(
f"MRR@{i} : {mrr:.4f} | MRR_unseen@{i} : {unseen_mrr:.4f} | MRR_seen@{i} : {seen_mrr:.4f}"
)
else:
metrics.update({f"MRR@{i}": mrr})
lines.append(f"MRR@{i} : {mrr:.4f}")
print("\n".join(lines))
return metrics
def evaluate_beir(args, tokenizer: AutoTokenizer, dataset: Optional[Dataset]):
"""
Evaluate BEIR dataset using beir library
Args:
args: arguments
tokenizer: tokenizer
dataset: dataset
Returns:
metrics: dict of metrics
"""
q_gt, q_pred = {}, {}
with open(args.res1_save_path, "r") as f:
prev_q = ""
for line in f.readlines():
query, pred, gt, rank = line[:-1].split("\t")
if query != prev_q:
q_pred[query] = pred.split(",")
q_pred[query] = q_pred[query]
prev_q = query
if query in q_gt:
if len(q_gt[query]) <= 100:
q_gt[query].add(gt)
else:
q_gt[query] = gt.split(",")
q_gt[query] = set(q_gt[query])
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
data_path = os.path.join("data/BEIR_dataset", args.dataset_name)
_, _, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
fname = os.path.join(data_path, "dev_doc_newid.tsv")
df = pd.read_csv(fname, encoding="utf-8", sep="\t", dtype=str).loc[
:, ["query", "queryid"]
]
df_unique_q = df.drop_duplicates(subset=["query", "queryid"])
query2qid = {}
for query, qid in df_unique_q[["query", "queryid"]].values:
input_ = dataset.clean_text(query)
output_ = tokenizer.batch_encode_plus(
[input_],
max_length=156,
padding="max_length",
truncation=True,
return_tensors="pt",
)
query = tokenizer.decode(
output_["input_ids"][0].numpy(), skip_special_tokens=True
)
query2qid[query] = qid
retriever = EvaluateRetrieval(None, score_function="dot")
results = defaultdict(dict)
for q in q_pred:
qid = query2qid[q]
for rank, d in enumerate(q_pred[q]):
score = 1 / (rank + 1)
oldid = d.split("<->")[0]
results[qid][oldid] = score
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, args.ndcg_num)
metrics = {}
print("#####################")
for k in args.ndcg_num:
metrics.update({f"NDCG@{k}": ndcg[f"NDCG@{k}"]})
score = ndcg[f"NDCG@{k}"]
print(f"NDCG@{k} : {score}")
print("#####################")
for k in args.recall_num:
metrics.update({f"Recall@{k}": recall[f"Recall@{k}"]})
score = recall[f"Recall@{k}"]
print(f"Recall@{k} : {score}")
print("#####################")
return metrics