File size: 10,740 Bytes
312d5c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """
Step 3: Compute Best-of-N accuracy with weighted selection.
Best-of-N weighted selection (from DeepMind 2408.03314, Section 5.1):
1. For each problem, we have N=16 solutions with PRM scores
2. Extract the final answer from each solution
3. Group solutions by their final answer string
4. Sum the PRM scores within each group (weighted vote)
5. Select the answer with the highest total weighted score
This is formally:
â = argmax_a Σᵢ 𝟙(aᵢ = a) · score(sᵢ)
Where score(sᵢ) is the PRM's last-step prediction for solution i.
Co-authored with Claude (Anthropic). I can explain all code logic.
"""
import json
from collections import defaultdict
def extract_boxed_solution(text):
"""Extract content of the last \\boxed{} in text."""
try:
start_index = text.rindex("\\boxed{")
content_start = start_index + 7
bracket_count = 1
current_pos = content_start
while bracket_count > 0 and current_pos < len(text):
if text[current_pos] == "{":
bracket_count += 1
elif text[current_pos] == "}":
bracket_count -= 1
current_pos += 1
if bracket_count == 0:
return text[content_start : current_pos - 1].strip()
return None
except (ValueError, Exception):
return None
def weighted_best_of_n(extracted_answers, prm_scores):
"""
Compute the Best-of-N answer using weighted selection.
Groups solutions by their extracted answer, sums PRM scores
per group, and returns the answer with the highest total score.
Args:
extracted_answers: list of N answer strings (may contain None)
prm_scores: list of N PRM scores (floats in [0,1])
Returns:
tuple: (best_answer, answer_scores_dict)
"""
answer_scores = defaultdict(float)
answer_counts = defaultdict(int)
for answer, score in zip(extracted_answers, prm_scores):
if answer is None:
# Skip solutions where we couldn't extract an answer
# (following DeepMind's filtering of unparseable solutions)
continue
answer_scores[answer] += score
answer_counts[answer] += 1
if not answer_scores:
return None, {}
# Select the answer with highest total weighted score
best_answer = max(answer_scores, key=answer_scores.get)
return best_answer, dict(answer_scores)
def standard_best_of_n(extracted_answers, prm_scores):
"""
Standard (non-weighted) Best-of-N: pick the single solution
with the highest PRM score and use its answer.
"""
best_idx = None
best_score = -1
for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)):
if answer is not None and score > best_score:
best_score = score
best_idx = i
if best_idx is not None:
return extracted_answers[best_idx]
return None
def majority_vote(extracted_answers):
"""
Pure majority vote (no reward weighting): pick the most frequent answer.
"""
counts = defaultdict(int)
for answer in extracted_answers:
if answer is not None:
counts[answer] += 1
if not counts:
return None
return max(counts, key=counts.get)
# ──────────────────────────────────────────────────────────────────────────────
# Load scored results
# ──────────────────────────────────────────────────────────────────────────────
print("=" * 70)
print("STEP 3: Computing Best-of-N accuracy with weighted selection")
print("=" * 70)
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
scored_results = json.load(f)
# Also load greedy results for comparison
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json") as f:
greedy_results = json.load(f)
# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N for each problem
# ──────────────────────────────────────────────────────────────────────────────
weighted_correct = 0
standard_correct = 0
majority_correct = 0
greedy_correct_count = 0
results_summary = []
for i, (scored, greedy) in enumerate(zip(scored_results, greedy_results)):
problem_id = scored["unique_id"]
ground_truth = scored["answer"]
# Extract answers from sampled solutions
extracted = scored["extracted_answers"]
scores = scored["prm_scores"]
# Weighted Best-of-N
weighted_answer, answer_scores = weighted_best_of_n(extracted, scores)
weighted_is_correct = (weighted_answer is not None) and (weighted_answer == ground_truth)
if weighted_is_correct:
weighted_correct += 1
# Standard Best-of-N (for comparison)
standard_answer = standard_best_of_n(extracted, scores)
standard_is_correct = (standard_answer is not None) and (standard_answer == ground_truth)
if standard_is_correct:
standard_correct += 1
# Majority vote (for comparison)
majority_answer = majority_vote(extracted)
majority_is_correct = (majority_answer is not None) and (majority_answer == ground_truth)
if majority_is_correct:
majority_correct += 1
# Greedy baseline
greedy_answer = greedy["greedy_extracted_answer"]
greedy_is_correct = greedy["greedy_correct"]
if greedy_is_correct:
greedy_correct_count += 1
# Count how many of the N solutions got the right answer
n_correct_in_sample = sum(1 for a in extracted if a == ground_truth)
# Summary for this problem
summary = {
"idx": i,
"unique_id": problem_id,
"level": scored["level"],
"subject": scored["subject"],
"ground_truth": ground_truth,
"greedy_answer": greedy_answer,
"greedy_correct": greedy_is_correct,
"weighted_bon_answer": weighted_answer,
"weighted_bon_correct": weighted_is_correct,
"standard_bon_answer": standard_answer,
"standard_bon_correct": standard_is_correct,
"majority_vote_answer": majority_answer,
"majority_vote_correct": majority_is_correct,
"n_correct_in_16": n_correct_in_sample,
"answer_score_breakdown": answer_scores,
"prm_scores": scores,
}
results_summary.append(summary)
# Print per-problem results
status_g = "✓" if greedy_is_correct else "✗"
status_w = "✓" if weighted_is_correct else "✗"
print(f"\n [{problem_id}] Level {scored['level']} | {scored['subject']}")
print(f" Ground truth: {ground_truth}")
print(f" Greedy {status_g}: {greedy_answer}")
print(f" Weighted BoN {status_w}: {weighted_answer}")
print(f" Correct in sample: {n_correct_in_sample}/{len(extracted)}")
if answer_scores:
print(f" Score breakdown: {dict(sorted(answer_scores.items(), key=lambda x: -x[1]))}")
# ──────────────────────────────────────────────────────────────────────────────
# Overall results
# ──────────────────────────────────────────────────────────────────────────────
n_problems = len(scored_results)
print("\n" + "=" * 70)
print("RESULTS SUMMARY")
print("=" * 70)
print(f" Greedy (N=1): {greedy_correct_count}/{n_problems} = {greedy_correct_count/n_problems:.1%}")
print(f" Majority Vote (N=16): {majority_correct}/{n_problems} = {majority_correct/n_problems:.1%}")
print(f" Standard Best-of-N (N=16): {standard_correct}/{n_problems} = {standard_correct/n_problems:.1%}")
print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n_problems} = {weighted_correct/n_problems:.1%}")
# Save results
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json", "w") as f:
json.dump(results_summary, f, indent=2)
print("\nSaved detailed results to outputs/bon_results.json")
# ──────────────────────────────────────────────────────────────────────────────
# Compute Best-of-N at various N values (using the N=16 sample)
# ──────────────────────────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("ANALYSIS: How accuracy varies with N")
print("=" * 70)
import random
random.seed(42)
n_values = [1, 2, 4, 8, 16]
n_trials = 50 # Average over multiple random subsets for N < 16
accuracy_by_n = {}
for n in n_values:
if n == 16:
# Use all solutions
correct = 0
for s in scored_results:
answer, _ = weighted_best_of_n(s["extracted_answers"], s["prm_scores"])
if answer == s["answer"]:
correct += 1
acc = correct / n_problems
else:
# Subsample and average over trials
trial_accs = []
for trial in range(n_trials):
correct = 0
for s in scored_results:
# Random subset of N solutions
indices = random.sample(range(16), n)
sub_answers = [s["extracted_answers"][j] for j in indices]
sub_scores = [s["prm_scores"][j] for j in indices]
answer, _ = weighted_best_of_n(sub_answers, sub_scores)
if answer == s["answer"]:
correct += 1
trial_accs.append(correct / n_problems)
acc = sum(trial_accs) / len(trial_accs)
accuracy_by_n[n] = acc
print(f" N={n:2d}: {acc:.1%}")
# Save accuracy-by-N for plotting
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json", "w") as f:
json.dump(accuracy_by_n, f, indent=2)
print("\nDone! Results saved. Run step4_analysis.py for plots.")
|