from __future__ import annotations import argparse import json import os import subprocess import sys from pathlib import Path from statistics import mean from typing import Any, Dict, List, Optional ROOT = Path(__file__).resolve().parents[2] DEFAULT_BASE_SCRIPT = ROOT / "reasoning" / "disturb_CoT_shared_loto_reasoning.py" DEFAULT_RAW_DIR = Path(__file__).resolve().parent / "raw" TASK_LABELS = { "gsm8k": "Open-ended numeric reasoning", "commonsenseqa": "Commonsense multiple choice", "strategyqa": "Yes/No reasoning", "aqua": "Math multiple choice", "arc_challenge": "Science reasoning multiple choice", "openbookqa": "Open-book science reasoning", "qasc": "Multi-hop science reasoning", "logiqa": "Logical reasoning multiple choice", "boolq": "Reading comprehension yes/no", "piqa": "Physical reasoning multiple choice", } CANDIDATE_COUNTS = { "commonsenseqa": 5, "strategyqa": 2, "aqua": 5, "arc_challenge": 4, "openbookqa": 4, "qasc": 8, "logiqa": 4, "boolq": 2, "piqa": 2, } def parse_csv(value: str) -> List[str]: return [item.strip() for item in str(value).split(",") if item.strip()] def build_command(args: argparse.Namespace, holdout: str, out_json: Path, out_md: Path) -> List[str]: cmd = [ args.python, str(args.base_script), "--model", args.model, "--device", args.device, "--model_dtype", args.model_dtype, "--layer", str(args.layer), "--tasks", args.tasks, "--mode", "loto", "--loto_eval_mode", "heldout", "--loto_only", holdout, "--n_subspace", str(args.n_subspace), "--n_eval", str(args.n_eval), "--pca_var", str(args.pca_var), "--tau", str(args.tau), "--m_shared", args.m_shared, "--calib_decode_max_new_tokens", str(args.calib_decode_max_new_tokens), "--per_task_max_states", str(args.per_task_max_states), "--alpha_remove", str(args.alpha_remove), "--reasoning_tokens", str(args.reasoning_tokens), "--max_new_tokens", str(args.max_new_tokens), "--temperature", str(args.temperature), "--top_p", str(args.top_p), "--top_k", str(args.top_k), "--do_sample", str(int(args.do_sample)), "--template_randomization", str(int(args.template_randomization)), "--template_seed", str(args.template_seed), "--shuffle_choices", str(int(args.shuffle_choices)), "--add_answer_prefix", str(int(args.add_answer_prefix)), "--answer_prefix", args.answer_prefix, "--use_forced_choice", str(int(args.use_forced_choice)), "--fc_warmup_tokens", str(args.fc_warmup_tokens), "--fc_prefix_mode", args.fc_prefix_mode, "--fc_answer_prefix", args.fc_answer_prefix, "--batch_size", str(args.batch_size), "--max_prompt_len", str(args.max_prompt_len), "--bootstrap_iters", str(args.bootstrap_iters), "--perm_iters", str(args.perm_iters), "--ci_alpha", str(args.ci_alpha), "--seed", str(args.seed), "--sample_seed", str(args.sample_seed), "--out_json", str(out_json), "--out_md", str(out_md), ] return cmd def load_json(path: Path) -> Dict[str, Any]: with path.open("r", encoding="utf-8") as handle: return json.load(handle) def extract_metrics(payload: Dict[str, Any], holdout: str) -> Dict[str, Any]: fold = payload["folds"][holdout] result = fold["eval"][holdout] paired = result.get("paired", {}) decode_minus_prefill = paired.get("decode_minus_prefill", {}) decode_minus_baseline = paired.get("decode_minus_baseline", {}) prefill_minus_baseline = paired.get("prefill_minus_baseline", {}) random_minus_baseline = paired.get("random_minus_baseline", {}) baseline_acc = float(result["baseline"]["acc"]) decode_acc = float(result["decode_shared"]["acc"]) prefill_acc = float(result["prefill_shared"]["acc"]) random_acc = float(result["random"]["acc"]) chance_acc: Optional[float] = None if holdout in CANDIDATE_COUNTS: chance_acc = 1.0 / float(CANDIDATE_COUNTS[holdout]) floor_flag = baseline_acc <= (1.0 / max(int(result.get("n", 1)), 1)) if chance_acc is not None: floor_flag = floor_flag or baseline_acc <= chance_acc return { "task": holdout, "task_label": TASK_LABELS.get(holdout, holdout), "n_eval": int(result.get("n", 0)), "protocol": result.get("protocol", ""), "k_eval": int(fold.get("bases", {}).get("k_eval", 0)), "chance_acc": chance_acc, "baseline_near_floor": bool(floor_flag), "baseline_acc": baseline_acc, "decode_shared_acc": decode_acc, "prefill_shared_acc": prefill_acc, "random_acc": random_acc, "decode_minus_prefill": float(decode_minus_prefill.get("mean_diff", 0.0)), "decode_minus_prefill_ci_low": float(decode_minus_prefill.get("ci_low", 0.0)), "decode_minus_prefill_ci_high": float(decode_minus_prefill.get("ci_high", 0.0)), "decode_minus_prefill_p": float(decode_minus_prefill.get("p_value", 1.0)), "decode_minus_baseline": float(decode_minus_baseline.get("mean_diff", decode_acc - baseline_acc)), "prefill_minus_baseline": float(prefill_minus_baseline.get("mean_diff", prefill_acc - baseline_acc)), "random_minus_baseline": float(random_minus_baseline.get("mean_diff", random_acc - baseline_acc)), } def fmt_pct(value: float) -> str: return f"{100.0 * value:.1f}" def fmt_delta(value: float) -> str: return f"{100.0 * value:+.1f}" def build_markdown(summary: Dict[str, Any]) -> str: lines: List[str] = [] cfg = summary["config"] aggregate = summary["aggregate"] lines.append("# Quick Reasoning Rebuttal Check") lines.append("") lines.append("This is a quick-turn held-out-task check for reasoning-heavy tasks.") lines.append("") lines.append(f"- Model: `{cfg['model']}` dtype={cfg['model_dtype']} device={cfg['device']}") lines.append(f"- Tasks used for basis/eval: `{cfg['tasks']}`") lines.append(f"- Held-out tasks run: `{cfg['heldout_tasks']}`") lines.append(f"- Per-task eval size: n_eval={cfg['n_eval']}, n_subspace={cfg['n_subspace']}, layer={cfg['layer']}") lines.append(f"- Protocol: LOTO heldout, forced_choice={cfg['use_forced_choice']}, do_sample={cfg['do_sample']}") lines.append("") lines.append("## Per-task results") lines.append("") header = [ "Held-out", "Type", "n", "Baseline", "Decode-shared", "Prefill-shared", "Random", "D-P delta", "p", ] rows: List[List[str]] = [] for item in summary["per_task"]: rows.append( [ item["task"], item["task_label"], str(item["n_eval"]), fmt_pct(item["baseline_acc"]) if item["chance_acc"] is None else f"{fmt_pct(item['baseline_acc'])} (chance {fmt_pct(item['chance_acc'])})", fmt_pct(item["decode_shared_acc"]), fmt_pct(item["prefill_shared_acc"]), fmt_pct(item["random_acc"]), f"{fmt_delta(item['decode_minus_prefill'])} [{fmt_delta(item['decode_minus_prefill_ci_low'])}, {fmt_delta(item['decode_minus_prefill_ci_high'])}]", f"{item['decode_minus_prefill_p']:.3g}", ] ) cols = list(zip(*([header] + rows))) widths = [max(len(str(x)) for x in col) for col in cols] def fmt_row(row: List[str]) -> str: return "| " + " | ".join(str(x).ljust(w) for x, w in zip(row, widths)) + " |" lines.append(fmt_row(header)) lines.append("|-" + "-|-".join("-" * width for width in widths) + "-|") for row in rows: lines.append(fmt_row(row)) lines.append("") lines.append("## Aggregate") lines.append("") lines.append( f"- Mean accuracy: baseline={fmt_pct(aggregate['baseline_acc_mean'])}, " f"decode_shared={fmt_pct(aggregate['decode_shared_acc_mean'])}, " f"prefill_shared={fmt_pct(aggregate['prefill_shared_acc_mean'])}, " f"random={fmt_pct(aggregate['random_acc_mean'])}" ) lines.append( f"- Mean deltas vs baseline: decode={fmt_delta(aggregate['decode_minus_baseline_mean'])}, " f"prefill={fmt_delta(aggregate['prefill_minus_baseline_mean'])}, " f"random={fmt_delta(aggregate['random_minus_baseline_mean'])}" ) lines.append( f"- Mean decode-minus-prefill delta: {fmt_delta(aggregate['decode_minus_prefill_mean'])}" ) informative = summary["aggregate"].get("informative_tasks", []) inconclusive = summary["aggregate"].get("inconclusive_tasks", []) if informative: lines.append(f"- Informative held-out tasks: `{','.join(informative)}`") if inconclusive: lines.append(f"- Inconclusive due to baseline floor/chance: `{','.join(inconclusive)}`") lines.append("") lines.append("## Interpretation") lines.append("") for item in summary["per_task"]: if item["baseline_near_floor"]: lines.append( f"- `{item['task']}` is currently inconclusive: baseline is at or near floor/chance, so this fold does not say much about decode-vs-prefill selectivity." ) else: lines.append( f"- `{item['task']}` is informative: decode-shared changes accuracy by {fmt_delta(item['decode_minus_baseline'])} vs baseline and {fmt_delta(item['decode_minus_prefill'])} vs prefill-shared." ) lines.append( "- Use informative folds as rebuttal evidence that the decode-shared phenomenon is not confined to short classification tasks." ) return "\n".join(lines) + "\n" def aggregate_metrics(per_task: List[Dict[str, Any]]) -> Dict[str, float]: informative = [item for item in per_task if not item["baseline_near_floor"]] return { "baseline_acc_mean": mean(item["baseline_acc"] for item in per_task), "decode_shared_acc_mean": mean(item["decode_shared_acc"] for item in per_task), "prefill_shared_acc_mean": mean(item["prefill_shared_acc"] for item in per_task), "random_acc_mean": mean(item["random_acc"] for item in per_task), "decode_minus_baseline_mean": mean(item["decode_minus_baseline"] for item in per_task), "prefill_minus_baseline_mean": mean(item["prefill_minus_baseline"] for item in per_task), "random_minus_baseline_mean": mean(item["random_minus_baseline"] for item in per_task), "decode_minus_prefill_mean": mean(item["decode_minus_prefill"] for item in per_task), "informative_tasks": [item["task"] for item in informative], "inconclusive_tasks": [item["task"] for item in per_task if item["baseline_near_floor"]], } def main() -> None: parser = argparse.ArgumentParser(description="Run a small reasoning held-out sweep and summarize it.") parser.add_argument("--python", default=sys.executable) parser.add_argument("--base_script", type=Path, default=DEFAULT_BASE_SCRIPT) parser.add_argument("--raw_dir", type=Path, default=DEFAULT_RAW_DIR) parser.add_argument("--summary_json", type=Path, default=Path(__file__).resolve().parent / "quick_reasoning_summary.json") parser.add_argument("--summary_md", type=Path, default=Path(__file__).resolve().parent / "quick_reasoning_summary.md") parser.add_argument("--reuse_existing", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--model", default="meta-llama/Llama-2-7b-chat-hf") parser.add_argument("--device", default="cuda") parser.add_argument("--model_dtype", default="fp16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--tasks", default="gsm8k,commonsenseqa,strategyqa,arc_challenge,openbookqa,qasc,logiqa") parser.add_argument("--heldout_tasks", default="gsm8k,logiqa") parser.add_argument("--layer", type=int, default=10) parser.add_argument("--n_subspace", type=int, default=64) parser.add_argument("--n_eval", type=int, default=32) parser.add_argument("--pca_var", type=float, default=0.95) parser.add_argument("--tau", type=float, default=0.001) parser.add_argument("--m_shared", default="all") parser.add_argument("--calib_decode_max_new_tokens", type=int, default=64) parser.add_argument("--per_task_max_states", type=int, default=4096) parser.add_argument("--alpha_remove", type=float, default=1.0) parser.add_argument("--reasoning_tokens", type=int, default=64) parser.add_argument("--max_new_tokens", type=int, default=128) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--do_sample", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--template_randomization", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--template_seed", type=int, default=1234) parser.add_argument("--shuffle_choices", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--add_answer_prefix", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--answer_prefix", default="\nFinal answer:") parser.add_argument("--use_forced_choice", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--fc_warmup_tokens", type=int, default=0) parser.add_argument("--fc_prefix_mode", default="auto", choices=["auto", "always", "never"]) parser.add_argument("--fc_answer_prefix", default="\nFinal answer:") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_prompt_len", type=int, default=512) parser.add_argument("--bootstrap_iters", type=int, default=1000) parser.add_argument("--perm_iters", type=int, default=2000) parser.add_argument("--ci_alpha", type=float, default=0.05) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--sample_seed", type=int, default=12345) parser.add_argument("--offline", action=argparse.BooleanOptionalAction, default=True) args = parser.parse_args() args.raw_dir.mkdir(parents=True, exist_ok=True) args.summary_json.parent.mkdir(parents=True, exist_ok=True) args.summary_md.parent.mkdir(parents=True, exist_ok=True) heldout_tasks = parse_csv(args.heldout_tasks) if not heldout_tasks: raise ValueError("heldout_tasks must not be empty") task_set = set(parse_csv(args.tasks)) bad = [task for task in heldout_tasks if task not in task_set] if bad: raise ValueError(f"heldout_tasks must be a subset of tasks; got invalid {bad}") env = os.environ.copy() if args.offline: env.setdefault("TRANSFORMERS_OFFLINE", "1") env.setdefault("HF_DATASETS_OFFLINE", "1") per_task: List[Dict[str, Any]] = [] commands: Dict[str, List[str]] = {} raw_files: Dict[str, Dict[str, str]] = {} for holdout in heldout_tasks: out_json = args.raw_dir / f"{holdout}_loto.json" out_md = args.raw_dir / f"{holdout}_loto.md" cmd = build_command(args, holdout, out_json, out_md) commands[holdout] = cmd raw_files[holdout] = {"json": str(out_json), "md": str(out_md)} if args.reuse_existing and out_json.exists(): print(f"[reuse] holdout={holdout} json={out_json}") else: print(f"[run] holdout={holdout}") print("[cmd]", " ".join(cmd)) subprocess.run(cmd, check=True, env=env, cwd=str(ROOT)) payload = load_json(out_json) per_task.append(extract_metrics(payload, holdout)) per_task.sort(key=lambda item: heldout_tasks.index(item["task"])) summary = { "config": { "model": args.model, "device": args.device, "model_dtype": args.model_dtype, "tasks": args.tasks, "heldout_tasks": ",".join(heldout_tasks), "layer": args.layer, "n_subspace": args.n_subspace, "n_eval": args.n_eval, "use_forced_choice": bool(args.use_forced_choice), "do_sample": bool(args.do_sample), "offline": bool(args.offline), }, "commands": commands, "raw_files": raw_files, "per_task": per_task, "aggregate": aggregate_metrics(per_task), } with args.summary_json.open("w", encoding="utf-8") as handle: json.dump(summary, handle, ensure_ascii=False, indent=2) with args.summary_md.open("w", encoding="utf-8") as handle: handle.write(build_markdown(summary)) print(f"[done] summary_json={args.summary_json}") print(f"[done] summary_md={args.summary_md}") if __name__ == "__main__": main()