Upload README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Best-of-N Weighted Selection on MATH-500
|
| 2 |
+
|
| 3 |
+
> HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves:
|
| 8 |
+
|
| 9 |
+
1. Sampling $N$ independent solutions per problem from an LLM
|
| 10 |
+
2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward
|
| 11 |
+
3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote)
|
| 12 |
+
4. Selecting the answer with the highest total weighted score
|
| 13 |
+
|
| 14 |
+
Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$
|
| 15 |
+
|
| 16 |
+
## Models Used
|
| 17 |
+
|
| 18 |
+
| Model | Role | Size |
|
| 19 |
+
|---|---|---|
|
| 20 |
+
| [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B |
|
| 21 |
+
| [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B |
|
| 22 |
+
|
| 23 |
+
## Results
|
| 24 |
+
|
| 25 |
+
### Accuracy Comparison
|
| 26 |
+
|
| 27 |
+
| Method | Accuracy | Improvement over Greedy |
|
| 28 |
+
|---|---|---|
|
| 29 |
+
| Greedy (N=1) | 9/20 (45%) | — |
|
| 30 |
+
| Majority Vote (N=16) | 12/20 (60%) | +15pp |
|
| 31 |
+
| Standard Best-of-N (N=16) | 11/20 (55%) | +10pp |
|
| 32 |
+
| **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** |
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
### Accuracy Scales with N
|
| 37 |
+
|
| 38 |
+
Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16:
|
| 39 |
+
|
| 40 |
+
| N=1 | N=2 | N=4 | N=8 | N=16 |
|
| 41 |
+
|---|---|---|---|---|
|
| 42 |
+
| 51.5% | 58% | 63.6% | 65.3% | 65% |
|
| 43 |
+
|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
### Per-Problem Analysis
|
| 47 |
+
|
| 48 |
+
Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none:
|
| 49 |
+
|
| 50 |
+
- `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct)
|
| 51 |
+
- `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal)
|
| 52 |
+
- `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here)
|
| 53 |
+
- `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct)
|
| 54 |
+
|
| 55 |
+

|
| 56 |
+
|
| 57 |
+
### PRM Score Distribution
|
| 58 |
+
|
| 59 |
+
The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores:
|
| 60 |
+
|
| 61 |
+

|
| 62 |
+
|
| 63 |
+
## Key Design Decisions
|
| 64 |
+
|
| 65 |
+
### Why Last-Step Prediction?
|
| 66 |
+
|
| 67 |
+
The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory.
|
| 68 |
+
|
| 69 |
+
### Why Weighted Selection over Standard Best-of-N?
|
| 70 |
+
|
| 71 |
+
Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score.
|
| 72 |
+
|
| 73 |
+
### Answer Matching Limitations
|
| 74 |
+
|
| 75 |
+
We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences:
|
| 76 |
+
- `\frac43` vs `\frac{4}{3}` — mathematically equivalent
|
| 77 |
+
- `4210_{5}` vs `4210_5` — just spacing difference
|
| 78 |
+
- `(5,\infty)` vs `(5, \infty)` — just a space
|
| 79 |
+
|
| 80 |
+
With proper normalization, all methods' accuracies would be higher.
|
| 81 |
+
|
| 82 |
+
## Repository Structure
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
├── README.md # This file
|
| 86 |
+
├── plots/
|
| 87 |
+
│ ├── plot1_accuracy_comparison.png # Bar chart: all methods
|
| 88 |
+
│ ├── plot2_accuracy_vs_n.png # Line chart: accuracy vs N
|
| 89 |
+
│ ├── plot3_per_problem.png # Per-problem breakdown
|
| 90 |
+
│ └── plot4_prm_scores.png # PRM score distributions
|
| 91 |
+
├── results/
|
| 92 |
+
│ ├── bon_results.json # Per-problem detailed results
|
| 93 |
+
│ ├── accuracy_by_n.json # Accuracy at each N value
|
| 94 |
+
│ └── filtered_problems.json # The 20 selected problems
|
| 95 |
+
└── code/
|
| 96 |
+
├── run_all.py # Complete pipeline (runs on GPU)
|
| 97 |
+
├── step1_filter_and_greedy.py # Step 1: Filter + greedy generation
|
| 98 |
+
├── step2_sample_and_score.py # Step 2: N=16 sampling + PRM scoring
|
| 99 |
+
├── step3_best_of_n.py # Step 3: BoN computation + N analysis
|
| 100 |
+
├── step4_analysis.py # Step 4: Plots and analysis
|
| 101 |
+
└── step5_push_dataset.py # Step 5: Push dataset to Hub
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Linked Resources
|
| 105 |
+
|
| 106 |
+
- **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers
|
| 107 |
+
- **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU
|
| 108 |
+
|
| 109 |
+
## References
|
| 110 |
+
|
| 111 |
+
- [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E
|
| 112 |
+
- [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5
|
| 113 |
+
- [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute)
|
| 114 |
+
- [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference)
|
| 115 |
+
- [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3)
|
| 116 |
+
|
| 117 |
+
## Co-authorship Note
|
| 118 |
+
|
| 119 |
+
This code was co-authored with Claude (Anthropic). I can explain all code logic in detail.
|
| 120 |
+
|
| 121 |
+
**Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code.
|
| 122 |
+
|
| 123 |
+
**My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation.
|