# Best-of-N Weighted Selection on MATH-500 > HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post. ## Overview This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves: 1. Sampling $N$ independent solutions per problem from an LLM 2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward 3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote) 4. Selecting the answer with the highest total weighted score Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$ ## Models Used | Model | Role | Size | |---|---|---| | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B | | [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B | ## Results ### Accuracy Comparison | Method | Accuracy | Improvement over Greedy | |---|---|---| | Greedy (N=1) | 9/20 (45%) | — | | Majority Vote (N=16) | 12/20 (60%) | +15pp | | Standard Best-of-N (N=16) | 11/20 (55%) | +10pp | | **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** | ![Accuracy Comparison](plots/plot1_accuracy_comparison.png) ### Accuracy Scales with N Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16: | N=1 | N=2 | N=4 | N=8 | N=16 | |---|---|---|---|---| | 51.5% | 58% | 63.6% | 65.3% | 65% | ![Accuracy vs N](plots/plot2_accuracy_vs_n.png) ### Per-Problem Analysis Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none: - `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct) - `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal) - `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here) - `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct) ![Per-Problem Analysis](plots/plot3_per_problem.png) ### PRM Score Distribution The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores: ![PRM Score Distribution](plots/plot4_prm_scores.png) ## Key Design Decisions ### Why Last-Step Prediction? The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory. ### Why Weighted Selection over Standard Best-of-N? Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score. ### Answer Matching Limitations We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences: - `\frac43` vs `\frac{4}{3}` — mathematically equivalent - `4210_{5}` vs `4210_5` — just spacing difference - `(5,\infty)` vs `(5, \infty)` — just a space With proper normalization, all methods' accuracies would be higher. ## Repository Structure ``` ├── README.md # This file ├── plots/ │ ├── plot1_accuracy_comparison.png # Bar chart: all methods │ ├── plot2_accuracy_vs_n.png # Line chart: accuracy vs N │ ├── plot3_per_problem.png # Per-problem breakdown │ └── plot4_prm_scores.png # PRM score distributions ├── results/ │ ├── bon_results.json # Per-problem detailed results │ ├── accuracy_by_n.json # Accuracy at each N value │ └── filtered_problems.json # The 20 selected problems └── code/ ├── run_all.py # Complete pipeline (runs on GPU) ├── step1_filter_and_greedy.py # Step 1: Filter + greedy generation ├── step2_sample_and_score.py # Step 2: N=16 sampling + PRM scoring ├── step3_best_of_n.py # Step 3: BoN computation + N analysis ├── step4_analysis.py # Step 4: Plots and analysis └── step5_push_dataset.py # Step 5: Push dataset to Hub ``` ## Linked Resources - **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers - **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU ## References - [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E - [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5 - [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) - [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference) - [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3) ## Co-authorship Note This code was co-authored with Claude (Anthropic). I can explain all code logic in detail. **Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code. **My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation.