cmpatino HF Staff commited on
Commit
a98a175
·
verified ·
1 Parent(s): 312d5c4

Upload code/step4_analysis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/step4_analysis.py +260 -0
code/step4_analysis.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 4: Analysis and visualization of Best-of-N vs greedy performance.
3
+
4
+ This script creates plots comparing:
5
+ 1. Overall accuracy: Greedy vs Majority Vote vs Standard BoN vs Weighted BoN
6
+ 2. Accuracy vs N (how performance scales with number of samples)
7
+ 3. Per-problem analysis: which problems did BoN solve that greedy couldn't?
8
+ 4. PRM score distribution analysis
9
+
10
+ Co-authored with Claude (Anthropic). I can explain all code logic.
11
+ """
12
+
13
+ import json
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib
16
+ import numpy as np
17
+ from collections import defaultdict
18
+
19
+ matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150})
20
+
21
+ # ──────────────────────────────────────────────────────────────────────────────
22
+ # Load results
23
+ # ──────────────────────────────────────────────────────────────────────────────
24
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json") as f:
25
+ bon_results = json.load(f)
26
+
27
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json") as f:
28
+ accuracy_by_n = json.load(f)
29
+
30
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
31
+ scored_results = json.load(f)
32
+
33
+ n_problems = len(bon_results)
34
+
35
+ # ──────────────────────────────────────────────────────────────────────────────
36
+ # Plot 1: Overall accuracy comparison (bar chart)
37
+ # ──────────────────────────────────────────────────────────────────────────────
38
+ fig, ax = plt.subplots(figsize=(8, 5))
39
+
40
+ methods = ["Greedy\n(N=1)", "Majority Vote\n(N=16)", "Standard BoN\n(N=16)", "Weighted BoN\n(N=16)"]
41
+ accuracies = [
42
+ sum(r["greedy_correct"] for r in bon_results) / n_problems,
43
+ sum(r["majority_vote_correct"] for r in bon_results) / n_problems,
44
+ sum(r["standard_bon_correct"] for r in bon_results) / n_problems,
45
+ sum(r["weighted_bon_correct"] for r in bon_results) / n_problems,
46
+ ]
47
+ colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
48
+
49
+ bars = ax.bar(methods, accuracies, color=colors, edgecolor="white", linewidth=1.5)
50
+ for bar, acc in zip(bars, accuracies):
51
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
52
+ f"{acc:.0%}", ha="center", va="bottom", fontweight="bold", fontsize=12)
53
+
54
+ ax.set_ylabel("Accuracy")
55
+ ax.set_title("Math Problem Accuracy: Greedy vs Best-of-N Methods\n(20 MATH-500 problems, Levels 1-3)")
56
+ ax.set_ylim(0, 1.05)
57
+ ax.grid(axis="y", alpha=0.3)
58
+ plt.tight_layout()
59
+ plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot1_accuracy_comparison.png")
60
+ plt.close()
61
+ print("Saved plot1_accuracy_comparison.png")
62
+
63
+ # ──────────────────────────────────────────────────────────────────────────────
64
+ # Plot 2: Accuracy vs N
65
+ # ──────────────────────────────────────────────────────────────────────────────
66
+ fig, ax = plt.subplots(figsize=(7, 5))
67
+
68
+ ns = sorted([int(k) for k in accuracy_by_n.keys()])
69
+ accs = [accuracy_by_n[str(n)] for n in ns]
70
+
71
+ ax.plot(ns, accs, "o-", color="#8172B2", linewidth=2, markersize=8, label="Weighted BoN")
72
+
73
+ # Add greedy baseline as horizontal line
74
+ greedy_acc = sum(r["greedy_correct"] for r in bon_results) / n_problems
75
+ ax.axhline(y=greedy_acc, color="#4C72B0", linestyle="--", linewidth=1.5, label=f"Greedy baseline ({greedy_acc:.0%})")
76
+
77
+ for n, acc in zip(ns, accs):
78
+ ax.annotate(f"{acc:.0%}", (n, acc), textcoords="offset points",
79
+ xytext=(0, 10), ha="center", fontsize=10)
80
+
81
+ ax.set_xlabel("N (number of samples)")
82
+ ax.set_ylabel("Accuracy")
83
+ ax.set_title("Weighted Best-of-N Accuracy vs Number of Samples")
84
+ ax.set_xticks(ns)
85
+ ax.set_ylim(0, 1.05)
86
+ ax.legend()
87
+ ax.grid(alpha=0.3)
88
+ plt.tight_layout()
89
+ plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot2_accuracy_vs_n.png")
90
+ plt.close()
91
+ print("Saved plot2_accuracy_vs_n.png")
92
+
93
+ # ──────────────────────────────────────────────────────────────────────────────
94
+ # Plot 3: Per-problem comparison (Greedy vs Weighted BoN)
95
+ # ──────────────────────────────────────────────────────────────────────────────
96
+ fig, ax = plt.subplots(figsize=(12, 5))
97
+
98
+ # Categorize problems
99
+ categories = {
100
+ "Both correct": [],
101
+ "Only BoN correct": [],
102
+ "Only Greedy correct": [],
103
+ "Both wrong": [],
104
+ }
105
+
106
+ for r in bon_results:
107
+ g = r["greedy_correct"]
108
+ b = r["weighted_bon_correct"]
109
+ label = f"L{r['level']}: {r['unique_id'].split('/')[-1][:15]}"
110
+ if g and b:
111
+ categories["Both correct"].append(label)
112
+ elif not g and b:
113
+ categories["Only BoN correct"].append(label)
114
+ elif g and not b:
115
+ categories["Only Greedy correct"].append(label)
116
+ else:
117
+ categories["Both wrong"].append(label)
118
+
119
+ # Color map for the stacked bars
120
+ cat_colors = {
121
+ "Both correct": "#55A868",
122
+ "Only BoN correct": "#8172B2",
123
+ "Only Greedy correct": "#C44E52",
124
+ "Both wrong": "#CCCCCC",
125
+ }
126
+
127
+ # Create a categorical overview
128
+ labels = []
129
+ colors_list = []
130
+ for r in bon_results:
131
+ g = r["greedy_correct"]
132
+ b = r["weighted_bon_correct"]
133
+ label = f"L{r['level']}"
134
+ labels.append(label)
135
+ if g and b:
136
+ colors_list.append(cat_colors["Both correct"])
137
+ elif not g and b:
138
+ colors_list.append(cat_colors["Only BoN correct"])
139
+ elif g and not b:
140
+ colors_list.append(cat_colors["Only Greedy correct"])
141
+ else:
142
+ colors_list.append(cat_colors["Both wrong"])
143
+
144
+ x = range(len(bon_results))
145
+ # Plot n_correct_in_16 as bar height, colored by category
146
+ heights = [r["n_correct_in_16"] for r in bon_results]
147
+ ax.bar(x, heights, color=colors_list, edgecolor="white", linewidth=0.5)
148
+
149
+ # Add problem labels
150
+ ax.set_xticks(x)
151
+ short_ids = [r["unique_id"].split("/")[-1].replace(".json", "")[:12] for r in bon_results]
152
+ ax.set_xticklabels(short_ids, rotation=45, ha="right", fontsize=8)
153
+
154
+ ax.set_ylabel("# Correct Solutions (out of 16)")
155
+ ax.set_title("Per-Problem Analysis: Correct Solutions in N=16 Sample")
156
+
157
+ # Legend
158
+ from matplotlib.patches import Patch
159
+ legend_elements = [Patch(facecolor=c, label=l) for l, c in cat_colors.items()]
160
+ ax.legend(handles=legend_elements, loc="upper right", fontsize=9)
161
+ ax.grid(axis="y", alpha=0.3)
162
+
163
+ plt.tight_layout()
164
+ plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot3_per_problem.png")
165
+ plt.close()
166
+ print("Saved plot3_per_problem.png")
167
+
168
+ # ──────────────────────────────────────────────────────────────────────────────
169
+ # Plot 4: PRM Score Distribution (correct vs incorrect solutions)
170
+ # ──────────────────────────────────────────────────────────────────────────────
171
+ fig, ax = plt.subplots(figsize=(7, 5))
172
+
173
+ correct_scores = []
174
+ incorrect_scores = []
175
+
176
+ for r in scored_results:
177
+ for answer, score in zip(r["extracted_answers"], r["prm_scores"]):
178
+ if answer == r["answer"]:
179
+ correct_scores.append(score)
180
+ else:
181
+ incorrect_scores.append(score)
182
+
183
+ bins = np.linspace(0, 1, 25)
184
+ ax.hist(correct_scores, bins=bins, alpha=0.7, label=f"Correct ({len(correct_scores)})", color="#55A868")
185
+ ax.hist(incorrect_scores, bins=bins, alpha=0.7, label=f"Incorrect ({len(incorrect_scores)})", color="#C44E52")
186
+
187
+ ax.set_xlabel("PRM Last-Step Score")
188
+ ax.set_ylabel("Count")
189
+ ax.set_title("PRM Score Distribution: Correct vs Incorrect Solutions")
190
+ ax.legend()
191
+ ax.grid(alpha=0.3)
192
+
193
+ plt.tight_layout()
194
+ plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot4_prm_scores.png")
195
+ plt.close()
196
+ print("Saved plot4_prm_scores.png")
197
+
198
+ # ──────────────────────────────────────────────────────────────────────────────
199
+ # Print detailed analysis
200
+ # ──────────────────────────────────────────────────────────────────────────────
201
+ print("\n" + "=" * 70)
202
+ print("DETAILED ANALYSIS")
203
+ print("=" * 70)
204
+
205
+ print(f"\nOverall Accuracies:")
206
+ print(f" Greedy (N=1): {accuracies[0]:.0%}")
207
+ print(f" Majority Vote (N=16): {accuracies[1]:.0%}")
208
+ print(f" Standard Best-of-N (N=16): {accuracies[2]:.0%}")
209
+ print(f" Weighted Best-of-N (N=16): {accuracies[3]:.0%}")
210
+
211
+ print(f"\nProblems ONLY solved by Weighted BoN (not greedy):")
212
+ for r in bon_results:
213
+ if r["weighted_bon_correct"] and not r["greedy_correct"]:
214
+ print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
215
+ print(f" Ground truth: {r['ground_truth']}")
216
+ print(f" Greedy answer: {r['greedy_answer']}")
217
+ print(f" BoN answer: {r['weighted_bon_answer']}")
218
+ print(f" Correct in sample: {r['n_correct_in_16']}/16")
219
+
220
+ print(f"\nProblems ONLY solved by Greedy (not BoN):")
221
+ for r in bon_results:
222
+ if r["greedy_correct"] and not r["weighted_bon_correct"]:
223
+ print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
224
+ print(f" Ground truth: {r['ground_truth']}")
225
+ print(f" Greedy answer: {r['greedy_answer']}")
226
+ print(f" BoN answer: {r['weighted_bon_answer']}")
227
+ print(f" Correct in sample: {r['n_correct_in_16']}/16")
228
+
229
+ print(f"\nProblems neither method solved:")
230
+ for r in bon_results:
231
+ if not r["greedy_correct"] and not r["weighted_bon_correct"]:
232
+ print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
233
+ print(f" Ground truth: {r['ground_truth']}")
234
+ print(f" Correct in sample: {r['n_correct_in_16']}/16")
235
+
236
+ # PRM Score stats
237
+ print(f"\nPRM Score Statistics:")
238
+ print(f" Correct solutions: mean={np.mean(correct_scores):.3f}, median={np.median(correct_scores):.3f}")
239
+ print(f" Incorrect solutions: mean={np.mean(incorrect_scores):.3f}, median={np.median(incorrect_scores):.3f}")
240
+
241
+ # Accuracy by level
242
+ print(f"\nAccuracy by problem level:")
243
+ for level in sorted(set(r["level"] for r in bon_results)):
244
+ level_results = [r for r in bon_results if r["level"] == level]
245
+ n = len(level_results)
246
+ g = sum(r["greedy_correct"] for r in level_results)
247
+ w = sum(r["weighted_bon_correct"] for r in level_results)
248
+ print(f" Level {level}: Greedy {g}/{n} ({g/n:.0%}) | Weighted BoN {w}/{n} ({w/n:.0%})")
249
+
250
+ # Accuracy by subject
251
+ print(f"\nAccuracy by subject:")
252
+ subjects = sorted(set(r["subject"] for r in bon_results))
253
+ for subj in subjects:
254
+ subj_results = [r for r in bon_results if r["subject"] == subj]
255
+ n = len(subj_results)
256
+ g = sum(r["greedy_correct"] for r in subj_results)
257
+ w = sum(r["weighted_bon_correct"] for r in subj_results)
258
+ print(f" {subj}: Greedy {g}/{n} | Weighted BoN {w}/{n}")
259
+
260
+ print("\nAll plots saved to outputs/ directory.")