decodeshare / artifacts /rebuttal /reasoning /quick_reasoning_sweep.py
Zishan Shao
Add lighthouse rebuttal artifacts
1c8e365
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
from statistics import mean
from typing import Any, Dict, List, Optional
ROOT = Path(__file__).resolve().parents[2]
DEFAULT_BASE_SCRIPT = ROOT / "reasoning" / "disturb_CoT_shared_loto_reasoning.py"
DEFAULT_RAW_DIR = Path(__file__).resolve().parent / "raw"
TASK_LABELS = {
"gsm8k": "Open-ended numeric reasoning",
"commonsenseqa": "Commonsense multiple choice",
"strategyqa": "Yes/No reasoning",
"aqua": "Math multiple choice",
"arc_challenge": "Science reasoning multiple choice",
"openbookqa": "Open-book science reasoning",
"qasc": "Multi-hop science reasoning",
"logiqa": "Logical reasoning multiple choice",
"boolq": "Reading comprehension yes/no",
"piqa": "Physical reasoning multiple choice",
}
CANDIDATE_COUNTS = {
"commonsenseqa": 5,
"strategyqa": 2,
"aqua": 5,
"arc_challenge": 4,
"openbookqa": 4,
"qasc": 8,
"logiqa": 4,
"boolq": 2,
"piqa": 2,
}
def parse_csv(value: str) -> List[str]:
return [item.strip() for item in str(value).split(",") if item.strip()]
def build_command(args: argparse.Namespace, holdout: str, out_json: Path, out_md: Path) -> List[str]:
cmd = [
args.python,
str(args.base_script),
"--model",
args.model,
"--device",
args.device,
"--model_dtype",
args.model_dtype,
"--layer",
str(args.layer),
"--tasks",
args.tasks,
"--mode",
"loto",
"--loto_eval_mode",
"heldout",
"--loto_only",
holdout,
"--n_subspace",
str(args.n_subspace),
"--n_eval",
str(args.n_eval),
"--pca_var",
str(args.pca_var),
"--tau",
str(args.tau),
"--m_shared",
args.m_shared,
"--calib_decode_max_new_tokens",
str(args.calib_decode_max_new_tokens),
"--per_task_max_states",
str(args.per_task_max_states),
"--alpha_remove",
str(args.alpha_remove),
"--reasoning_tokens",
str(args.reasoning_tokens),
"--max_new_tokens",
str(args.max_new_tokens),
"--temperature",
str(args.temperature),
"--top_p",
str(args.top_p),
"--top_k",
str(args.top_k),
"--do_sample",
str(int(args.do_sample)),
"--template_randomization",
str(int(args.template_randomization)),
"--template_seed",
str(args.template_seed),
"--shuffle_choices",
str(int(args.shuffle_choices)),
"--add_answer_prefix",
str(int(args.add_answer_prefix)),
"--answer_prefix",
args.answer_prefix,
"--use_forced_choice",
str(int(args.use_forced_choice)),
"--fc_warmup_tokens",
str(args.fc_warmup_tokens),
"--fc_prefix_mode",
args.fc_prefix_mode,
"--fc_answer_prefix",
args.fc_answer_prefix,
"--batch_size",
str(args.batch_size),
"--max_prompt_len",
str(args.max_prompt_len),
"--bootstrap_iters",
str(args.bootstrap_iters),
"--perm_iters",
str(args.perm_iters),
"--ci_alpha",
str(args.ci_alpha),
"--seed",
str(args.seed),
"--sample_seed",
str(args.sample_seed),
"--out_json",
str(out_json),
"--out_md",
str(out_md),
]
return cmd
def load_json(path: Path) -> Dict[str, Any]:
with path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def extract_metrics(payload: Dict[str, Any], holdout: str) -> Dict[str, Any]:
fold = payload["folds"][holdout]
result = fold["eval"][holdout]
paired = result.get("paired", {})
decode_minus_prefill = paired.get("decode_minus_prefill", {})
decode_minus_baseline = paired.get("decode_minus_baseline", {})
prefill_minus_baseline = paired.get("prefill_minus_baseline", {})
random_minus_baseline = paired.get("random_minus_baseline", {})
baseline_acc = float(result["baseline"]["acc"])
decode_acc = float(result["decode_shared"]["acc"])
prefill_acc = float(result["prefill_shared"]["acc"])
random_acc = float(result["random"]["acc"])
chance_acc: Optional[float] = None
if holdout in CANDIDATE_COUNTS:
chance_acc = 1.0 / float(CANDIDATE_COUNTS[holdout])
floor_flag = baseline_acc <= (1.0 / max(int(result.get("n", 1)), 1))
if chance_acc is not None:
floor_flag = floor_flag or baseline_acc <= chance_acc
return {
"task": holdout,
"task_label": TASK_LABELS.get(holdout, holdout),
"n_eval": int(result.get("n", 0)),
"protocol": result.get("protocol", ""),
"k_eval": int(fold.get("bases", {}).get("k_eval", 0)),
"chance_acc": chance_acc,
"baseline_near_floor": bool(floor_flag),
"baseline_acc": baseline_acc,
"decode_shared_acc": decode_acc,
"prefill_shared_acc": prefill_acc,
"random_acc": random_acc,
"decode_minus_prefill": float(decode_minus_prefill.get("mean_diff", 0.0)),
"decode_minus_prefill_ci_low": float(decode_minus_prefill.get("ci_low", 0.0)),
"decode_minus_prefill_ci_high": float(decode_minus_prefill.get("ci_high", 0.0)),
"decode_minus_prefill_p": float(decode_minus_prefill.get("p_value", 1.0)),
"decode_minus_baseline": float(decode_minus_baseline.get("mean_diff", decode_acc - baseline_acc)),
"prefill_minus_baseline": float(prefill_minus_baseline.get("mean_diff", prefill_acc - baseline_acc)),
"random_minus_baseline": float(random_minus_baseline.get("mean_diff", random_acc - baseline_acc)),
}
def fmt_pct(value: float) -> str:
return f"{100.0 * value:.1f}"
def fmt_delta(value: float) -> str:
return f"{100.0 * value:+.1f}"
def build_markdown(summary: Dict[str, Any]) -> str:
lines: List[str] = []
cfg = summary["config"]
aggregate = summary["aggregate"]
lines.append("# Quick Reasoning Rebuttal Check")
lines.append("")
lines.append("This is a quick-turn held-out-task check for reasoning-heavy tasks.")
lines.append("")
lines.append(f"- Model: `{cfg['model']}` dtype={cfg['model_dtype']} device={cfg['device']}")
lines.append(f"- Tasks used for basis/eval: `{cfg['tasks']}`")
lines.append(f"- Held-out tasks run: `{cfg['heldout_tasks']}`")
lines.append(f"- Per-task eval size: n_eval={cfg['n_eval']}, n_subspace={cfg['n_subspace']}, layer={cfg['layer']}")
lines.append(f"- Protocol: LOTO heldout, forced_choice={cfg['use_forced_choice']}, do_sample={cfg['do_sample']}")
lines.append("")
lines.append("## Per-task results")
lines.append("")
header = [
"Held-out",
"Type",
"n",
"Baseline",
"Decode-shared",
"Prefill-shared",
"Random",
"D-P delta",
"p",
]
rows: List[List[str]] = []
for item in summary["per_task"]:
rows.append(
[
item["task"],
item["task_label"],
str(item["n_eval"]),
fmt_pct(item["baseline_acc"]) if item["chance_acc"] is None else f"{fmt_pct(item['baseline_acc'])} (chance {fmt_pct(item['chance_acc'])})",
fmt_pct(item["decode_shared_acc"]),
fmt_pct(item["prefill_shared_acc"]),
fmt_pct(item["random_acc"]),
f"{fmt_delta(item['decode_minus_prefill'])} [{fmt_delta(item['decode_minus_prefill_ci_low'])}, {fmt_delta(item['decode_minus_prefill_ci_high'])}]",
f"{item['decode_minus_prefill_p']:.3g}",
]
)
cols = list(zip(*([header] + rows)))
widths = [max(len(str(x)) for x in col) for col in cols]
def fmt_row(row: List[str]) -> str:
return "| " + " | ".join(str(x).ljust(w) for x, w in zip(row, widths)) + " |"
lines.append(fmt_row(header))
lines.append("|-" + "-|-".join("-" * width for width in widths) + "-|")
for row in rows:
lines.append(fmt_row(row))
lines.append("")
lines.append("## Aggregate")
lines.append("")
lines.append(
f"- Mean accuracy: baseline={fmt_pct(aggregate['baseline_acc_mean'])}, "
f"decode_shared={fmt_pct(aggregate['decode_shared_acc_mean'])}, "
f"prefill_shared={fmt_pct(aggregate['prefill_shared_acc_mean'])}, "
f"random={fmt_pct(aggregate['random_acc_mean'])}"
)
lines.append(
f"- Mean deltas vs baseline: decode={fmt_delta(aggregate['decode_minus_baseline_mean'])}, "
f"prefill={fmt_delta(aggregate['prefill_minus_baseline_mean'])}, "
f"random={fmt_delta(aggregate['random_minus_baseline_mean'])}"
)
lines.append(
f"- Mean decode-minus-prefill delta: {fmt_delta(aggregate['decode_minus_prefill_mean'])}"
)
informative = summary["aggregate"].get("informative_tasks", [])
inconclusive = summary["aggregate"].get("inconclusive_tasks", [])
if informative:
lines.append(f"- Informative held-out tasks: `{','.join(informative)}`")
if inconclusive:
lines.append(f"- Inconclusive due to baseline floor/chance: `{','.join(inconclusive)}`")
lines.append("")
lines.append("## Interpretation")
lines.append("")
for item in summary["per_task"]:
if item["baseline_near_floor"]:
lines.append(
f"- `{item['task']}` is currently inconclusive: baseline is at or near floor/chance, so this fold does not say much about decode-vs-prefill selectivity."
)
else:
lines.append(
f"- `{item['task']}` is informative: decode-shared changes accuracy by {fmt_delta(item['decode_minus_baseline'])} vs baseline and {fmt_delta(item['decode_minus_prefill'])} vs prefill-shared."
)
lines.append(
"- Use informative folds as rebuttal evidence that the decode-shared phenomenon is not confined to short classification tasks."
)
return "\n".join(lines) + "\n"
def aggregate_metrics(per_task: List[Dict[str, Any]]) -> Dict[str, float]:
informative = [item for item in per_task if not item["baseline_near_floor"]]
return {
"baseline_acc_mean": mean(item["baseline_acc"] for item in per_task),
"decode_shared_acc_mean": mean(item["decode_shared_acc"] for item in per_task),
"prefill_shared_acc_mean": mean(item["prefill_shared_acc"] for item in per_task),
"random_acc_mean": mean(item["random_acc"] for item in per_task),
"decode_minus_baseline_mean": mean(item["decode_minus_baseline"] for item in per_task),
"prefill_minus_baseline_mean": mean(item["prefill_minus_baseline"] for item in per_task),
"random_minus_baseline_mean": mean(item["random_minus_baseline"] for item in per_task),
"decode_minus_prefill_mean": mean(item["decode_minus_prefill"] for item in per_task),
"informative_tasks": [item["task"] for item in informative],
"inconclusive_tasks": [item["task"] for item in per_task if item["baseline_near_floor"]],
}
def main() -> None:
parser = argparse.ArgumentParser(description="Run a small reasoning held-out sweep and summarize it.")
parser.add_argument("--python", default=sys.executable)
parser.add_argument("--base_script", type=Path, default=DEFAULT_BASE_SCRIPT)
parser.add_argument("--raw_dir", type=Path, default=DEFAULT_RAW_DIR)
parser.add_argument("--summary_json", type=Path, default=Path(__file__).resolve().parent / "quick_reasoning_summary.json")
parser.add_argument("--summary_md", type=Path, default=Path(__file__).resolve().parent / "quick_reasoning_summary.md")
parser.add_argument("--reuse_existing", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--model", default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--device", default="cuda")
parser.add_argument("--model_dtype", default="fp16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--tasks", default="gsm8k,commonsenseqa,strategyqa,arc_challenge,openbookqa,qasc,logiqa")
parser.add_argument("--heldout_tasks", default="gsm8k,logiqa")
parser.add_argument("--layer", type=int, default=10)
parser.add_argument("--n_subspace", type=int, default=64)
parser.add_argument("--n_eval", type=int, default=32)
parser.add_argument("--pca_var", type=float, default=0.95)
parser.add_argument("--tau", type=float, default=0.001)
parser.add_argument("--m_shared", default="all")
parser.add_argument("--calib_decode_max_new_tokens", type=int, default=64)
parser.add_argument("--per_task_max_states", type=int, default=4096)
parser.add_argument("--alpha_remove", type=float, default=1.0)
parser.add_argument("--reasoning_tokens", type=int, default=64)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--do_sample", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--template_randomization", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--template_seed", type=int, default=1234)
parser.add_argument("--shuffle_choices", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--add_answer_prefix", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--answer_prefix", default="\nFinal answer:")
parser.add_argument("--use_forced_choice", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--fc_warmup_tokens", type=int, default=0)
parser.add_argument("--fc_prefix_mode", default="auto", choices=["auto", "always", "never"])
parser.add_argument("--fc_answer_prefix", default="\nFinal answer:")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_prompt_len", type=int, default=512)
parser.add_argument("--bootstrap_iters", type=int, default=1000)
parser.add_argument("--perm_iters", type=int, default=2000)
parser.add_argument("--ci_alpha", type=float, default=0.05)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--sample_seed", type=int, default=12345)
parser.add_argument("--offline", action=argparse.BooleanOptionalAction, default=True)
args = parser.parse_args()
args.raw_dir.mkdir(parents=True, exist_ok=True)
args.summary_json.parent.mkdir(parents=True, exist_ok=True)
args.summary_md.parent.mkdir(parents=True, exist_ok=True)
heldout_tasks = parse_csv(args.heldout_tasks)
if not heldout_tasks:
raise ValueError("heldout_tasks must not be empty")
task_set = set(parse_csv(args.tasks))
bad = [task for task in heldout_tasks if task not in task_set]
if bad:
raise ValueError(f"heldout_tasks must be a subset of tasks; got invalid {bad}")
env = os.environ.copy()
if args.offline:
env.setdefault("TRANSFORMERS_OFFLINE", "1")
env.setdefault("HF_DATASETS_OFFLINE", "1")
per_task: List[Dict[str, Any]] = []
commands: Dict[str, List[str]] = {}
raw_files: Dict[str, Dict[str, str]] = {}
for holdout in heldout_tasks:
out_json = args.raw_dir / f"{holdout}_loto.json"
out_md = args.raw_dir / f"{holdout}_loto.md"
cmd = build_command(args, holdout, out_json, out_md)
commands[holdout] = cmd
raw_files[holdout] = {"json": str(out_json), "md": str(out_md)}
if args.reuse_existing and out_json.exists():
print(f"[reuse] holdout={holdout} json={out_json}")
else:
print(f"[run] holdout={holdout}")
print("[cmd]", " ".join(cmd))
subprocess.run(cmd, check=True, env=env, cwd=str(ROOT))
payload = load_json(out_json)
per_task.append(extract_metrics(payload, holdout))
per_task.sort(key=lambda item: heldout_tasks.index(item["task"]))
summary = {
"config": {
"model": args.model,
"device": args.device,
"model_dtype": args.model_dtype,
"tasks": args.tasks,
"heldout_tasks": ",".join(heldout_tasks),
"layer": args.layer,
"n_subspace": args.n_subspace,
"n_eval": args.n_eval,
"use_forced_choice": bool(args.use_forced_choice),
"do_sample": bool(args.do_sample),
"offline": bool(args.offline),
},
"commands": commands,
"raw_files": raw_files,
"per_task": per_task,
"aggregate": aggregate_metrics(per_task),
}
with args.summary_json.open("w", encoding="utf-8") as handle:
json.dump(summary, handle, ensure_ascii=False, indent=2)
with args.summary_md.open("w", encoding="utf-8") as handle:
handle.write(build_markdown(summary))
print(f"[done] summary_json={args.summary_json}")
print(f"[done] summary_md={args.summary_md}")
if __name__ == "__main__":
main()