File size: 5,649 Bytes
df55f26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """Coordinate-descent tuning sweep for the four retrieval knobs.
Runs the 141-question eval repeatedly, varying one knob at a time while holding
the others at their current best. Picks the value that maximises Hit@5 (with
MRR as the tiebreak) and continues to the next knob. One pass through all four
knobs is usually enough to settle.
The knobs are read from env vars at index-load time (see canlex/index.py), so
each combination is exercised in a fresh subprocess of canlex.eval. Outputs go
to data/eval/sweep.log and a compact JSON summary to data/eval/sweep.json.
py -m canlex.sweep
"""
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
LOG = ROOT / "data" / "eval" / "sweep.log"
SUMMARY = ROOT / "data" / "eval" / "sweep.json"
# (env var, list of candidate values, current default). The defaults match the
# literals in canlex/index.py; values bracket each on a roughly geometric grid.
KNOBS = [
("CANLEX_MN_WEIGHT", [0.0012, 0.0024, 0.005, 0.01], 0.0024),
("CANLEX_MN_CAP", [0.006, 0.012, 0.024, 0.05], 0.012),
("CANLEX_REG_PENALTY", [0.004, 0.008, 0.016, 0.032], 0.008),
("CANLEX_BACKMATTER_PENALTY",[0.004, 0.008, 0.016, 0.032], 0.008),
]
_METRIC_RE = re.compile(
r"Hit@1:\s*([\d.]+).*?Hit@3:\s*([\d.]+).*?Hit@5:\s*([\d.]+).*?"
r"Hit@10:\s*([\d.]+).*?MRR:\s*([\d.]+)", re.S)
def _run_eval(env_overrides: dict[str, float]) -> dict[str, float]:
"""Run canlex.eval once with the given env overrides; return metrics dict."""
env = dict(os.environ)
for k, v in env_overrides.items():
env[k] = f"{v}"
proc = subprocess.run(
[sys.executable, "-u", "-m", "canlex.eval"],
capture_output=True, text=True, env=env, cwd=ROOT,
)
if proc.returncode != 0:
raise RuntimeError(f"eval failed (exit {proc.returncode}):\n"
f"{proc.stderr[-800:]}")
m = _METRIC_RE.search(proc.stdout)
if not m:
raise RuntimeError(f"could not parse eval output:\n{proc.stdout[-800:]}")
h1, h3, h5, h10, mrr = (float(x) for x in m.groups())
n_misses = proc.stdout.count("miss(es)") # 0 if we end up at 100
return {"hit1": h1, "hit3": h3, "hit5": h5, "hit10": h10, "mrr": mrr,
"stdout": proc.stdout}
def _score(metrics: dict[str, float]) -> tuple[float, float]:
"""Order by Hit@5 then MRR (both higher is better)."""
return (metrics["hit5"], metrics["mrr"])
def main():
LOG.parent.mkdir(parents=True, exist_ok=True)
log = LOG.open("w", encoding="utf-8")
log.write(f"# CanLex tuning sweep -- {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
print(f"# CanLex tuning sweep -- writing log to {LOG}\n")
current = {name: default for name, _values, default in KNOBS}
all_runs: list[dict] = []
# Baseline at current defaults.
print("Baseline:", current)
log.write(f"\nBaseline: {current}\n")
baseline = _run_eval(current)
print(f" Hit@5={baseline['hit5']:.3f} MRR={baseline['mrr']:.3f}")
log.write(f" Hit@5={baseline['hit5']:.3f} MRR={baseline['mrr']:.3f}\n")
all_runs.append({"values": dict(current), "metrics":
{k: baseline[k] for k in ("hit1", "hit3", "hit5", "hit10", "mrr")}})
best_metrics = baseline
for name, values, _default in KNOBS:
print(f"\nSweeping {name} in {values} (others held at {current})...")
log.write(f"\nSweeping {name} in {values} (others held at {current})\n")
local_best = (current[name], _score(best_metrics), best_metrics)
for v in values:
if v == current[name]:
# Re-use the already-measured baseline at this knob value.
metrics = best_metrics
else:
run_values = dict(current, **{name: v})
metrics = _run_eval(run_values)
row = (f" {name}={v!r:<8s} -> Hit@1={metrics['hit1']:.3f} "
f"Hit@3={metrics['hit3']:.3f} Hit@5={metrics['hit5']:.3f} "
f"Hit@10={metrics['hit10']:.3f} MRR={metrics['mrr']:.3f}")
print(row); log.write(row + "\n"); log.flush()
all_runs.append({"values": dict(current, **{name: v}),
"metrics": {k: metrics[k] for k in
("hit1", "hit3", "hit5", "hit10", "mrr")}})
score = _score(metrics)
if score > local_best[1]:
local_best = (v, score, metrics)
if local_best[0] != current[name]:
print(f" ! {name}: {current[name]} -> {local_best[0]} "
f"(Hit@5 {best_metrics['hit5']:.3f} -> {local_best[2]['hit5']:.3f})")
log.write(f" ! {name}: {current[name]} -> {local_best[0]}\n")
current[name] = local_best[0]
best_metrics = local_best[2]
print(f"\nBest: {current}")
print(f" Hit@1={best_metrics['hit1']:.3f} Hit@3={best_metrics['hit3']:.3f} "
f"Hit@5={best_metrics['hit5']:.3f} Hit@10={best_metrics['hit10']:.3f} "
f"MRR={best_metrics['mrr']:.3f}")
log.write(f"\nBest: {current}\n Hit@1={best_metrics['hit1']:.3f} "
f"Hit@5={best_metrics['hit5']:.3f} MRR={best_metrics['mrr']:.3f}\n")
log.close()
SUMMARY.write_text(json.dumps({
"best": current,
"best_metrics": {k: best_metrics[k] for k in
("hit1", "hit3", "hit5", "hit10", "mrr")},
"runs": all_runs,
}, indent=2), encoding="utf-8")
print(f"\nLog: {LOG}\nSummary: {SUMMARY}")
if __name__ == "__main__":
main()
|