File size: 5,649 Bytes
df55f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Coordinate-descent tuning sweep for the four retrieval knobs.

Runs the 141-question eval repeatedly, varying one knob at a time while holding
the others at their current best. Picks the value that maximises Hit@5 (with
MRR as the tiebreak) and continues to the next knob. One pass through all four
knobs is usually enough to settle.

The knobs are read from env vars at index-load time (see canlex/index.py), so
each combination is exercised in a fresh subprocess of canlex.eval. Outputs go
to data/eval/sweep.log and a compact JSON summary to data/eval/sweep.json.

    py -m canlex.sweep
"""
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
LOG = ROOT / "data" / "eval" / "sweep.log"
SUMMARY = ROOT / "data" / "eval" / "sweep.json"

# (env var, list of candidate values, current default). The defaults match the
# literals in canlex/index.py; values bracket each on a roughly geometric grid.
KNOBS = [
    ("CANLEX_MN_WEIGHT",         [0.0012, 0.0024, 0.005, 0.01],   0.0024),
    ("CANLEX_MN_CAP",            [0.006, 0.012, 0.024, 0.05],     0.012),
    ("CANLEX_REG_PENALTY",       [0.004, 0.008, 0.016, 0.032],    0.008),
    ("CANLEX_BACKMATTER_PENALTY",[0.004, 0.008, 0.016, 0.032],    0.008),
]


_METRIC_RE = re.compile(
    r"Hit@1:\s*([\d.]+).*?Hit@3:\s*([\d.]+).*?Hit@5:\s*([\d.]+).*?"
    r"Hit@10:\s*([\d.]+).*?MRR:\s*([\d.]+)", re.S)


def _run_eval(env_overrides: dict[str, float]) -> dict[str, float]:
    """Run canlex.eval once with the given env overrides; return metrics dict."""
    env = dict(os.environ)
    for k, v in env_overrides.items():
        env[k] = f"{v}"
    proc = subprocess.run(
        [sys.executable, "-u", "-m", "canlex.eval"],
        capture_output=True, text=True, env=env, cwd=ROOT,
    )
    if proc.returncode != 0:
        raise RuntimeError(f"eval failed (exit {proc.returncode}):\n"
                           f"{proc.stderr[-800:]}")
    m = _METRIC_RE.search(proc.stdout)
    if not m:
        raise RuntimeError(f"could not parse eval output:\n{proc.stdout[-800:]}")
    h1, h3, h5, h10, mrr = (float(x) for x in m.groups())
    n_misses = proc.stdout.count("miss(es)")          # 0 if we end up at 100
    return {"hit1": h1, "hit3": h3, "hit5": h5, "hit10": h10, "mrr": mrr,
            "stdout": proc.stdout}


def _score(metrics: dict[str, float]) -> tuple[float, float]:
    """Order by Hit@5 then MRR (both higher is better)."""
    return (metrics["hit5"], metrics["mrr"])


def main():
    LOG.parent.mkdir(parents=True, exist_ok=True)
    log = LOG.open("w", encoding="utf-8")
    log.write(f"# CanLex tuning sweep -- {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    print(f"# CanLex tuning sweep -- writing log to {LOG}\n")

    current = {name: default for name, _values, default in KNOBS}
    all_runs: list[dict] = []

    # Baseline at current defaults.
    print("Baseline:", current)
    log.write(f"\nBaseline: {current}\n")
    baseline = _run_eval(current)
    print(f"  Hit@5={baseline['hit5']:.3f}  MRR={baseline['mrr']:.3f}")
    log.write(f"  Hit@5={baseline['hit5']:.3f}  MRR={baseline['mrr']:.3f}\n")
    all_runs.append({"values": dict(current), "metrics":
                     {k: baseline[k] for k in ("hit1", "hit3", "hit5", "hit10", "mrr")}})
    best_metrics = baseline

    for name, values, _default in KNOBS:
        print(f"\nSweeping {name} in {values} (others held at {current})...")
        log.write(f"\nSweeping {name} in {values} (others held at {current})\n")
        local_best = (current[name], _score(best_metrics), best_metrics)
        for v in values:
            if v == current[name]:
                # Re-use the already-measured baseline at this knob value.
                metrics = best_metrics
            else:
                run_values = dict(current, **{name: v})
                metrics = _run_eval(run_values)
            row = (f"  {name}={v!r:<8s} -> Hit@1={metrics['hit1']:.3f} "
                   f"Hit@3={metrics['hit3']:.3f} Hit@5={metrics['hit5']:.3f} "
                   f"Hit@10={metrics['hit10']:.3f} MRR={metrics['mrr']:.3f}")
            print(row); log.write(row + "\n"); log.flush()
            all_runs.append({"values": dict(current, **{name: v}),
                             "metrics": {k: metrics[k] for k in
                                         ("hit1", "hit3", "hit5", "hit10", "mrr")}})
            score = _score(metrics)
            if score > local_best[1]:
                local_best = (v, score, metrics)
        if local_best[0] != current[name]:
            print(f"  ! {name}: {current[name]} -> {local_best[0]}  "
                  f"(Hit@5 {best_metrics['hit5']:.3f} -> {local_best[2]['hit5']:.3f})")
            log.write(f"  ! {name}: {current[name]} -> {local_best[0]}\n")
        current[name] = local_best[0]
        best_metrics = local_best[2]

    print(f"\nBest: {current}")
    print(f"  Hit@1={best_metrics['hit1']:.3f}  Hit@3={best_metrics['hit3']:.3f}  "
          f"Hit@5={best_metrics['hit5']:.3f}  Hit@10={best_metrics['hit10']:.3f}  "
          f"MRR={best_metrics['mrr']:.3f}")
    log.write(f"\nBest: {current}\n  Hit@1={best_metrics['hit1']:.3f} "
              f"Hit@5={best_metrics['hit5']:.3f}  MRR={best_metrics['mrr']:.3f}\n")
    log.close()
    SUMMARY.write_text(json.dumps({
        "best": current,
        "best_metrics": {k: best_metrics[k] for k in
                         ("hit1", "hit3", "hit5", "hit10", "mrr")},
        "runs": all_runs,
    }, indent=2), encoding="utf-8")
    print(f"\nLog: {LOG}\nSummary: {SUMMARY}")


if __name__ == "__main__":
    main()