File size: 6,627 Bytes
86a6fb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | # Best-of-N Weighted Selection on MATH-500
> HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post.
## Overview
This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves:
1. Sampling $N$ independent solutions per problem from an LLM
2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward
3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote)
4. Selecting the answer with the highest total weighted score
Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$
## Models Used
| Model | Role | Size |
|---|---|---|
| [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B |
| [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B |
## Results
### Accuracy Comparison
| Method | Accuracy | Improvement over Greedy |
|---|---|---|
| Greedy (N=1) | 9/20 (45%) | — |
| Majority Vote (N=16) | 12/20 (60%) | +15pp |
| Standard Best-of-N (N=16) | 11/20 (55%) | +10pp |
| **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** |

### Accuracy Scales with N
Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16:
| N=1 | N=2 | N=4 | N=8 | N=16 |
|---|---|---|---|---|
| 51.5% | 58% | 63.6% | 65.3% | 65% |

### Per-Problem Analysis
Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none:
- `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct)
- `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal)
- `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here)
- `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct)

### PRM Score Distribution
The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores:

## Key Design Decisions
### Why Last-Step Prediction?
The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory.
### Why Weighted Selection over Standard Best-of-N?
Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score.
### Answer Matching Limitations
We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences:
- `\frac43` vs `\frac{4}{3}` — mathematically equivalent
- `4210_{5}` vs `4210_5` — just spacing difference
- `(5,\infty)` vs `(5, \infty)` — just a space
With proper normalization, all methods' accuracies would be higher.
## Repository Structure
```
├── README.md # This file
├── plots/
│ ├── plot1_accuracy_comparison.png # Bar chart: all methods
│ ├── plot2_accuracy_vs_n.png # Line chart: accuracy vs N
│ ├── plot3_per_problem.png # Per-problem breakdown
│ └── plot4_prm_scores.png # PRM score distributions
├── results/
│ ├── bon_results.json # Per-problem detailed results
│ ├── accuracy_by_n.json # Accuracy at each N value
│ └── filtered_problems.json # The 20 selected problems
└── code/
├── run_all.py # Complete pipeline (runs on GPU)
├── step1_filter_and_greedy.py # Step 1: Filter + greedy generation
├── step2_sample_and_score.py # Step 2: N=16 sampling + PRM scoring
├── step3_best_of_n.py # Step 3: BoN computation + N analysis
├── step4_analysis.py # Step 4: Plots and analysis
└── step5_push_dataset.py # Step 5: Push dataset to Hub
```
## Linked Resources
- **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers
- **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU
## References
- [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E
- [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5
- [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute)
- [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference)
- [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3)
## Co-authorship Note
This code was co-authored with Claude (Anthropic). I can explain all code logic in detail.
**Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code.
**My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation.
|