| import os
|
| import ujson
|
| import torch
|
| import random
|
|
|
| from collections import defaultdict, OrderedDict
|
|
|
| from colbert.parameters import DEVICE
|
| from colbert.modeling.colbert import ColBERT
|
| from colbert.utils.utils import print_message, load_checkpoint
|
| from colbert.evaluation.load_model import load_model
|
| from colbert.utils.runs import Run
|
|
|
|
|
| def load_queries(queries_path):
|
| queries = OrderedDict()
|
|
|
| print_message("#> Loading the queries from", queries_path, "...")
|
|
|
| with open(queries_path) as f:
|
| for line in f:
|
| qid, query, *_ = line.strip().split('\t')
|
| qid = int(qid)
|
|
|
| assert (qid not in queries), ("Query QID", qid, "is repeated!")
|
| queries[qid] = query
|
|
|
| print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
|
|
|
| return queries
|
|
|
|
|
| def load_qrels(qrels_path):
|
| if qrels_path is None:
|
| return None
|
|
|
| print_message("#> Loading qrels from", qrels_path, "...")
|
|
|
| qrels = OrderedDict()
|
| with open(qrels_path, mode='r', encoding="utf-8") as f:
|
| for line in f:
|
| qid, x, pid, y = map(int, line.strip().split('\t'))
|
| assert x == 0 and y == 1
|
| qrels[qid] = qrels.get(qid, [])
|
| qrels[qid].append(pid)
|
|
|
| assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
|
|
|
| avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
|
|
|
| print_message("#> Loaded qrels for", len(qrels), "unique queries with",
|
| avg_positive, "positives per query on average.\n")
|
|
|
| return qrels
|
|
|
|
|
| def load_topK(topK_path):
|
| queries = OrderedDict()
|
| topK_docs = OrderedDict()
|
| topK_pids = OrderedDict()
|
|
|
| print_message("#> Loading the top-k per query from", topK_path, "...")
|
|
|
| with open(topK_path) as f:
|
| for line_idx, line in enumerate(f):
|
| if line_idx and line_idx % (10*1000*1000) == 0:
|
| print(line_idx, end=' ', flush=True)
|
|
|
| qid, pid, query, passage = line.split('\t')
|
| qid, pid = int(qid), int(pid)
|
|
|
| assert (qid not in queries) or (queries[qid] == query)
|
| queries[qid] = query
|
| topK_docs[qid] = topK_docs.get(qid, [])
|
| topK_docs[qid].append(passage)
|
| topK_pids[qid] = topK_pids.get(qid, [])
|
| topK_pids[qid].append(pid)
|
|
|
| print()
|
|
|
| assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
|
|
| Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
|
|
| print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
| print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
|
|
|
| return queries, topK_docs, topK_pids
|
|
|
|
|
| def load_topK_pids(topK_path, qrels):
|
| topK_pids = defaultdict(list)
|
| topK_positives = defaultdict(list)
|
|
|
| print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
|
|
|
| with open(topK_path) as f:
|
| for line_idx, line in enumerate(f):
|
| if line_idx and line_idx % (10*1000*1000) == 0:
|
| print(line_idx, end=' ', flush=True)
|
|
|
| qid, pid, *rest = line.strip().split('\t')
|
| qid, pid = int(qid), int(pid)
|
|
|
| topK_pids[qid].append(pid)
|
|
|
| assert len(rest) in [1, 2, 3]
|
|
|
| if len(rest) > 1:
|
| *_, label = rest
|
| label = int(label)
|
| assert label in [0, 1]
|
|
|
| if label >= 1:
|
| topK_positives[qid].append(pid)
|
|
|
| print()
|
|
|
| assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
| assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
|
|
|
|
|
| topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
|
|
|
| Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
|
|
| print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
| print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
|
|
|
| if len(topK_positives) == 0:
|
| topK_positives = None
|
| else:
|
| assert len(topK_pids) >= len(topK_positives)
|
|
|
| for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
|
| topK_positives[qid] = []
|
|
|
| assert len(topK_pids) == len(topK_positives)
|
|
|
| avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
|
|
|
| print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
|
| avg_positive, "positives per query on average.\n")
|
|
|
| assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
|
|
|
| if topK_positives is None:
|
| topK_positives = qrels
|
|
|
| return topK_pids, topK_positives
|
|
|
|
|
| def load_collection(collection_path):
|
| print_message("#> Loading collection...")
|
|
|
| collection = []
|
|
|
| with open(collection_path) as f:
|
| for line_idx, line in enumerate(f):
|
| if line_idx % (1000*1000) == 0:
|
| print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
|
|
|
| pid, passage, *rest = line.strip().split('\t')
|
| assert pid == 'id' or int(pid) == line_idx
|
|
|
| if len(rest) >= 1:
|
| title = rest[0]
|
| passage = title + ' | ' + passage
|
|
|
| collection.append(passage)
|
|
|
| print()
|
|
|
| return collection
|
|
|
|
|
| def load_colbert(args, do_print=True):
|
| colbert, checkpoint = load_model(args, do_print)
|
|
|
|
|
|
|
|
|
| for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
|
| if 'arguments' in checkpoint and hasattr(args, k):
|
| if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
|
| a, b = checkpoint['arguments'][k], getattr(args, k)
|
| Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
|
|
|
| if 'arguments' in checkpoint:
|
| if args.rank < 1:
|
| print(ujson.dumps(checkpoint['arguments'], indent=4))
|
|
|
| if do_print:
|
| print('\n')
|
|
|
| return colbert, checkpoint
|
|
|