File size: 3,037 Bytes
2966f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Measure CanLex retrieval quality against a curated question set.

Each item in data/eval/questions.json pairs a realistic legal question with the
provision(s) or case(s) that answer it. This runs every question through the
retrieval index and reports Hit@k and MRR. Re-run it after any retrieval change
-- a new reranker, different embeddings, a chunking tweak -- to see whether
quality moved, and read the "Misses" list to see exactly what to fix.

    py -m canlex.eval
"""
import json
import sys

from .config import ROOT
from .index import LegislationIndex

QUESTIONS = ROOT / "data" / "eval" / "questions.json"
EVAL_TOP_K = 20      # search depth, so ranks past the usual 6 are still visible


def _matches(result, answers):
    """True if a search result is one of the gold answers (act + section).

    A gold answer is [act, section]; an empty section matches any chunk of that
    act/case (used for case-law answers, whose chunks carry no section number).
    """
    r_acts = {result.get("act_short", "").lower(),
              result.get("act_code", "").lower()}
    r_sec = result.get("section", "")
    for act, section in answers:
        if act.lower() in r_acts and (section == r_sec or section == ""):
            return True
    return False


def run():
    if not QUESTIONS.exists():
        print(f"No question set at {QUESTIONS}.", file=sys.stderr)
        return
    items = json.loads(QUESTIONS.read_text(encoding="utf-8"))
    index = LegislationIndex()
    ranks = []          # rank of the first gold hit per question (0 = miss)
    misses = []
    for item in items:
        answers = [tuple(a) for a in item["answers"]]
        results = index.search(item["query"], top_k=EVAL_TOP_K)
        rank = 0
        for i, result in enumerate(results, start=1):
            if _matches(result, answers):
                rank = i
                break
        ranks.append(rank)
        if rank == 0 or rank > 5:
            top = results[0] if results else None
            misses.append((item["query"], answers, rank, top))

    n = len(ranks) or 1
    hit = lambda k: sum(1 for r in ranks if 0 < r <= k) / n
    mrr = sum(1.0 / r for r in ranks if r) / n
    print(f"CanLex retrieval evaluation -- {len(ranks)} questions\n")
    print(f"  Hit@1:   {hit(1):.2f}")
    print(f"  Hit@3:   {hit(3):.2f}")
    print(f"  Hit@5:   {hit(5):.2f}")
    print(f"  Hit@10:  {hit(10):.2f}")
    print(f"  MRR:     {mrr:.2f}")

    if misses:
        print(f"\n{len(misses)} miss(es) -- gold answer ranked >5 or absent:")
        for query, answers, rank, top in misses:
            gold = ", ".join(f"{a} s.{s}".rstrip(" s.") for a, s in answers)
            where = f"ranked #{rank}" if rank else f"absent (searched {EVAL_TOP_K})"
            got = (f"{top.get('act_short', '')} s.{top.get('section', '')}".rstrip(" s.")
                   if top else "nothing")
            print(f"  [{where}] {query}")
            print(f"      gold: {gold}   |   top result: {got}")
    print()


if __name__ == "__main__":
    run()