Spaces:
Sleeping
Sleeping
| """Before / after demo β baseline vs GRPO-trained policy. | |
| Designed for hackathon judges: loads both models, runs greedy evaluation on | |
| a fixed problem set, and prints a clean side-by-side comparison with full | |
| solution text for the most interesting examples. | |
| Features | |
| -------- | |
| * Handles all checkpoint types: HF model IDs, GRPO full-weight saves, | |
| PEFT/LoRA adapter directories. | |
| * Automatically loads the chat template from the base model when the | |
| checkpoint tokenizer doesn't have one (fixes the 0% accuracy bug that | |
| silently swallows TemplateErrors). | |
| * Reads ``metrics.jsonl`` (if present) and prints the full accuracy curve, | |
| showing judges the training progression at a glance. | |
| * Saves machine-readable JSON (for grading scripts) and prints a human- | |
| readable Markdown table. | |
| * Shows full solution text for the best wins and worst regressions. | |
| Quick-start | |
| ----------- | |
| After a GRPO run, point at ``best_policy/``:: | |
| python scripts/demo_before_after.py \\ | |
| --baseline-model checkpoints/dual_task_v1 \\ | |
| --trained-model checkpoints/grpo/<run>/best_policy \\ | |
| --problems data/sft/gsm8k_sft.jsonl \\ | |
| --max-samples 100 | |
| Include the training curve:: | |
| python scripts/demo_before_after.py \\ | |
| --baseline-model checkpoints/dual_task_v1 \\ | |
| --trained-model checkpoints/grpo/<run>/best_policy \\ | |
| --metrics-jsonl checkpoints/grpo/<run>/metrics.jsonl \\ | |
| --problems data/sft/gsm8k_sft.jsonl \\ | |
| --max-samples 100 \\ | |
| --records-out results/demo.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import re | |
| import sys | |
| import time | |
| import types | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from peft import PeftModel | |
| from tqdm.auto import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.sft.solution_format import extract_final_answer_numeric_str | |
| from src.utils.attn_backend import select_attn_implementation | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(name)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| _SEP = "=" * 78 | |
| _SEP2 = "-" * 78 | |
| # --------------------------------------------------------------------------- | |
| # Data | |
| # --------------------------------------------------------------------------- | |
| class Problem: | |
| question: str | |
| gold_final: str | |
| def _parse_gold(answer: str) -> str: | |
| m = re.search(r"####\s*([-0-9.,/ ]+)", answer) | |
| if m: | |
| return m.group(1).strip().replace(",", "") | |
| return answer.strip().splitlines()[-1].strip() | |
| def _load_problems(path: Path, max_samples: int) -> List[Problem]: | |
| """Accept GSM8K ``{question, answer}`` or SFT ``{messages}`` JSONL.""" | |
| out: List[Problem] = [] | |
| with path.open(encoding="utf-8") as fh: | |
| for line in fh: | |
| if max_samples > 0 and len(out) >= max_samples: | |
| break | |
| line = line.strip() | |
| if not line: | |
| continue | |
| obj = json.loads(line) | |
| if "question" in obj and "answer" in obj: | |
| out.append(Problem( | |
| question=obj["question"].strip(), | |
| gold_final=_parse_gold(obj["answer"]), | |
| )) | |
| elif "messages" in obj: | |
| user = next( | |
| (m["content"] for m in obj["messages"] if m.get("role") == "user"), "" | |
| ).strip() | |
| asst = next( | |
| (m["content"] for m in obj["messages"] if m.get("role") == "assistant"), "" | |
| ) | |
| gold = extract_final_answer_numeric_str(asst) or "" | |
| out.append(Problem(question=user, gold_final=gold.strip())) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Model loading β handles HF IDs, full-weight saves, and PEFT adapters | |
| # --------------------------------------------------------------------------- | |
| def _ensure_chat_template( | |
| tokenizer: AutoTokenizer, | |
| fallback_model: str = "Qwen/Qwen2.5-Math-1.5B-Instruct", | |
| ) -> None: | |
| """Load chat template from *fallback_model* when the checkpoint lacks one. | |
| SFT adapter checkpoints often omit the chat_template from their tokenizer | |
| config. Without it, ``apply_chat_template`` raises a TemplateError that | |
| is silently swallowed inside ``evaluate_gsm8k``, returning 0% accuracy. | |
| """ | |
| if tokenizer.chat_template is not None: | |
| return | |
| logger.info("Tokenizer missing chat_template β loading from %s", fallback_model) | |
| try: | |
| _base_tok = AutoTokenizer.from_pretrained(fallback_model, trust_remote_code=True) | |
| if _base_tok.chat_template is not None: | |
| tokenizer.chat_template = _base_tok.chat_template | |
| logger.info("Chat template loaded.") | |
| except Exception as exc: | |
| logger.warning("Could not load chat template: %s", exc) | |
| def _load_model( | |
| checkpoint: str, | |
| base_model_id: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| attn_impl: str, | |
| ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| """Load model + tokenizer from any checkpoint style. | |
| Handles: | |
| * HuggingFace model ID (e.g. ``Qwen/Qwen2.5-Math-1.5B-Instruct``) | |
| * GRPO full-weight save (directory with ``model.safetensors`` / pytorch_model*) | |
| * PEFT/LoRA adapter dir (directory with ``adapter_config.json``) | |
| """ | |
| # PEFT shim β prevents crash in merge_and_unload on some versions. | |
| if "transformers.integrations.tensor_parallel" not in sys.modules: | |
| sys.modules["transformers.integrations.tensor_parallel"] = types.ModuleType( | |
| "tensor_parallel" | |
| ) | |
| ckpt_path = Path(checkpoint) | |
| is_adapter = ckpt_path.is_dir() and (ckpt_path / "adapter_config.json").exists() | |
| is_local_full = ckpt_path.is_dir() and not is_adapter | |
| # Tokenizer | |
| tok_src = checkpoint if (ckpt_path.is_dir() and (ckpt_path / "tokenizer_config.json").exists()) else base_model_id | |
| tokenizer = AutoTokenizer.from_pretrained(tok_src, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" # standard for generation | |
| _ensure_chat_template(tokenizer, fallback_model=base_model_id) | |
| load_kw = dict( | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| device_map={"": device}, | |
| trust_remote_code=True, | |
| attn_implementation=attn_impl, | |
| ) | |
| if is_adapter: | |
| # Read base model from pipeline_meta.json if present | |
| meta_file = ckpt_path / "pipeline_meta.json" | |
| _base = base_model_id | |
| if meta_file.exists(): | |
| _base = json.loads(meta_file.read_text()).get("base_model", _base) | |
| logger.info("PEFT adapter β loading base %s then merging %s", _base, checkpoint) | |
| _base_mdl = AutoModelForCausalLM.from_pretrained(_base, **load_kw) | |
| model = PeftModel.from_pretrained(_base_mdl, checkpoint).merge_and_unload() | |
| model = model.to(device) | |
| else: | |
| # Full weights (GRPO save) or HF model ID | |
| src = checkpoint if is_local_full else checkpoint | |
| logger.info("Loading full-weight model from %s", src) | |
| model = AutoModelForCausalLM.from_pretrained(src, **load_kw) | |
| # Re-enable requires_grad isn't needed for eval, but ensure eval mode. | |
| model.eval() | |
| n = sum(p.numel() for p in model.parameters()) | |
| logger.info("Loaded: %s (%.2fB params, %.1f GB VRAM est.)", | |
| checkpoint, n / 1e9, n * 2 / 1e9) | |
| return model, tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def _build_prompt(tokenizer: AutoTokenizer, question: str) -> str: | |
| """Format question using the model's chat template (matches training format).""" | |
| if tokenizer.chat_template is None: | |
| return question | |
| msgs = [ | |
| {"role": "system", "content": "You are a helpful math assistant. Solve the problem step-by-step and end with 'Final Answer: <number>'."}, | |
| {"role": "user", "content": question}, | |
| ] | |
| try: | |
| return tokenizer.apply_chat_template( | |
| msgs, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| return question | |
| def _stop_ids(tokenizer: AutoTokenizer) -> List[int]: | |
| ids = [] | |
| if tokenizer.eos_token_id is not None: | |
| ids.append(tokenizer.eos_token_id) | |
| im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| if isinstance(im_end, int) and im_end not in ids: | |
| ids.append(im_end) | |
| return ids or None # type: ignore[return-value] | |
| def _generate( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| question: str, | |
| max_new_tokens: int, | |
| device: torch.device, | |
| ) -> str: | |
| prompt = _build_prompt(tokenizer, question) | |
| enc = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1024, | |
| ).to(device) | |
| prompt_len = enc["input_ids"].shape[1] | |
| out = model.generate( | |
| input_ids=enc["input_ids"], | |
| attention_mask=enc["attention_mask"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, # greedy β deterministic for reproducibility | |
| temperature=1.0, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| eos_token_id=_stop_ids(tokenizer), | |
| use_cache=True, | |
| ) | |
| return tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) | |
| # --------------------------------------------------------------------------- | |
| # Scoring | |
| # --------------------------------------------------------------------------- | |
| def _normalize(x: str) -> str: | |
| if not x: | |
| return "" | |
| s = x.strip().replace(",", "").replace("$", "").strip() | |
| try: | |
| f = float(s) | |
| return f"{int(f)}" if f == int(f) else f"{f}" | |
| except ValueError: | |
| return s | |
| class Record: | |
| question: str | |
| gold: str | |
| pred: str | |
| correct: bool | |
| solution_text: str | |
| def _score_model( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| problems: List[Problem], | |
| max_new_tokens: int, | |
| device: torch.device, | |
| label: str, | |
| ) -> Tuple[int, List[Record]]: | |
| records: List[Record] = [] | |
| correct = 0 | |
| for prob in tqdm(problems, desc=f"Scoring {label}", unit="q", dynamic_ncols=True): | |
| try: | |
| text = _generate(model, tokenizer, prob.question, max_new_tokens, device) | |
| except Exception as exc: | |
| text = f"[generation error: {exc}]" | |
| pred = extract_final_answer_numeric_str(text) or "" | |
| ok = bool(pred) and _normalize(pred) == _normalize(prob.gold_final) | |
| if ok: | |
| correct += 1 | |
| records.append(Record( | |
| question=prob.question, | |
| gold=prob.gold_final, | |
| pred=pred, | |
| correct=ok, | |
| solution_text=text, | |
| )) | |
| return correct, records | |
| # --------------------------------------------------------------------------- | |
| # Metrics curve | |
| # --------------------------------------------------------------------------- | |
| def _load_metrics_curve(path: Path) -> List[Dict]: | |
| """Read metrics.jsonl and return rows that contain GSM8K accuracy.""" | |
| rows = [] | |
| if not path.exists(): | |
| return rows | |
| with path.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| if "accuracy" in obj or "iteration" in obj: | |
| rows.append(obj) | |
| except json.JSONDecodeError: | |
| pass | |
| return rows | |
| def _print_curve(rows: List[Dict]) -> None: | |
| if not rows: | |
| return | |
| print(f"\n{_SEP}") | |
| print("TRAINING ACCURACY CURVE (from metrics.jsonl)") | |
| print(_SEP) | |
| print(f"{'Iter':>5} {'GSM8K%':>7} {'Reward':>7} {'Batch%':>7} {'LR':>10} {'Time(s)':>8}") | |
| print(_SEP2) | |
| for r in rows: | |
| it = r.get("iteration", "") | |
| acc = r.get("accuracy", None) | |
| rwd = r.get("mean_reward", None) | |
| bat = r.get("batch_accuracy", None) | |
| lr = r.get("learning_rate", None) | |
| ts = r.get("iter_time_s", None) | |
| acc_s = f"{100*acc:.1f}%" if acc is not None else "β" | |
| rwd_s = f"{rwd:.3f}" if rwd is not None else "β" | |
| bat_s = f"{100*bat:.1f}%" if bat is not None else "β" | |
| lr_s = f"{lr:.2e}" if lr is not None else "β" | |
| ts_s = f"{ts:.1f}" if ts is not None else "β" | |
| print(f"{it:>5} {acc_s:>7} {rwd_s:>7} {bat_s:>7} {lr_s:>10} {ts_s:>8}") | |
| print() | |
| # --------------------------------------------------------------------------- | |
| # Output | |
| # --------------------------------------------------------------------------- | |
| def _print_summary( | |
| base_correct: int, | |
| tr_correct: int, | |
| base_records: List[Record], | |
| tr_records: List[Record], | |
| baseline_name: str, | |
| trained_name: str, | |
| n_solutions: int = 3, | |
| ) -> None: | |
| n = len(base_records) | |
| wins = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if not b.correct and t.correct] | |
| losses = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if b.correct and not t.correct] | |
| both_wrong = sum(1 for b, t in zip(base_records, tr_records) if not b.correct and not t.correct) | |
| both_right = sum(1 for b, t in zip(base_records, tr_records) if b.correct and t.correct) | |
| delta = tr_correct - base_correct | |
| sign = "+" if delta >= 0 else "" | |
| print(f"\n{_SEP}") | |
| print("BEFORE vs AFTER β GSM8K accuracy (greedy decoding, fixed seed)") | |
| print(_SEP) | |
| print(f" Baseline : {baseline_name}") | |
| print(f" Trained : {trained_name}") | |
| print(_SEP2) | |
| print(f" Baseline accuracy : {base_correct}/{n} ({100*base_correct/n:.1f}%)") | |
| print(f" Trained accuracy : {tr_correct}/{n} ({100*tr_correct/n:.1f}%)") | |
| print(f" Delta : {sign}{delta} problems ({sign}{100*delta/n:.1f} pp)") | |
| print(_SEP2) | |
| print(f" Newly correct (wins) : {len(wins)}") | |
| print(f" Newly wrong (losses) : {len(losses)}") | |
| print(f" Both correct : {both_right}") | |
| print(f" Both wrong : {both_wrong}") | |
| print(_SEP) | |
| if wins: | |
| print(f"\n{'='*78}") | |
| print(f"WINS β problems the RL model now solves that the baseline could not") | |
| print(f"{'='*78}") | |
| for i, (_, base_r, tr_r) in enumerate(wins[:n_solutions]): | |
| print(f"\n[Win {i+1}/{min(n_solutions, len(wins))}]") | |
| _print_problem(base_r, tr_r) | |
| if losses: | |
| print(f"\n{'='*78}") | |
| print(f"REGRESSIONS β problems the baseline solved but the RL model now misses") | |
| print(f"{'='*78}") | |
| for i, (_, base_r, tr_r) in enumerate(losses[:min(2, len(losses))]): | |
| print(f"\n[Regression {i+1}/{min(2, len(losses))}]") | |
| _print_problem(base_r, tr_r, is_regression=True) | |
| print(f"\n{_SEP}") | |
| pct_gain = 100 * delta / max(n - base_correct, 1) | |
| print(f"SUMMARY: RL training fixed {len(wins)} problems, regressed {len(losses)}.") | |
| print(f" Net: {sign}{delta} pts. Relative gain on previously-wrong: {pct_gain:+.1f}%") | |
| print(_SEP) | |
| def _print_problem(base_r: Record, tr_r: Record, is_regression: bool = False) -> None: | |
| q = base_r.question | |
| # Truncate long questions | |
| if len(q) > 250: | |
| q = q[:247] + "..." | |
| print(f" Q : {q}") | |
| print(f" Gold : {base_r.gold}") | |
| if not is_regression: | |
| print(f" Before : {base_r.pred!r:30s} β") | |
| print(f" After : {tr_r.pred!r:30s} β") | |
| # Show trained solution (truncated) | |
| sol = tr_r.solution_text.strip() | |
| if sol: | |
| lines = sol.splitlines() | |
| show = "\n ".join(lines[:12]) | |
| if len(lines) > 12: | |
| show += f"\n ... ({len(lines)-12} more lines)" | |
| print(f"\n Solution (trained model):\n {show}") | |
| else: | |
| print(f" Before : {base_r.pred!r:30s} β") | |
| print(f" After : {tr_r.pred!r:30s} β") | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def main() -> int: | |
| parser = argparse.ArgumentParser( | |
| description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--baseline-model", default="checkpoints/dual_task_v1", | |
| help="Pre-RL checkpoint. HF model ID, full-weight dir, or PEFT adapter dir.", | |
| ) | |
| parser.add_argument( | |
| "--trained-model", required=True, | |
| help="Post-RL checkpoint (GRPO best_policy/ dir, or iteration checkpoint).", | |
| ) | |
| parser.add_argument( | |
| "--base-model-for-adapter", default="Qwen/Qwen2.5-Math-1.5B-Instruct", | |
| help="Base model used when loading a PEFT adapter checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--problems", type=Path, default=Path("data/sft/gsm8k_sft.jsonl"), | |
| help="JSONL eval set. Defaults to GSM8K training split (first --max-samples rows).", | |
| ) | |
| parser.add_argument("--max-samples", type=int, default=100) | |
| parser.add_argument("--max-new-tokens", type=int, default=512) | |
| parser.add_argument( | |
| "--metrics-jsonl", type=Path, default=None, | |
| help="Path to metrics.jsonl from a GRPO run β prints the accuracy curve.", | |
| ) | |
| parser.add_argument( | |
| "--n-solutions", type=int, default=3, | |
| help="Number of win/loss examples to print in full.", | |
| ) | |
| parser.add_argument( | |
| "--records-out", type=Path, default=None, | |
| help="Save full per-problem JSON records here (for judge grading scripts).", | |
| ) | |
| parser.add_argument( | |
| "--device", default="cuda" if torch.cuda.is_available() else "cpu", | |
| ) | |
| parser.add_argument( | |
| "--dtype", default="bfloat16", | |
| choices=["float32", "float16", "bfloat16"], | |
| ) | |
| args = parser.parse_args() | |
| if not args.problems.is_file(): | |
| logger.error("Problems file not found: %s", args.problems) | |
| return 2 | |
| dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} | |
| dtype = dtype_map[args.dtype] | |
| device = torch.device(args.device) | |
| attn = select_attn_implementation() | |
| logger.info("Device: %s | dtype: %s | attn: %s", device, args.dtype, attn) | |
| # Print training curve if available | |
| if args.metrics_jsonl: | |
| curve = _load_metrics_curve(args.metrics_jsonl) | |
| _print_curve(curve) | |
| problems = _load_problems(args.problems, args.max_samples) | |
| if not problems: | |
| logger.error("No problems loaded from %s", args.problems) | |
| return 2 | |
| logger.info("Evaluating on %d problems from %s", len(problems), args.problems) | |
| # ββ Baseline ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("%s\nScoring BASELINE: %s\n%s", _SEP, args.baseline_model, _SEP) | |
| t0 = time.perf_counter() | |
| base_model, base_tok = _load_model( | |
| args.baseline_model, args.base_model_for_adapter, device, dtype, attn | |
| ) | |
| base_correct, base_records = _score_model( | |
| base_model, base_tok, problems, args.max_new_tokens, device, "baseline" | |
| ) | |
| del base_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Baseline done in %.1fs β accuracy: %d/%d (%.1f%%)", | |
| time.perf_counter() - t0, | |
| base_correct, len(problems), | |
| 100 * base_correct / len(problems)) | |
| # ββ Trained βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("%s\nScoring TRAINED: %s\n%s", _SEP, args.trained_model, _SEP) | |
| t0 = time.perf_counter() | |
| tr_model, tr_tok = _load_model( | |
| args.trained_model, args.base_model_for_adapter, device, dtype, attn | |
| ) | |
| tr_correct, tr_records = _score_model( | |
| tr_model, tr_tok, problems, args.max_new_tokens, device, "trained" | |
| ) | |
| del tr_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Trained done in %.1fs β accuracy: %d/%d (%.1f%%)", | |
| time.perf_counter() - t0, | |
| tr_correct, len(problems), | |
| 100 * tr_correct / len(problems)) | |
| # ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _print_summary( | |
| base_correct, tr_correct, | |
| base_records, tr_records, | |
| baseline_name=args.baseline_model, | |
| trained_name=args.trained_model, | |
| n_solutions=args.n_solutions, | |
| ) | |
| # ββ Save records ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.records_out: | |
| args.records_out.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "baseline_model": args.baseline_model, | |
| "trained_model": args.trained_model, | |
| "n_problems": len(problems), | |
| "baseline": { | |
| "correct": base_correct, | |
| "accuracy": base_correct / len(problems), | |
| "records": [vars(r) for r in base_records], | |
| }, | |
| "trained": { | |
| "correct": tr_correct, | |
| "accuracy": tr_correct / len(problems), | |
| "records": [vars(r) for r in tr_records], | |
| }, | |
| } | |
| args.records_out.write_text( | |
| json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| logger.info("Per-problem records saved to %s", args.records_out) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |