cmpatino HF Staff commited on
Commit
86a6fb2
·
verified ·
1 Parent(s): e82a87c

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +123 -0
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Best-of-N Weighted Selection on MATH-500
2
+
3
+ > HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post.
4
+
5
+ ## Overview
6
+
7
+ This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves:
8
+
9
+ 1. Sampling $N$ independent solutions per problem from an LLM
10
+ 2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward
11
+ 3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote)
12
+ 4. Selecting the answer with the highest total weighted score
13
+
14
+ Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$
15
+
16
+ ## Models Used
17
+
18
+ | Model | Role | Size |
19
+ |---|---|---|
20
+ | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B |
21
+ | [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B |
22
+
23
+ ## Results
24
+
25
+ ### Accuracy Comparison
26
+
27
+ | Method | Accuracy | Improvement over Greedy |
28
+ |---|---|---|
29
+ | Greedy (N=1) | 9/20 (45%) | — |
30
+ | Majority Vote (N=16) | 12/20 (60%) | +15pp |
31
+ | Standard Best-of-N (N=16) | 11/20 (55%) | +10pp |
32
+ | **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** |
33
+
34
+ ![Accuracy Comparison](plots/plot1_accuracy_comparison.png)
35
+
36
+ ### Accuracy Scales with N
37
+
38
+ Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16:
39
+
40
+ | N=1 | N=2 | N=4 | N=8 | N=16 |
41
+ |---|---|---|---|---|
42
+ | 51.5% | 58% | 63.6% | 65.3% | 65% |
43
+
44
+ ![Accuracy vs N](plots/plot2_accuracy_vs_n.png)
45
+
46
+ ### Per-Problem Analysis
47
+
48
+ Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none:
49
+
50
+ - `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct)
51
+ - `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal)
52
+ - `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here)
53
+ - `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct)
54
+
55
+ ![Per-Problem Analysis](plots/plot3_per_problem.png)
56
+
57
+ ### PRM Score Distribution
58
+
59
+ The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores:
60
+
61
+ ![PRM Score Distribution](plots/plot4_prm_scores.png)
62
+
63
+ ## Key Design Decisions
64
+
65
+ ### Why Last-Step Prediction?
66
+
67
+ The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory.
68
+
69
+ ### Why Weighted Selection over Standard Best-of-N?
70
+
71
+ Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score.
72
+
73
+ ### Answer Matching Limitations
74
+
75
+ We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences:
76
+ - `\frac43` vs `\frac{4}{3}` — mathematically equivalent
77
+ - `4210_{5}` vs `4210_5` — just spacing difference
78
+ - `(5,\infty)` vs `(5, \infty)` — just a space
79
+
80
+ With proper normalization, all methods' accuracies would be higher.
81
+
82
+ ## Repository Structure
83
+
84
+ ```
85
+ ├── README.md # This file
86
+ ├── plots/
87
+ │ ├── plot1_accuracy_comparison.png # Bar chart: all methods
88
+ │ ├── plot2_accuracy_vs_n.png # Line chart: accuracy vs N
89
+ │ ├── plot3_per_problem.png # Per-problem breakdown
90
+ │ └── plot4_prm_scores.png # PRM score distributions
91
+ ├── results/
92
+ │ ├── bon_results.json # Per-problem detailed results
93
+ │ ├── accuracy_by_n.json # Accuracy at each N value
94
+ │ └── filtered_problems.json # The 20 selected problems
95
+ └── code/
96
+ ├── run_all.py # Complete pipeline (runs on GPU)
97
+ ├── step1_filter_and_greedy.py # Step 1: Filter + greedy generation
98
+ ├── step2_sample_and_score.py # Step 2: N=16 sampling + PRM scoring
99
+ ├── step3_best_of_n.py # Step 3: BoN computation + N analysis
100
+ ├── step4_analysis.py # Step 4: Plots and analysis
101
+ └── step5_push_dataset.py # Step 5: Push dataset to Hub
102
+ ```
103
+
104
+ ## Linked Resources
105
+
106
+ - **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers
107
+ - **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU
108
+
109
+ ## References
110
+
111
+ - [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E
112
+ - [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5
113
+ - [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute)
114
+ - [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference)
115
+ - [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3)
116
+
117
+ ## Co-authorship Note
118
+
119
+ This code was co-authored with Claude (Anthropic). I can explain all code logic in detail.
120
+
121
+ **Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code.
122
+
123
+ **My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation.