CanLex / canlex /sweep.py
Beemer
Coordinate-descent tuning sweep over the four retrieval knobs
df55f26
"""Coordinate-descent tuning sweep for the four retrieval knobs.
Runs the 141-question eval repeatedly, varying one knob at a time while holding
the others at their current best. Picks the value that maximises Hit@5 (with
MRR as the tiebreak) and continues to the next knob. One pass through all four
knobs is usually enough to settle.
The knobs are read from env vars at index-load time (see canlex/index.py), so
each combination is exercised in a fresh subprocess of canlex.eval. Outputs go
to data/eval/sweep.log and a compact JSON summary to data/eval/sweep.json.
py -m canlex.sweep
"""
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
LOG = ROOT / "data" / "eval" / "sweep.log"
SUMMARY = ROOT / "data" / "eval" / "sweep.json"
# (env var, list of candidate values, current default). The defaults match the
# literals in canlex/index.py; values bracket each on a roughly geometric grid.
KNOBS = [
("CANLEX_MN_WEIGHT", [0.0012, 0.0024, 0.005, 0.01], 0.0024),
("CANLEX_MN_CAP", [0.006, 0.012, 0.024, 0.05], 0.012),
("CANLEX_REG_PENALTY", [0.004, 0.008, 0.016, 0.032], 0.008),
("CANLEX_BACKMATTER_PENALTY",[0.004, 0.008, 0.016, 0.032], 0.008),
]
_METRIC_RE = re.compile(
r"Hit@1:\s*([\d.]+).*?Hit@3:\s*([\d.]+).*?Hit@5:\s*([\d.]+).*?"
r"Hit@10:\s*([\d.]+).*?MRR:\s*([\d.]+)", re.S)
def _run_eval(env_overrides: dict[str, float]) -> dict[str, float]:
"""Run canlex.eval once with the given env overrides; return metrics dict."""
env = dict(os.environ)
for k, v in env_overrides.items():
env[k] = f"{v}"
proc = subprocess.run(
[sys.executable, "-u", "-m", "canlex.eval"],
capture_output=True, text=True, env=env, cwd=ROOT,
)
if proc.returncode != 0:
raise RuntimeError(f"eval failed (exit {proc.returncode}):\n"
f"{proc.stderr[-800:]}")
m = _METRIC_RE.search(proc.stdout)
if not m:
raise RuntimeError(f"could not parse eval output:\n{proc.stdout[-800:]}")
h1, h3, h5, h10, mrr = (float(x) for x in m.groups())
n_misses = proc.stdout.count("miss(es)") # 0 if we end up at 100
return {"hit1": h1, "hit3": h3, "hit5": h5, "hit10": h10, "mrr": mrr,
"stdout": proc.stdout}
def _score(metrics: dict[str, float]) -> tuple[float, float]:
"""Order by Hit@5 then MRR (both higher is better)."""
return (metrics["hit5"], metrics["mrr"])
def main():
LOG.parent.mkdir(parents=True, exist_ok=True)
log = LOG.open("w", encoding="utf-8")
log.write(f"# CanLex tuning sweep -- {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
print(f"# CanLex tuning sweep -- writing log to {LOG}\n")
current = {name: default for name, _values, default in KNOBS}
all_runs: list[dict] = []
# Baseline at current defaults.
print("Baseline:", current)
log.write(f"\nBaseline: {current}\n")
baseline = _run_eval(current)
print(f" Hit@5={baseline['hit5']:.3f} MRR={baseline['mrr']:.3f}")
log.write(f" Hit@5={baseline['hit5']:.3f} MRR={baseline['mrr']:.3f}\n")
all_runs.append({"values": dict(current), "metrics":
{k: baseline[k] for k in ("hit1", "hit3", "hit5", "hit10", "mrr")}})
best_metrics = baseline
for name, values, _default in KNOBS:
print(f"\nSweeping {name} in {values} (others held at {current})...")
log.write(f"\nSweeping {name} in {values} (others held at {current})\n")
local_best = (current[name], _score(best_metrics), best_metrics)
for v in values:
if v == current[name]:
# Re-use the already-measured baseline at this knob value.
metrics = best_metrics
else:
run_values = dict(current, **{name: v})
metrics = _run_eval(run_values)
row = (f" {name}={v!r:<8s} -> Hit@1={metrics['hit1']:.3f} "
f"Hit@3={metrics['hit3']:.3f} Hit@5={metrics['hit5']:.3f} "
f"Hit@10={metrics['hit10']:.3f} MRR={metrics['mrr']:.3f}")
print(row); log.write(row + "\n"); log.flush()
all_runs.append({"values": dict(current, **{name: v}),
"metrics": {k: metrics[k] for k in
("hit1", "hit3", "hit5", "hit10", "mrr")}})
score = _score(metrics)
if score > local_best[1]:
local_best = (v, score, metrics)
if local_best[0] != current[name]:
print(f" ! {name}: {current[name]} -> {local_best[0]} "
f"(Hit@5 {best_metrics['hit5']:.3f} -> {local_best[2]['hit5']:.3f})")
log.write(f" ! {name}: {current[name]} -> {local_best[0]}\n")
current[name] = local_best[0]
best_metrics = local_best[2]
print(f"\nBest: {current}")
print(f" Hit@1={best_metrics['hit1']:.3f} Hit@3={best_metrics['hit3']:.3f} "
f"Hit@5={best_metrics['hit5']:.3f} Hit@10={best_metrics['hit10']:.3f} "
f"MRR={best_metrics['mrr']:.3f}")
log.write(f"\nBest: {current}\n Hit@1={best_metrics['hit1']:.3f} "
f"Hit@5={best_metrics['hit5']:.3f} MRR={best_metrics['mrr']:.3f}\n")
log.close()
SUMMARY.write_text(json.dumps({
"best": current,
"best_metrics": {k: best_metrics[k] for k in
("hit1", "hit3", "hit5", "hit10", "mrr")},
"runs": all_runs,
}, indent=2), encoding="utf-8")
print(f"\nLog: {LOG}\nSummary: {SUMMARY}")
if __name__ == "__main__":
main()