File size: 6,627 Bytes
86a6fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Best-of-N Weighted Selection on MATH-500

> HuggingFace internship exercise — replicating a baseline from the [Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) blog post.

## Overview

This repo contains the complete code, results, and analysis for **Best-of-N sampling with weighted selection** on math problems. The approach was introduced in [DeepMind (2408.03314)](https://huggingface.co/papers/2408.03314) and involves:

1. Sampling $N$ independent solutions per problem from an LLM
2. Scoring each solution with a **Process Reward Model (PRM)** — using the **last step prediction** as the final reward
3. Grouping solutions by their final answer and **summing PRM scores per group** (weighted vote)
4. Selecting the answer with the highest total weighted score

Formally: $\hat{a} = \arg\max_a \sum_{i=1}^{N} \mathbb{1}(a_i = a) \cdot \text{score}(s_i)$

## Models Used

| Model | Role | Size |
|---|---|---|
| [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Solution generator (LLM) | 1.5B |
| [Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B](https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B) | Process Reward Model (scorer) | 1.5B |

## Results

### Accuracy Comparison

| Method | Accuracy | Improvement over Greedy |
|---|---|---|
| Greedy (N=1) | 9/20 (45%) | — |
| Majority Vote (N=16) | 12/20 (60%) | +15pp |
| Standard Best-of-N (N=16) | 11/20 (55%) | +10pp |
| **Weighted Best-of-N (N=16)** | **13/20 (65%)** | **+20pp** |

![Accuracy Comparison](plots/plot1_accuracy_comparison.png)

### Accuracy Scales with N

Weighted Best-of-N accuracy improves monotonically as we sample more solutions, plateauing around N=8-16:

| N=1 | N=2 | N=4 | N=8 | N=16 |
|---|---|---|---|---|
| 51.5% | 58% | 63.6% | 65.3% | 65% |

![Accuracy vs N](plots/plot2_accuracy_vs_n.png)

### Per-Problem Analysis

Weighted Best-of-N solved **4 additional problems** that greedy decoding couldn't, while losing none:

- `algebra/1265` — Greedy answered 3, BoN correctly found 3/2 (12/16 samples correct)
- `intermediate_algebra/860` — Greedy failed to produce \boxed{}, BoN identified "ellipse" (5/16 correct, but weighted voting aggregated enough signal)
- `number_theory/22` — Greedy answered 6, BoN correctly found 2 (5/16 correct — weighted voting beat majority vote here)
- `number_theory/45` — Greedy answered 15, BoN correctly found 23 (8/16 correct)

![Per-Problem Analysis](plots/plot3_per_problem.png)

### PRM Score Distribution

The PRM effectively separates correct from incorrect solutions — correct solutions cluster at higher scores:

![PRM Score Distribution](plots/plot4_prm_scores.png)

## Key Design Decisions

### Why Last-Step Prediction?

The Skywork PRM outputs a score at each reasoning step. Following DeepMind Appendix E, we use only the **last step's score** as the full-solution reward, rather than min or product of all step scores. This works because the PRM was trained with soft Monte Carlo return labels — the last step already integrates information about the full solution trajectory.

### Why Weighted Selection over Standard Best-of-N?

Standard Best-of-N picks the single solution with the highest PRM score. This can be fooled by a single high-scoring wrong solution. Weighted selection is more robust: a correct answer appearing in multiple solutions accumulates evidence through summed PRM scores. This is especially visible in problem `number_theory/22`, where the correct answer "2" appeared in only 5/16 samples but had the highest total weighted score.

### Answer Matching Limitations

We use exact string matching for answer comparison (no SymPy normalization). This means several problems are marked "incorrect" due to LaTeX formatting differences:
- `\frac43` vs `\frac{4}{3}` — mathematically equivalent
- `4210_{5}` vs `4210_5` — just spacing difference
- `(5,\infty)` vs `(5, \infty)` — just a space

With proper normalization, all methods' accuracies would be higher.

## Repository Structure

```
├── README.md                          # This file
├── plots/
│   ├── plot1_accuracy_comparison.png  # Bar chart: all methods
│   ├── plot2_accuracy_vs_n.png        # Line chart: accuracy vs N
│   ├── plot3_per_problem.png          # Per-problem breakdown
│   └── plot4_prm_scores.png           # PRM score distributions
├── results/
│   ├── bon_results.json               # Per-problem detailed results
│   ├── accuracy_by_n.json             # Accuracy at each N value
│   └── filtered_problems.json         # The 20 selected problems
└── code/
    ├── run_all.py                     # Complete pipeline (runs on GPU)
    ├── step1_filter_and_greedy.py     # Step 1: Filter + greedy generation
    ├── step2_sample_and_score.py      # Step 2: N=16 sampling + PRM scoring
    ├── step3_best_of_n.py             # Step 3: BoN computation + N analysis
    ├── step4_analysis.py              # Step 4: Plots and analysis
    └── step5_push_dataset.py          # Step 5: Push dataset to Hub
```

## Linked Resources

- **Results dataset**: [cmpatino/math500-bon-weighted-results](https://huggingface.co/datasets/cmpatino/math500-bon-weighted-results) — contains all 20 problems with greedy solutions, 16 sampled solutions each, PRM scores, and Best-of-N answers
- **Execution Space**: [cmpatino/math500-bon-exercise](https://huggingface.co/spaces/cmpatino/math500-bon-exercise) — Docker Space used to run the pipeline on a T4 GPU

## References

- [Scaling LLM Test-Time Compute (DeepMind, 2408.03314)](https://huggingface.co/papers/2408.03314) — Section 5.1 + Appendix E
- [Math-Shepherd (2312.08935)](https://huggingface.co/papers/2312.08935) — Section 3.4, Eq. 5
- [HF Blog: Scaling Test-Time Compute with Open Models](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute)
- [Skywork PRM inference repo](https://github.com/SkyworkAI/skywork-o1-prm-inference)
- [Answer extraction helper](https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3)

## Co-authorship Note

This code was co-authored with Claude (Anthropic). I can explain all code logic in detail.

**Claude-assisted areas**: Pipeline structure, Skywork PRM model loading, weighted voting implementation, plotting code.

**My contributions**: Paper methodology analysis, last-step prediction choice (Appendix E), Space deployment debugging (health check server, memory management), results analysis and interpretation.