narcolepticchicken commited on
Commit
2aced14
·
verified ·
1 Parent(s): dd862b7

Add dataset builder script

Browse files
Files changed (1) hide show
  1. dataset_builder.py +259 -0
dataset_builder.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculative Tool Actions — Dataset Builder
3
+ ==========================================
4
+ Converts agent trace datasets into a unified schema with 8 action types:
5
+ tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED
6
+
7
+ Sources:
8
+ - SWE-bench/SWE-smith-trajectories (tool split, resolved=True)
9
+ - tuandunghcmut/toolbench-v1
10
+
11
+ Output datasets (pushed to Hub):
12
+ - {hub_org}/speculative-actions-proposer-sft -> prompt-completion for next-action SFT
13
+ - {hub_org}/speculative-actions-verifier-pref -> chosen/rejected pairs for verifier DPO/Reward
14
+ - {hub_org}/speculative-actions-eval -> held-out eval set with gold labels
15
+ """
16
+
17
+ import json
18
+ import re
19
+ import argparse
20
+ from collections import Counter
21
+ from datasets import load_dataset, Dataset
22
+ from random import Random
23
+
24
+ ACTION_TYPES = [
25
+ "tool_call",
26
+ "retrieval",
27
+ "file_read",
28
+ "file_write",
29
+ "repair",
30
+ "verifier",
31
+ "ask_clarification",
32
+ "final_answer",
33
+ "BLOCKED",
34
+ ]
35
+
36
+ ACTION_MAP = {a: i for i, a in enumerate(ACTION_TYPES)}
37
+
38
+
39
+ def classify_action(content: str, tool_calls=None) -> str:
40
+ """Heuristic classifier mapping raw agent output to one of ACTION_TYPES."""
41
+ c = content.lower()
42
+ tc = json.dumps(tool_calls).lower() if tool_calls else ""
43
+ combined = c + " " + tc
44
+
45
+ if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined):
46
+ return "final_answer"
47
+ if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined):
48
+ return "ask_clarification"
49
+ if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined):
50
+ return "BLOCKED"
51
+ if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined):
52
+ return "file_write"
53
+ if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined):
54
+ return "file_read"
55
+ if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined):
56
+ return "repair"
57
+ if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined):
58
+ return "verifier"
59
+ if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined):
60
+ return "retrieval"
61
+ if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined):
62
+ return "tool_call"
63
+ return "tool_call"
64
+
65
+
66
+ def process_swe_smith(split="train", max_rows=10_000):
67
+ print(f"Loading SWE-smith tool/{split} ...")
68
+ ds = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split=split, streaming=True)
69
+
70
+ rows_proposer = []
71
+ rows_verifier = []
72
+ rows_eval = []
73
+
74
+ count = 0
75
+ for example in ds:
76
+ count += 1
77
+ if count > max_rows:
78
+ break
79
+
80
+ messages = example.get("messages", [])
81
+ resolved = example.get("resolved", False)
82
+
83
+ state_so_far = []
84
+ for msg in messages:
85
+ role = msg.get("role", "")
86
+ content = msg.get("content", "")
87
+ tool_calls = msg.get("tool_calls", None)
88
+
89
+ if role in ("assistant", "agent"):
90
+ action_type = classify_action(content, tool_calls)
91
+ prompt_messages = state_so_far.copy()
92
+ completion_messages = [{"role": "assistant", "content": content}]
93
+ if tool_calls:
94
+ completion_messages[0]["tool_calls"] = tool_calls
95
+
96
+ rows_proposer.append({
97
+ "prompt": prompt_messages,
98
+ "completion": completion_messages,
99
+ "action_type": action_type,
100
+ })
101
+ rows_verifier.append({
102
+ "prompt": prompt_messages,
103
+ "completion": completion_messages,
104
+ "label": bool(resolved),
105
+ "action_type": action_type,
106
+ })
107
+ rows_eval.append({
108
+ "messages": prompt_messages + completion_messages,
109
+ "resolved": resolved,
110
+ "action_type": action_type,
111
+ })
112
+ state_so_far.append(msg)
113
+
114
+ print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows")
115
+ return rows_proposer, rows_verifier, rows_eval
116
+
117
+
118
+ def process_toolbench(split="train", max_rows=5_000):
119
+ print(f"Loading toolbench/{split} ...")
120
+ ds = load_dataset("tuandunghcmut/toolbench-v1", split=split, streaming=True)
121
+
122
+ rows_proposer = []
123
+ rows_verifier = []
124
+ rows_eval = []
125
+
126
+ count = 0
127
+ for example in ds:
128
+ count += 1
129
+ if count > max_rows:
130
+ break
131
+
132
+ conv = example.get("conversations", {})
133
+ froms = conv.get("from", [])
134
+ values = conv.get("value", [])
135
+
136
+ state_so_far = []
137
+ for role, content in zip(froms, values):
138
+ msg = {"role": role, "content": content}
139
+ if role == "assistant":
140
+ action_type = classify_action(content)
141
+ rows_proposer.append({
142
+ "prompt": state_so_far.copy(),
143
+ "completion": [msg],
144
+ "action_type": action_type,
145
+ })
146
+ rows_verifier.append({
147
+ "prompt": state_so_far.copy(),
148
+ "completion": [msg],
149
+ "label": True,
150
+ "action_type": action_type,
151
+ })
152
+ rows_eval.append({
153
+ "messages": state_so_far + [msg],
154
+ "resolved": True,
155
+ "action_type": action_type,
156
+ })
157
+ state_so_far.append(msg)
158
+
159
+ print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows")
160
+ return rows_proposer, rows_verifier, rows_eval
161
+
162
+
163
+ def build_proposer_dataset(rows, hub_org):
164
+ def fmt(row):
165
+ system_msg = {
166
+ "role": "system",
167
+ "content": (
168
+ "You are an agent action predictor. Given the conversation state, "
169
+ "predict the next action from: " + ", ".join(ACTION_TYPES) + ". "
170
+ "Respond with exactly the action name and a brief justification."
171
+ ),
172
+ }
173
+ prompt = [system_msg] + row["prompt"]
174
+ prompt[-1]["content"] += (
175
+ "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES)
176
+ )
177
+ completion = row["completion"]
178
+ action_type = row["action_type"]
179
+ completion[0]["content"] = f"Action: {action_type}\n" + completion[0]["content"]
180
+ return {"prompt": prompt, "completion": completion}
181
+
182
+ data = [fmt(r) for r in rows]
183
+ ds = Dataset.from_list(data)
184
+ ds = ds.shuffle(seed=42).train_test_split(test_size=0.1)
185
+ ds.push_to_hub(f"{hub_org}/speculative-actions-proposer-sft")
186
+ print(f"Pushed proposer SFT dataset to {hub_org}/speculative-actions-proposer-sft")
187
+ return ds
188
+
189
+
190
+ def build_verifier_dataset(rows, hub_org):
191
+ rng = Random(42)
192
+ good_rows = [r for r in rows if r["label"]]
193
+ bad_rows = [r for r in rows if not r["label"]]
194
+
195
+ if len(bad_rows) < len(good_rows) * 0.2:
196
+ for r in good_rows:
197
+ wrong_action = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
198
+ bad = {
199
+ "prompt": r["prompt"],
200
+ "completion": [{"role": "assistant", "content": f"Action: {wrong_action}\n(synthetic incorrect action)"}],
201
+ "label": False,
202
+ "action_type": wrong_action,
203
+ }
204
+ bad_rows.append(bad)
205
+
206
+ pairs = []
207
+ for g in good_rows:
208
+ b = rng.choice(bad_rows)
209
+ pairs.append({
210
+ "prompt": g["prompt"],
211
+ "chosen": g["completion"],
212
+ "rejected": b["completion"],
213
+ "action_type": g["action_type"],
214
+ })
215
+
216
+ ds = Dataset.from_list(pairs)
217
+ ds = ds.shuffle(seed=42).train_test_split(test_size=0.1)
218
+ ds.push_to_hub(f"{hub_org}/speculative-actions-verifier-pref")
219
+ print(f"Pushed verifier preference dataset to {hub_org}/speculative-actions-verifier-pref")
220
+ return ds
221
+
222
+
223
+ def build_eval_dataset(rows, hub_org):
224
+ ds = Dataset.from_list(rows)
225
+ ds = ds.shuffle(seed=42).select(range(min(2_000, len(rows))))
226
+ ds.push_to_hub(f"{hub_org}/speculative-actions-eval")
227
+ print(f"Pushed eval dataset to {hub_org}/speculative-actions-eval")
228
+ return ds
229
+
230
+
231
+ def main():
232
+ parser = argparse.ArgumentParser()
233
+ parser.add_argument("--hub_org", default="narcolepticchicken", type=str)
234
+ parser.add_argument("--max_swe", type=int, default=5_000)
235
+ parser.add_argument("--max_toolbench", type=int, default=3_000)
236
+ args = parser.parse_args()
237
+
238
+ p1, v1, e1 = process_swe_smith("train", args.max_swe)
239
+ p2, v2, e2 = process_toolbench("train", args.max_toolbench)
240
+
241
+ proposer_rows = p1 + p2
242
+ verifier_rows = v1 + v2
243
+ eval_rows = e1 + e2
244
+
245
+ print(f"\nTotal rows: proposer={len(proposer_rows)}, verifier={len(verifier_rows)}, eval={len(eval_rows)}")
246
+
247
+ print("\nAction distribution (proposer):")
248
+ for act, n in Counter(r["action_type"] for r in proposer_rows).most_common():
249
+ print(f" {act}: {n}")
250
+
251
+ build_proposer_dataset(proposer_rows, args.hub_org)
252
+ build_verifier_dataset(verifier_rows, args.hub_org)
253
+ build_eval_dataset(eval_rows, args.hub_org)
254
+
255
+ print("\nDataset construction complete.")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()