"""scripts/eval.py - held-out evaluation harness (Sections 6.2 + 7.3). Runs a model (or one of the deterministic baselines) over a held-out set of syndromes and reports: * format compliance rate * logical correction rate * mean Hamming-overlap with PyMatching * PyMatching beat-rate * mean total reward Usage:: # Baseline run (no model; uses PyMatching-imitator): python -m scripts.eval --policy pymatching --episodes 200 # Trained model (loads adapters via Unsloth): python -m scripts.eval --adapter checkpoints/grpo --episodes 500 # With W&B logging (summary + per-episode table): python -m scripts.eval --adapter checkpoints/grpo --episodes 500 \ --report-to wandb --wandb-group my-experiment """ from __future__ import annotations import argparse import json import sys from typing import Iterable from qubit_medic.client.client import LocalDecoderClient from qubit_medic.config import primary_level def _summary(name: str, results: list[dict]) -> dict: """Aggregate per-episode reward dicts into the metrics the master spec benchmarks against (sections 6 + 7 of the locked spec). Each entry in ``results`` is the env's per-step ``info["rewards"]`` dict, optionally with extra fields the eval loop decorated: * ``exact_match_pymatching`` (model-eval only) * ``output_length`` (model-eval only) * ``n_true_errors`` (any caller; enables hard-syndrome subset) """ n = max(1, len(results)) # Hard-syndrome subset = episodes where the simulated truth contains # at least 2 X|Z errors. This is the cohort where MWPM ambiguity # matters and trained-model contributions are most visible. hard = [r for r in results if int(r.get("n_true_errors", 0)) >= 2] n_hard = len(hard) out = { "name": name, "episodes": len(results), # Headline metrics (master spec, section 6). "logical_correction_rate": sum(r["logical_correction"] >= 0.5 for r in results) / n, "pymatching_beat_rate": sum(r["pymatching_beat"] >= 0.5 for r in results) / n, "format_compliance_rate": sum(r["format_compliance"] >= 0.999 for r in results) / n, "format_partial_rate": sum((r["format_compliance"] >= 0.5 and r["format_compliance"] < 0.999) for r in results) / n, # Continuous progress metrics. "syndrome_consistency_rate": sum(r["syndrome_consistency"] >= 0.999 for r in results) / n, "mean_syndrome_consistency": sum(r["syndrome_consistency"] for r in results) / n, "mean_hamming_overlap": sum(r["hamming_overlap"] for r in results) / n, "mean_total_reward": sum(r["total"] for r in results) / n, # Model-eval extras (present iff the model loop populated them). "exact_match_pymatching": sum(int(r.get("exact_match_pymatching", 0)) for r in results) / n, "mean_output_length": sum(int(r.get("output_length", 0)) for r in results) / n, # Hard-syndrome subset (FIX 5, 2026-04 eval spec). Easy syndromes # are where every baseline already hits ~95%+; the hard subset is # where differentiation actually shows up. "hard_syndrome_count": n_hard, "hard_syndrome_lcr": (sum(r["logical_correction"] >= 0.5 for r in hard) / n_hard if n_hard else 0.0), "hard_syndrome_beat_rate": (sum(r["pymatching_beat"] >= 0.5 for r in hard) / n_hard if n_hard else 0.0), } return out def _eval_baseline(name: str, episodes: int, level: str, collect_rows: bool = False): from scripts.baseline_policies import ( policy_pymatching, policy_zeros, policy_random, ) import random as _r rng = _r.Random(0) pol_map = { "pymatching": lambda obs: policy_pymatching(obs, env_client=None), "zeros": policy_zeros, "random": lambda obs: policy_random(obs, rng=rng), } if name not in pol_map: raise ValueError(f"unknown baseline {name}; choose from {sorted(pol_map)}") pol = pol_map[name] client = LocalDecoderClient() rewards = [] rows = [] for ep in range(episodes): obs = client.reset(forced_level=level, seed=10_000 + ep) completion = pol(obs) result = client.step(raw_response=completion, episode_id=obs.episode_id) rwd = dict(result.info["rewards"]) # copy so we can decorate # Tag with true-error count so _summary can filter the hard subset. rwd["n_true_errors"] = ( len(result.info.get("pymatching_x_errors", []) or []) + len(result.info.get("pymatching_z_errors", []) or []) ) rewards.append(rwd) if collect_rows and ep < 50: # cap table size rows.append({ "episode": ep, "completion": completion, "logical_correction": rwd["logical_correction"], "syndrome_consistency": rwd["syndrome_consistency"], "hamming_overlap": rwd["hamming_overlap"], "format_compliance": rwd["format_compliance"], "pymatching_beat": rwd["pymatching_beat"], "total": rwd["total"], "actual_obs_flip": result.info["actual_observable_flip"], "pm_obs_flip": result.info["pymatching_observable_pred"], }) return _summary(name, rewards), rows def _eval_model(adapter: str, episodes: int, level: str, base_model: str, max_new_tokens: int, collect_rows: bool = False): """Use Unsloth to load the adapter and generate completions. Populates ``exact_match_pymatching`` and ``output_length`` on each per-episode reward dict so :func:`_summary` can report the master spec's full benchmark suite (section 6 + section 7). """ from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=adapter if adapter else base_model, max_seq_length=2048, load_in_4bit=True, dtype=None, ) FastLanguageModel.for_inference(model) client = LocalDecoderClient() rewards = [] rows = [] for ep in range(episodes): obs = client.reset(forced_level=level, seed=10_000 + ep) chat = [{"role": "user", "content": obs.prompt}] text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # deterministic / greedy eval eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) gen_ids = out[0][inputs["input_ids"].shape[1]:] completion = tokenizer.decode(gen_ids, skip_special_tokens=True) n_tokens = int(gen_ids.shape[0]) result = client.step(raw_response=completion, episode_id=obs.episode_id) rwd = dict(result.info["rewards"]) # copy so we can decorate # Decorate with the master-spec extras. action = result.info.get("parsed_action", {}) or {} pm_x = sorted(set(map(int, result.info.get("pymatching_x_errors", []) or []))) pm_z = sorted(set(map(int, result.info.get("pymatching_z_errors", []) or []))) our_x = sorted(set(map(int, action.get("x_error_qubits", []) or []))) our_z = sorted(set(map(int, action.get("z_error_qubits", []) or []))) rwd["exact_match_pymatching"] = int( bool(action.get("parse_success", False)) and our_x == pm_x and our_z == pm_z ) rwd["output_length"] = n_tokens rwd["n_true_errors"] = len(pm_x) + len(pm_z) rewards.append(rwd) if collect_rows and ep < 50: rows.append({ "episode": ep, "completion": completion[:300], "logical_correction": rwd["logical_correction"], "syndrome_consistency": rwd["syndrome_consistency"], "hamming_overlap": rwd["hamming_overlap"], "format_compliance": rwd["format_compliance"], "pymatching_beat": rwd["pymatching_beat"], "exact_match_pymatching": rwd["exact_match_pymatching"], "output_length": rwd["output_length"], "total": rwd["total"], "actual_obs_flip": result.info["actual_observable_flip"], "pm_obs_flip": result.info["pymatching_observable_pred"], }) return _summary(f"model[{adapter}]", rewards), rows def main(argv: Iterable[str] = ()) -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--policy", choices=["random", "zeros", "pymatching"], default=None, help="evaluate a deterministic baseline instead of a model") parser.add_argument("--adapter", type=str, default=None, help="path to LoRA adapter dir; mutually exclusive with --policy") parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-3B-Instruct") parser.add_argument("--episodes", type=int, default=200) parser.add_argument("--level", type=str, default=primary_level().name) parser.add_argument("--max-new-tokens", type=int, default=160) parser.add_argument("--out", type=str, default=None) parser.add_argument("--report-to", type=str, default="none", choices=["wandb", "none"], help="If 'wandb', log summary + per-episode table.") parser.add_argument("--wandb-run-name", type=str, default=None) parser.add_argument("--wandb-group", type=str, default=None) parser.add_argument("--wandb-tags", type=str, nargs="*", default=("eval",)) parser.add_argument("--wandb-notes", type=str, default=None) args = parser.parse_args(list(argv)) if (args.policy is None) == (args.adapter is None): print("ERROR: exactly one of --policy and --adapter is required", file=sys.stderr) return 1 from qubit_medic import wandb_utils report_to = wandb_utils.derive_report_to(args.report_to) use_wandb = report_to == "wandb" if use_wandb: slug = args.policy or (args.adapter or "model").replace("/", "_") run_name = args.wandb_run_name or wandb_utils.make_run_name( "eval", suffix=slug) wandb_utils.init_run( run_name=run_name, job_type="eval", tags=tuple(list(args.wandb_tags) + [args.level]), notes=args.wandb_notes, group=args.wandb_group, extra_config={ "cli": { "policy": args.policy, "adapter": args.adapter, "episodes": args.episodes, "level": args.level, "max_new_tokens": args.max_new_tokens, "base_model": args.base_model, }, }, ) if args.policy is not None: result, rows = _eval_baseline(args.policy, args.episodes, args.level, collect_rows=use_wandb) else: result, rows = _eval_model(args.adapter, args.episodes, args.level, args.base_model, args.max_new_tokens, collect_rows=use_wandb) result["level"] = args.level print(json.dumps(result, indent=2)) if args.out: from pathlib import Path Path(args.out).parent.mkdir(parents=True, exist_ok=True) with open(args.out, "w") as f: json.dump(result, f, indent=2) if use_wandb: wandb_utils.log_eval_summary(result, prefix="eval") if rows: wandb_utils.log_generation_table( rows, step=None, table_name="eval/episode_breakdown", ) wandb_utils.update_summary({ "eval/policy_or_adapter": args.policy or args.adapter, "eval/episodes": args.episodes, "eval/level": args.level, }) wandb_utils.finish_run() return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))