Buckets:
| #!/usr/bin/env python3 | |
| """Run linear probes on all trained checkpoints and write results to JSON.""" | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from pawn.config import CLMConfig | |
| from pawn.model import PAWNCLM | |
| from pawn.eval_suite.probes import extract_probe_data, train_all_probes | |
| def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM: | |
| from pawn.checkpoint import load_backbone_weights | |
| state_dict, model_config = load_backbone_weights(checkpoint_path, device) | |
| if model_config: | |
| cfg = CLMConfig(**model_config) | |
| else: | |
| # Fallback: infer from state dict shapes | |
| d_model = state_dict["embed.src_embed.weight"].shape[1] | |
| n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1 | |
| if d_model == 256 and n_layers == 8: | |
| cfg = CLMConfig.small() | |
| elif d_model == 512 and n_layers == 8: | |
| cfg = CLMConfig.base() | |
| elif d_model == 640 and n_layers == 10: | |
| cfg = CLMConfig.large() | |
| else: | |
| cfg = CLMConfig(d_model=d_model, n_layers=n_layers) | |
| model = PAWNCLM(cfg).to(device) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Run linear probes on checkpoints") | |
| parser.add_argument("--log-dir", type=str, default="logs", help="Log directory containing run dirs") | |
| parser.add_argument("--n-games", type=int, default=4096, help="Games for probe train set") | |
| parser.add_argument("--n-val-games", type=int, default=1024, help="Games for probe val set") | |
| parser.add_argument("--n-epochs", type=int, default=20, help="Probe training epochs") | |
| parser.add_argument("--device", type=str, default=None) | |
| parser.add_argument("--run", type=str, default=None, help="Only evaluate this run dir name") | |
| parser.add_argument("--checkpoint", type=str, default=None, | |
| help="Evaluate a single checkpoint path directly (skips log-dir scan)") | |
| parser.add_argument("--no-outcome-token", action="store_true", | |
| help="Strip outcome token from sequences (auto-detected from checkpoint config)") | |
| args = parser.parse_args() | |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| if device == "cuda": | |
| from pawn.gpu import configure_gpu | |
| gpu_cfg = configure_gpu() | |
| import pawn.model as model_module | |
| model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend") | |
| log_dir = Path(args.log_dir) | |
| # Find all runs with checkpoints | |
| runs = [] | |
| if args.checkpoint: | |
| # Direct checkpoint path — build a minimal run entry | |
| ckpt_path = Path(args.checkpoint) | |
| # Try to load config from checkpoint dir | |
| cfg = {} | |
| cfg_file = ckpt_path / "config.json" | |
| if cfg_file.exists(): | |
| with open(cfg_file) as f: | |
| cfg = json.load(f) | |
| run_dir = log_dir | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| runs.append((run_dir, ckpt_path, cfg)) | |
| else: | |
| for config_path in sorted(log_dir.glob("run_*/config.json")): | |
| run_dir = config_path.parent | |
| if args.run and run_dir.name != args.run: | |
| continue | |
| # Find checkpoints: directory-based (safetensors) or legacy .pt | |
| checkpoints = sorted( | |
| [d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()] | |
| or list(run_dir.glob("checkpoints/step_*.pt")) | |
| ) | |
| if not checkpoints: | |
| continue | |
| latest = checkpoints[-1] | |
| with open(config_path) as f: | |
| cfg = json.load(f) | |
| runs.append((run_dir, latest, cfg)) | |
| if not runs: | |
| print("No runs with checkpoints found.") | |
| sys.exit(1) | |
| print(f"Found {len(runs)} runs to evaluate") | |
| # Generate probe data once (shared across all models with same max_ply) | |
| max_ply = 256 | |
| print(f"\nGenerating probe data: {args.n_games} train + {args.n_val_games} val games...") | |
| train_data = extract_probe_data(args.n_games, max_ply, seed=12345) | |
| val_data = extract_probe_data(args.n_val_games, max_ply, seed=54321) | |
| print("Done.") | |
| for run_dir, ckpt_path, run_cfg in runs: | |
| model_cfg = run_cfg.get("model", {}) | |
| train_cfg = run_cfg.get("training", {}) | |
| variant = f"{model_cfg.get('d_model', '?')}d/{model_cfg.get('n_layers', '?')}L" | |
| discard = train_cfg.get("discard_ply_limit", False) | |
| step = ckpt_path.stem.replace("step_", "") | |
| print(f"\n{'='*60}") | |
| print(f"Run: {run_dir.name} ({variant}, discard_ply={discard}, step={step})") | |
| print(f"Checkpoint: {ckpt_path}") | |
| print(f"{'='*60}") | |
| model = load_model_from_checkpoint(str(ckpt_path), device) | |
| # Auto-detect no_outcome_token from checkpoint config | |
| no_outcome = args.no_outcome_token or train_cfg.get("no_outcome_token", False) | |
| if no_outcome: | |
| print(f" [no_outcome_token=True] Stripping outcome token from probe inputs") | |
| results = train_all_probes( | |
| model, train_data, val_data, device, | |
| per_layer=True, n_epochs=args.n_epochs, verbose=True, | |
| no_outcome_token=no_outcome, | |
| ) | |
| # Save results | |
| output = { | |
| "run": run_dir.name, | |
| "checkpoint": str(ckpt_path), | |
| "step": int(step), | |
| "variant": variant, | |
| "discard_ply_limit": discard, | |
| "no_outcome_token": no_outcome, | |
| "model_config": model_cfg, | |
| "probes": { | |
| pname: { | |
| lname: {k: round(v, 6) if isinstance(v, float) else v for k, v in metrics.items()} | |
| for lname, metrics in layer_results.items() | |
| } | |
| for pname, layer_results in results.items() | |
| }, | |
| } | |
| out_path = run_dir / "probe_results.json" | |
| with open(out_path, "w") as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\nSaved: {out_path}") | |
| del model | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 6.25 kB
- Xet hash:
- 38fc11ae9e2f50fef5d9a58bf0cad382c04587096ba557349c4256dc22ada5f1
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.