| # Best-of-N Weighted Selection on MATH-500 |
|
|
| > HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post. |
|
|
| ## Overview |
|
|
| This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves: |
|
|
| 1. Sampling $N$ independent solutions per problem from an LLM |
| 2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward |
| 3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote) |
| 4. Selecting the answer with the highest total weighted score |
|
|
| Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$ |
|
|
| ## Models Used |
|
|
| | Model | Role | Size | |
| |---|---|---| |
| | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B | |
| | [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B | |
|
|
| ## Results |
|
|
| ### Accuracy Comparison |
|
|
| | Method | Accuracy | Improvement over Greedy | |
| |---|---|---| |
| | Greedy (N=1) | 9/20 (45%) | — | |
| | Majority Vote (N=16) | 12/20 (60%) | +15pp | |
| | Standard Best-of-N (N=16) | 11/20 (55%) | +10pp | |
| | **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** | |
|
|
|  |
|
|
| ### Accuracy Scales with N |
|
|
| Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16: |
|
|
| | N=1 | N=2 | N=4 | N=8 | N=16 | |
| |---|---|---|---|---| |
| | 51.5% | 58% | 63.6% | 65.3% | 65% | |
|
|
|  |
|
|
| ### Per-Problem Analysis |
|
|
| Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none: |
|
|
| - `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct) |
| - `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal) |
| - `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here) |
| - `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct) |
|
|
|  |
|
|
| ### PRM Score Distribution |
|
|
| The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores: |
|
|
|  |
|
|
| ## Key Design Decisions |
|
|
| ### Why Last-Step Prediction? |
|
|
| The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory. |
|
|
| ### Why Weighted Selection over Standard Best-of-N? |
|
|
| Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score. |
|
|
| ### Answer Matching Limitations |
|
|
| We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences: |
| - `\frac43` vs `\frac{4}{3}` — mathematically equivalent |
| - `4210_{5}` vs `4210_5` — just spacing difference |
| - `(5,\infty)` vs `(5, \infty)` — just a space |
|
|
| With proper normalization, all methods' accuracies would be higher. |
|
|
| ## Repository Structure |
|
|
| ``` |
| ├── README.md # This file |
| ├── plots/ |
| │ ├── plot1_accuracy_comparison.png # Bar chart: all methods |
| │ ├── plot2_accuracy_vs_n.png # Line chart: accuracy vs N |
| │ ├── plot3_per_problem.png # Per-problem breakdown |
| │ └── plot4_prm_scores.png # PRM score distributions |
| ├── results/ |
| │ ├── bon_results.json # Per-problem detailed results |
| │ ├── accuracy_by_n.json # Accuracy at each N value |
| │ └── filtered_problems.json # The 20 selected problems |
| └── code/ |
| ├── run_all.py # Complete pipeline (runs on GPU) |
| ├── step1_filter_and_greedy.py # Step 1: Filter + greedy generation |
| ├── step2_sample_and_score.py # Step 2: N=16 sampling + PRM scoring |
| ├── step3_best_of_n.py # Step 3: BoN computation + N analysis |
| ├── step4_analysis.py # Step 4: Plots and analysis |
| └── step5_push_dataset.py # Step 5: Push dataset to Hub |
| ``` |
|
|
| ## Linked Resources |
|
|
| - **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers |
| - **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU |
|
|
| ## References |
|
|
| - [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E |
| - [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5 |
| - [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) |
| - [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference) |
| - [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3) |
|
|
| ## Co-authorship Note |
|
|
| This code was co-authored with Claude (Anthropic). I can explain all code logic in detail. |
|
|
| **Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code. |
|
|
| **My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation. |
|
|