CanLex / canlex /eval.py
Beemer
Upgrade retrieval: bge-small embeddings + promote-only reranking
2966f10
"""Measure CanLex retrieval quality against a curated question set.
Each item in data/eval/questions.json pairs a realistic legal question with the
provision(s) or case(s) that answer it. This runs every question through the
retrieval index and reports Hit@k and MRR. Re-run it after any retrieval change
-- a new reranker, different embeddings, a chunking tweak -- to see whether
quality moved, and read the "Misses" list to see exactly what to fix.
py -m canlex.eval
"""
import json
import sys
from .config import ROOT
from .index import LegislationIndex
QUESTIONS = ROOT / "data" / "eval" / "questions.json"
EVAL_TOP_K = 20 # search depth, so ranks past the usual 6 are still visible
def _matches(result, answers):
"""True if a search result is one of the gold answers (act + section).
A gold answer is [act, section]; an empty section matches any chunk of that
act/case (used for case-law answers, whose chunks carry no section number).
"""
r_acts = {result.get("act_short", "").lower(),
result.get("act_code", "").lower()}
r_sec = result.get("section", "")
for act, section in answers:
if act.lower() in r_acts and (section == r_sec or section == ""):
return True
return False
def run():
if not QUESTIONS.exists():
print(f"No question set at {QUESTIONS}.", file=sys.stderr)
return
items = json.loads(QUESTIONS.read_text(encoding="utf-8"))
index = LegislationIndex()
ranks = [] # rank of the first gold hit per question (0 = miss)
misses = []
for item in items:
answers = [tuple(a) for a in item["answers"]]
results = index.search(item["query"], top_k=EVAL_TOP_K)
rank = 0
for i, result in enumerate(results, start=1):
if _matches(result, answers):
rank = i
break
ranks.append(rank)
if rank == 0 or rank > 5:
top = results[0] if results else None
misses.append((item["query"], answers, rank, top))
n = len(ranks) or 1
hit = lambda k: sum(1 for r in ranks if 0 < r <= k) / n
mrr = sum(1.0 / r for r in ranks if r) / n
print(f"CanLex retrieval evaluation -- {len(ranks)} questions\n")
print(f" Hit@1: {hit(1):.2f}")
print(f" Hit@3: {hit(3):.2f}")
print(f" Hit@5: {hit(5):.2f}")
print(f" Hit@10: {hit(10):.2f}")
print(f" MRR: {mrr:.2f}")
if misses:
print(f"\n{len(misses)} miss(es) -- gold answer ranked >5 or absent:")
for query, answers, rank, top in misses:
gold = ", ".join(f"{a} s.{s}".rstrip(" s.") for a, s in answers)
where = f"ranked #{rank}" if rank else f"absent (searched {EVAL_TOP_K})"
got = (f"{top.get('act_short', '')} s.{top.get('section', '')}".rstrip(" s.")
if top else "nothing")
print(f" [{where}] {query}")
print(f" gold: {gold} | top result: {got}")
print()
if __name__ == "__main__":
run()