cmpatino HF Staff commited on
Commit
312d5c4
·
verified ·
1 Parent(s): 02efe1b

Upload code/step3_best_of_n.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/step3_best_of_n.py +259 -0
code/step3_best_of_n.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 3: Compute Best-of-N accuracy with weighted selection.
3
+
4
+ Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1):
5
+ 1. For each problem, we have N=16 solutions with PRM scores
6
+ 2. Extract the final answer from each solution
7
+ 3. Group solutions by their final answer string
8
+ 4. Sum the PRM scores within each group (weighted vote)
9
+ 5. Select the answer with the highest total weighted score
10
+
11
+ This is formally:
12
+ â = argmax_a Σᵢ 𝟙(aᵢ = a) · score(sᵢ)
13
+
14
+ Where score(sᵢ) is the PRM's last-step prediction for solution i.
15
+
16
+ Co-authored with Claude (Anthropic). I can explain all code logic.
17
+ """
18
+
19
+ import json
20
+ from collections import defaultdict
21
+
22
+
23
+ def extract_boxed_solution(text):
24
+ """Extract content of the last \\boxed{} in text."""
25
+ try:
26
+ start_index = text.rindex("\\boxed{")
27
+ content_start = start_index + 7
28
+ bracket_count = 1
29
+ current_pos = content_start
30
+ while bracket_count > 0 and current_pos < len(text):
31
+ if text[current_pos] == "{":
32
+ bracket_count += 1
33
+ elif text[current_pos] == "}":
34
+ bracket_count -= 1
35
+ current_pos += 1
36
+ if bracket_count == 0:
37
+ return text[content_start : current_pos - 1].strip()
38
+ return None
39
+ except (ValueError, Exception):
40
+ return None
41
+
42
+
43
+ def weighted_best_of_n(extracted_answers, prm_scores):
44
+ """
45
+ Compute the Best-of-N answer using weighted selection.
46
+
47
+ Groups solutions by their extracted answer, sums PRM scores
48
+ per group, and returns the answer with the highest total score.
49
+
50
+ Args:
51
+ extracted_answers: list of N answer strings (may contain None)
52
+ prm_scores: list of N PRM scores (floats in [0,1])
53
+
54
+ Returns:
55
+ tuple: (best_answer, answer_scores_dict)
56
+ """
57
+ answer_scores = defaultdict(float)
58
+ answer_counts = defaultdict(int)
59
+
60
+ for answer, score in zip(extracted_answers, prm_scores):
61
+ if answer is None:
62
+ # Skip solutions where we couldn't extract an answer
63
+ # (following DeepMind's filtering of unparseable solutions)
64
+ continue
65
+ answer_scores[answer] += score
66
+ answer_counts[answer] += 1
67
+
68
+ if not answer_scores:
69
+ return None, {}
70
+
71
+ # Select the answer with highest total weighted score
72
+ best_answer = max(answer_scores, key=answer_scores.get)
73
+ return best_answer, dict(answer_scores)
74
+
75
+
76
+ def standard_best_of_n(extracted_answers, prm_scores):
77
+ """
78
+ Standard (non-weighted) Best-of-N: pick the single solution
79
+ with the highest PRM score and use its answer.
80
+ """
81
+ best_idx = None
82
+ best_score = -1
83
+ for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)):
84
+ if answer is not None and score > best_score:
85
+ best_score = score
86
+ best_idx = i
87
+ if best_idx is not None:
88
+ return extracted_answers[best_idx]
89
+ return None
90
+
91
+
92
+ def majority_vote(extracted_answers):
93
+ """
94
+ Pure majority vote (no reward weighting): pick the most frequent answer.
95
+ """
96
+ counts = defaultdict(int)
97
+ for answer in extracted_answers:
98
+ if answer is not None:
99
+ counts[answer] += 1
100
+ if not counts:
101
+ return None
102
+ return max(counts, key=counts.get)
103
+
104
+
105
+ # ──────────────────────────────────────────────────────────────────────────────
106
+ # Load scored results
107
+ # ──────────────────────────────────────────────────────────────────────────────
108
+ print("=" * 70)
109
+ print("STEP 3: Computing Best-of-N accuracy with weighted selection")
110
+ print("=" * 70)
111
+
112
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
113
+ scored_results = json.load(f)
114
+
115
+ # Also load greedy results for comparison
116
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f:
117
+ greedy_results = json.load(f)
118
+
119
+ # ──────────────────────────────────────────────────────────────────────────────
120
+ # Compute Best-of-N for each problem
121
+ # ──────────────────────────────────────────────────────────────────────────────
122
+ weighted_correct = 0
123
+ standard_correct = 0
124
+ majority_correct = 0
125
+ greedy_correct_count = 0
126
+
127
+ results_summary = []
128
+
129
+ for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)):
130
+ problem_id = scored["unique_id"]
131
+ ground_truth = scored["answer"]
132
+
133
+ # Extract answers from sampled solutions
134
+ extracted = scored["extracted_answers"]
135
+ scores = scored["prm_scores"]
136
+
137
+ # Weighted Best-of-N
138
+ weighted_answer, answer_scores = weighted_best_of_n(extracted, scores)
139
+ weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth)
140
+ if weighted_is_correct:
141
+ weighted_correct += 1
142
+
143
+ # Standard Best-of-N (for comparison)
144
+ standard_answer = standard_best_of_n(extracted, scores)
145
+ standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth)
146
+ if standard_is_correct:
147
+ standard_correct += 1
148
+
149
+ # Majority vote (for comparison)
150
+ majority_answer = majority_vote(extracted)
151
+ majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth)
152
+ if majority_is_correct:
153
+ majority_correct += 1
154
+
155
+ # Greedy baseline
156
+ greedy_answer = greedy["greedy_extracted_answer"]
157
+ greedy_is_correct = greedy["greedy_correct"]
158
+ if greedy_is_correct:
159
+ greedy_correct_count += 1
160
+
161
+ # Count how many of the N solutions got the right answer
162
+ n_correct_in_sample = sum(1 for a in extracted if a == ground_truth)
163
+
164
+ # Summary for this problem
165
+ summary = {
166
+ "idx": i,
167
+ "unique_id": problem_id,
168
+ "level": scored["level"],
169
+ "subject": scored["subject"],
170
+ "ground_truth": ground_truth,
171
+ "greedy_answer": greedy_answer,
172
+ "greedy_correct": greedy_is_correct,
173
+ "weighted_bon_answer": weighted_answer,
174
+ "weighted_bon_correct": weighted_is_correct,
175
+ "standard_bon_answer": standard_answer,
176
+ "standard_bon_correct": standard_is_correct,
177
+ "majority_vote_answer": majority_answer,
178
+ "majority_vote_correct": majority_is_correct,
179
+ "n_correct_in_16": n_correct_in_sample,
180
+ "answer_score_breakdown": answer_scores,
181
+ "prm_scores": scores,
182
+ }
183
+ results_summary.append(summary)
184
+
185
+ # Print per-problem results
186
+ status_g = "✓" if greedy_is_correct else "✗"
187
+ status_w = "✓" if weighted_is_correct else "✗"
188
+ print(f"\n [{problem_id}] Level {scored['level']} | {scored['subject']}")
189
+ print(f" Ground truth: {ground_truth}")
190
+ print(f" Greedy {status_g}: {greedy_answer}")
191
+ print(f" Weighted BoN {status_w}: {weighted_answer}")
192
+ print(f" Correct in sample: {n_correct_in_sample}/{len(extracted)}")
193
+ if answer_scores:
194
+ print(f" Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}")
195
+
196
+ # ──────────────────────────────────────────────────────────────────────────────
197
+ # Overall results
198
+ # ──────────────────────────────────────────────────────────────────────────────
199
+ n_problems = len(scored_results)
200
+ print("\n" + "=" * 70)
201
+ print("RESULTS SUMMARY")
202
+ print("=" * 70)
203
+ print(f" Greedy (N=1): {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}")
204
+ print(f" Majority Vote (N=16): {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}")
205
+ print(f" Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}")
206
+ print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}")
207
+
208
+ # Save results
209
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f:
210
+ json.dump(results_summary, f, indent=2)
211
+ print("\nSaved detailed results to outputs/bon_results.json")
212
+
213
+ # ──────────────────────────────────────────────────────────────────────────────
214
+ # Compute Best-of-N at various N values (using the N=16 sample)
215
+ # ──────────────────────────────────────────────────────────────────────────────
216
+ print("\n" + "=" * 70)
217
+ print("ANALYSIS: How accuracy varies with N")
218
+ print("=" * 70)
219
+
220
+ import random
221
+ random.seed(42)
222
+
223
+ n_values = [1, 2, 4, 8, 16]
224
+ n_trials = 50 # Average over multiple random subsets for N < 16
225
+
226
+ accuracy_by_n = {}
227
+ for n in n_values:
228
+ if n == 16:
229
+ # Use all solutions
230
+ correct = 0
231
+ for s in scored_results:
232
+ answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"])
233
+ if answer == s["answer"]:
234
+ correct += 1
235
+ acc = correct / n_problems
236
+ else:
237
+ # Subsample and average over trials
238
+ trial_accs = []
239
+ for trial in range(n_trials):
240
+ correct = 0
241
+ for s in scored_results:
242
+ # Random subset of N solutions
243
+ indices = random.sample(range(16), n)
244
+ sub_answers = [s["extracted_answers"][j] for j in indices]
245
+ sub_scores = [s["prm_scores"][j] for j in indices]
246
+ answer, _ = weighted_best_of_n(sub_answers, sub_scores)
247
+ if answer == s["answer"]:
248
+ correct += 1
249
+ trial_accs.append(correct / n_problems)
250
+ acc = sum(trial_accs) / len(trial_accs)
251
+
252
+ accuracy_by_n[n] = acc
253
+ print(f" N={n:2d}: {acc:.1%}")
254
+
255
+ # Save accuracy-by-N for plotting
256
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f:
257
+ json.dump(accuracy_by_n, f, indent=2)
258
+
259
+ print("\nDone! Results saved. Run step4_analysis.py for plots.")