"""Measure CanLex retrieval quality against a curated question set. Each item in data/eval/questions.json pairs a realistic legal question with the provision(s) or case(s) that answer it. This runs every question through the retrieval index and reports Hit@k and MRR. Re-run it after any retrieval change -- a new reranker, different embeddings, a chunking tweak -- to see whether quality moved, and read the "Misses" list to see exactly what to fix. py -m canlex.eval """ import json import sys from .config import ROOT from .index import LegislationIndex QUESTIONS = ROOT / "data" / "eval" / "questions.json" EVAL_TOP_K = 20 # search depth, so ranks past the usual 6 are still visible def _matches(result, answers): """True if a search result is one of the gold answers (act + section). A gold answer is [act, section]; an empty section matches any chunk of that act/case (used for case-law answers, whose chunks carry no section number). """ r_acts = {result.get("act_short", "").lower(), result.get("act_code", "").lower()} r_sec = result.get("section", "") for act, section in answers: if act.lower() in r_acts and (section == r_sec or section == ""): return True return False def run(): if not QUESTIONS.exists(): print(f"No question set at {QUESTIONS}.", file=sys.stderr) return items = json.loads(QUESTIONS.read_text(encoding="utf-8")) index = LegislationIndex() ranks = [] # rank of the first gold hit per question (0 = miss) misses = [] for item in items: answers = [tuple(a) for a in item["answers"]] results = index.search(item["query"], top_k=EVAL_TOP_K) rank = 0 for i, result in enumerate(results, start=1): if _matches(result, answers): rank = i break ranks.append(rank) if rank == 0 or rank > 5: top = results[0] if results else None misses.append((item["query"], answers, rank, top)) n = len(ranks) or 1 hit = lambda k: sum(1 for r in ranks if 0 < r <= k) / n mrr = sum(1.0 / r for r in ranks if r) / n print(f"CanLex retrieval evaluation -- {len(ranks)} questions\n") print(f" Hit@1: {hit(1):.2f}") print(f" Hit@3: {hit(3):.2f}") print(f" Hit@5: {hit(5):.2f}") print(f" Hit@10: {hit(10):.2f}") print(f" MRR: {mrr:.2f}") if misses: print(f"\n{len(misses)} miss(es) -- gold answer ranked >5 or absent:") for query, answers, rank, top in misses: gold = ", ".join(f"{a} s.{s}".rstrip(" s.") for a, s in answers) where = f"ranked #{rank}" if rank else f"absent (searched {EVAL_TOP_K})" got = (f"{top.get('act_short', '')} s.{top.get('section', '')}".rstrip(" s.") if top else "nothing") print(f" [{where}] {query}") print(f" gold: {gold} | top result: {got}") print() if __name__ == "__main__": run()