cmpatino HF Staff commited on
Commit
1ce1b4f
·
verified ·
1 Parent(s): 707fcea

Upload code/step1_filter_and_greedy.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/step1_filter_and_greedy.py +213 -0
code/step1_filter_and_greedy.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 1: Filter MATH-500 to 20 level 1-3 problems and generate greedy (N=1) solutions.
3
+
4
+ This script:
5
+ 1. Loads the MATH-500 dataset and filters to level 1-3 problems
6
+ 2. Randomly samples 20 problems (with a fixed seed for reproducibility)
7
+ 3. Generates a single greedy solution per problem using Qwen2.5-1.5B-Instruct
8
+ 4. Extracts answers from \boxed{} format and computes accuracy
9
+ 5. Saves results as JSON for the next steps
10
+
11
+ Co-authored with Claude (Anthropic) — used for structuring the pipeline and
12
+ prompt engineering. I can explain all code logic.
13
+ """
14
+
15
+ import json
16
+ import os
17
+ import random
18
+ import torch
19
+ from datasets import load_dataset
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
+ from typing import Optional
22
+
23
+
24
+ # ──────────────────────────────────────────────────────────────────────────────
25
+ # Helper: Extract answer from \boxed{...}
26
+ # Source: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
27
+ # ──────────────────────────────────────────────────────────────────────────────
28
+ def extract_boxed_solution(text: str) -> Optional[str]:
29
+ """
30
+ Extracts the content of the last \\boxed{} in a given LaTeX-style text.
31
+ Uses bracket-balanced parsing to handle nested braces correctly.
32
+ """
33
+ try:
34
+ start_index = text.rindex("\\boxed{")
35
+ content_start = start_index + 7
36
+ bracket_count = 1
37
+ current_pos = content_start
38
+
39
+ while bracket_count > 0 and current_pos < len(text):
40
+ if text[current_pos] == "{":
41
+ bracket_count += 1
42
+ elif text[current_pos] == "}":
43
+ bracket_count -= 1
44
+ current_pos += 1
45
+
46
+ if bracket_count == 0:
47
+ content = text[content_start : current_pos - 1].strip()
48
+ return content
49
+ else:
50
+ return None
51
+ except ValueError:
52
+ return None
53
+ except Exception:
54
+ return None
55
+
56
+
57
+ # ──────────────────────────────────────────────────────────────────────────────
58
+ # Step 1a: Filter dataset to level 1-3 and sample 20 problems
59
+ # ──────────────────────────────────────────────────────────────────────────────
60
+ print("=" * 70)
61
+ print("STEP 1: Loading and filtering MATH-500 dataset")
62
+ print("=" * 70)
63
+
64
+ dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
65
+ print(f"Total problems in MATH-500: {len(dataset)}")
66
+
67
+ # Filter to levels 1-3 (easier problems suitable for small models)
68
+ filtered = dataset.filter(lambda x: x["level"] in [1, 2, 3])
69
+ print(f"Problems at levels 1-3: {len(filtered)}")
70
+
71
+ # Sample 20 problems with a fixed seed for reproducibility
72
+ random.seed(42)
73
+ indices = random.sample(range(len(filtered)), k=20)
74
+ problems = filtered.select(indices)
75
+
76
+ # Display the selected problems
77
+ print(f"\nSelected {len(problems)} problems:")
78
+ for i, p in enumerate(problems):
79
+ print(f" [{i+1}] Level {p['level']} | {p['subject']} | {p['unique_id']}")
80
+ print(f" Answer: {p['answer']}")
81
+ # Show first 80 chars of problem
82
+ preview = p["problem"][:80].replace("\n", " ")
83
+ print(f" Problem: {preview}...")
84
+
85
+ # Save filtered problems for later steps
86
+ problems_data = [
87
+ {
88
+ "idx": i,
89
+ "problem": p["problem"],
90
+ "solution": p["solution"],
91
+ "answer": p["answer"],
92
+ "subject": p["subject"],
93
+ "level": p["level"],
94
+ "unique_id": p["unique_id"],
95
+ }
96
+ for i, p in enumerate(problems)
97
+ ]
98
+
99
+ os.makedirs("/Users/cmpatino/Projects/ml-intern/exercise/outputs", exist_ok=True)
100
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/filtered_problems.json", "w") as f:
101
+ json.dump(problems_data, f, indent=2)
102
+ print(f"\nSaved {len(problems_data)} problems to outputs/filtered_problems.json")
103
+
104
+ # ──────────────────────────────────────────────────────────────────────────────
105
+ # Step 1b: Generate greedy (N=1) solutions
106
+ # ──────────────────────────────────────────────────────────────────────────────
107
+ print("\n" + "=" * 70)
108
+ print("STEP 2: Generating greedy solutions with Qwen2.5-1.5B-Instruct")
109
+ print("=" * 70)
110
+
111
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
112
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ MODEL_ID,
115
+ torch_dtype=torch.bfloat16,
116
+ device_map="auto",
117
+ )
118
+
119
+ # System prompt encouraging chain-of-thought and \boxed{} format
120
+ SYSTEM_PROMPT = (
121
+ "You are a helpful math assistant. Solve the problem step by step, "
122
+ "showing your reasoning clearly. Put your final answer inside "
123
+ "\\boxed{answer} at the end of your solution."
124
+ )
125
+
126
+
127
+ def generate_solutions(problems_data, model, tokenizer, n=1, temperature=None, do_sample=False):
128
+ """
129
+ Generate n solutions per problem.
130
+
131
+ Args:
132
+ problems_data: list of problem dicts
133
+ model: the language model
134
+ tokenizer: the tokenizer
135
+ n: number of solutions to generate per problem
136
+ temperature: sampling temperature (None for greedy)
137
+ do_sample: whether to sample (False = greedy)
138
+
139
+ Returns:
140
+ list of dicts with problem info + generated solutions
141
+ """
142
+ results = []
143
+
144
+ for i, p in enumerate(problems_data):
145
+ print(f"\n Generating for problem {i+1}/{len(problems_data)}: {p['unique_id']}")
146
+
147
+ # Format the chat prompt
148
+ messages = [
149
+ {"role": "system", "content": SYSTEM_PROMPT},
150
+ {"role": "user", "content": p["problem"]},
151
+ ]
152
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
153
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
154
+
155
+ # Generation kwargs
156
+ gen_kwargs = {
157
+ "max_new_tokens": 2048,
158
+ "do_sample": do_sample,
159
+ }
160
+ if do_sample and temperature is not None:
161
+ gen_kwargs["temperature"] = temperature
162
+
163
+ solutions = []
164
+ for j in range(n):
165
+ with torch.no_grad():
166
+ output = model.generate(**inputs, **gen_kwargs)
167
+
168
+ # Decode only the generated tokens (exclude the prompt)
169
+ generated = output[0][inputs["input_ids"].shape[1]:]
170
+ solution_text = tokenizer.decode(generated, skip_special_tokens=True)
171
+ solutions.append(solution_text)
172
+
173
+ if n > 1 and (j + 1) % 4 == 0:
174
+ print(f" Generated {j+1}/{n} solutions")
175
+
176
+ result = {**p, "generated_solutions": solutions}
177
+ results.append(result)
178
+
179
+ return results
180
+
181
+
182
+ # Generate greedy solutions (N=1, no sampling)
183
+ greedy_results = generate_solutions(problems_data, model, tokenizer, n=1, do_sample=False)
184
+
185
+ # ──────────────────────────────────────────────────────────────────────────────
186
+ # Step 1c: Evaluate greedy accuracy
187
+ # ──────────────────────────────────────────────────────────────────────────────
188
+ print("\n" + "=" * 70)
189
+ print("STEP 3: Evaluating greedy accuracy")
190
+ print("=" * 70)
191
+
192
+ correct = 0
193
+ for r in greedy_results:
194
+ extracted = extract_boxed_solution(r["generated_solutions"][0])
195
+ r["greedy_extracted_answer"] = extracted
196
+ r["greedy_correct"] = (extracted is not None) and (extracted == r["answer"])
197
+ if r["greedy_correct"]:
198
+ correct += 1
199
+ status = "✓" if r["greedy_correct"] else "✗"
200
+ print(f" {status} [{r['unique_id']}] Expected: {r['answer']} | Got: {extracted}")
201
+
202
+ greedy_accuracy = correct / len(greedy_results)
203
+ print(f"\nGreedy accuracy: {correct}/{len(greedy_results)} = {greedy_accuracy:.1%}")
204
+
205
+ # Save greedy results
206
+ with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/greedy_results.json", "w") as f:
207
+ json.dump(greedy_results, f, indent=2)
208
+ print("Saved greedy results to outputs/greedy_results.json")
209
+
210
+ # Clean up model to free memory for PRM scoring
211
+ del model
212
+ torch.cuda.empty_cache()
213
+ print("\nFreed LLM memory. Ready for Step 2 (sampling + PRM scoring).")