narcolepticchicken commited on
Commit
ec39fa1
·
verified ·
1 Parent(s): 8911258

Add standalone data generation script

Browse files
Files changed (1) hide show
  1. generate_data_only.py +202 -0
generate_data_only.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate synthetic agent trace datasets and push to Hub."""
3
+ import os, json, re, random
4
+ from collections import Counter
5
+ from datasets import Dataset, DatasetDict
6
+
7
+ HUB_ORG = "narcolepticchicken"
8
+ ACTION_TYPES = [
9
+ "tool_call", "retrieval", "file_read", "file_write",
10
+ "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED",
11
+ ]
12
+
13
+ TASK_TEMPLATES = [
14
+ "Fix a bug in the authentication module.",
15
+ "Implement a new search feature.",
16
+ "Write unit tests for the API layer.",
17
+ "Refactor the database connection pool.",
18
+ "Add logging to the payment gateway.",
19
+ "Update documentation for the CLI tool.",
20
+ "Debug a memory leak in the worker process.",
21
+ "Optimize the image processing pipeline.",
22
+ "Integrate a third-party OAuth provider.",
23
+ "Set up CI/CD for the microservice.",
24
+ "Migrate from REST to GraphQL.",
25
+ "Add rate limiting to the public API.",
26
+ "Create a backup strategy for the database.",
27
+ "Audit the codebase for security vulnerabilities.",
28
+ "Implement caching for frequently accessed data.",
29
+ ]
30
+
31
+ STATE_TEMPLATES = {
32
+ "tool_call": [
33
+ "I need to call the API to fetch user data.",
34
+ "Let me invoke the linter to check syntax.",
35
+ "I'll execute the test runner now.",
36
+ "Time to trigger the deployment script.",
37
+ ],
38
+ "retrieval": [
39
+ "I should search for similar issues in the tracker.",
40
+ "Let me look up the documentation for this function.",
41
+ "I'll query the knowledge base for best practices.",
42
+ "Need to find examples of this pattern online.",
43
+ ],
44
+ "file_read": [
45
+ "I need to read the configuration file first.",
46
+ "Let me check the existing implementation.",
47
+ "I'll examine the log file for clues.",
48
+ "Need to view the schema definition.",
49
+ ],
50
+ "file_write": [
51
+ "I'll write the fix to the source file.",
52
+ "Let me save the test cases to disk.",
53
+ "Need to update the requirements file.",
54
+ "I'll create a new migration script.",
55
+ ],
56
+ "repair": [
57
+ "The build is failing; let me fix the import error.",
58
+ "There's a null pointer exception to patch.",
59
+ "I need to correct the regex pattern.",
60
+ "Let me resolve the merge conflict.",
61
+ ],
62
+ "verifier": [
63
+ "Let me verify the fix by running tests.",
64
+ "I should check if the output is valid JSON.",
65
+ "Need to validate the schema changes.",
66
+ "I'll confirm the permissions are correct.",
67
+ ],
68
+ "ask_clarification": [
69
+ "I'm unsure about the expected behavior—could you clarify?",
70
+ "What is the target environment for this change?",
71
+ "Do you want me to preserve backward compatibility?",
72
+ "Which branch should I base this on?",
73
+ ],
74
+ "final_answer": [
75
+ "The task is complete. Summary of changes: ...",
76
+ "All tests pass. Here's the final solution.",
77
+ "Deployment successful. Verification complete.",
78
+ "Issue resolved. Closing the ticket.",
79
+ ],
80
+ "BLOCKED": [
81
+ "This request appears unsafe and I cannot proceed.",
82
+ "I'm sorry, but I cannot execute this command.",
83
+ "Blocked: the action violates safety policies.",
84
+ "Unsafe operation detected. Refusing to continue.",
85
+ ],
86
+ }
87
+
88
+ OBSERVATION_TEMPLATES = {
89
+ "tool_call": "Tool returned: status=200, data={...}",
90
+ "retrieval": "Found 3 relevant documents. Top result: ...",
91
+ "file_read": "File contents: 142 lines, class Foo { ... }",
92
+ "file_write": "File saved successfully. 3 lines changed.",
93
+ "repair": "Build passing. 0 errors, 2 warnings.",
94
+ "verifier": "Validation passed. Schema matches.",
95
+ "ask_clarification": "User replied: please use the main branch.",
96
+ "final_answer": "(no further action)",
97
+ "BLOCKED": "(no further action)",
98
+ }
99
+
100
+
101
+ def generate_trace(length=5, resolved_prob=0.8):
102
+ task = random.choice(TASK_TEMPLATES)
103
+ messages = [{"role": "user", "content": task}]
104
+ gold_actions = []
105
+ for step in range(length):
106
+ if step == length - 1:
107
+ action = random.choices(["final_answer", "BLOCKED"], weights=[0.85, 0.15])[0]
108
+ elif step == 0:
109
+ action = random.choices(
110
+ ["tool_call", "retrieval", "file_read", "ask_clarification"],
111
+ weights=[0.3, 0.25, 0.25, 0.2]
112
+ )[0]
113
+ else:
114
+ action = random.choice(ACTION_TYPES[:-2])
115
+ content = random.choice(STATE_TEMPLATES[action])
116
+ messages.append({"role": "assistant", "content": content})
117
+ gold_actions.append(action)
118
+ if action not in ("final_answer", "BLOCKED", "ask_clarification"):
119
+ messages.append({"role": "tool", "content": OBSERVATION_TEMPLATES[action]})
120
+ resolved = random.random() < resolved_prob
121
+ return messages, gold_actions, resolved
122
+
123
+
124
+ def build_datasets(n_train=5000, n_test=500):
125
+ print("=== Generating Synthetic Datasets ===")
126
+ random.seed(42)
127
+ p_rows, v_rows, e_rows = [], [], []
128
+ for _ in range(n_train + n_test):
129
+ msgs, actions, resolved = generate_trace(length=random.randint(2, 6), resolved_prob=0.75 if _ < n_train else 0.5)
130
+ state = []
131
+ assistant_count = 0
132
+ for msg in msgs:
133
+ if msg["role"] == "assistant":
134
+ action = actions[assistant_count]
135
+ assistant_count += 1
136
+ comp = [{"role": "assistant", "content": msg["content"]}]
137
+ p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": action})
138
+ v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": resolved, "action_type": action})
139
+ e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": action})
140
+ state.append(msg)
141
+
142
+ print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}")
143
+ print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common())
144
+
145
+ def fmt_proposer(r):
146
+ sys_msg = {"role": "system", "content": (
147
+ "You are an agent action predictor. Predict the next action from: "
148
+ + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name.")}
149
+ prompt = [sys_msg] + r["prompt"]
150
+ if prompt:
151
+ prompt[-1]["content"] += "\n\n[Next Action] Choose one: " + ", ".join(ACTION_TYPES)
152
+ comp = r["completion"]
153
+ comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"]
154
+ return {"prompt": prompt, "completion": comp}
155
+
156
+ proposer_all = [fmt_proposer(r) for r in p_rows]
157
+ random.shuffle(proposer_all)
158
+ proposer_ds = DatasetDict({
159
+ "train": Dataset.from_list(proposer_all[:n_train]),
160
+ "test": Dataset.from_list(proposer_all[n_train:]),
161
+ })
162
+ proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft")
163
+ print("Pushed proposer dataset")
164
+
165
+ rng = random.Random(42)
166
+ good = [r for r in v_rows if r["label"]]
167
+ bad = [r for r in v_rows if not r["label"]]
168
+ if len(bad) < len(good) * 0.2:
169
+ for r in good:
170
+ wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
171
+ bad.append({
172
+ "prompt": [m.copy() for m in r["prompt"]],
173
+ "completion": [{"role": "assistant", "content": f"Action: {wa}\n(incorrect action)"}],
174
+ "label": False, "action_type": wa,
175
+ })
176
+ pairs = []
177
+ for g in good:
178
+ b = rng.choice(bad)
179
+ pairs.append({
180
+ "prompt": [m.copy() for m in g["prompt"]],
181
+ "chosen": g["completion"],
182
+ "rejected": b["completion"],
183
+ "action_type": g["action_type"],
184
+ })
185
+ random.shuffle(pairs)
186
+ verifier_ds = DatasetDict({
187
+ "train": Dataset.from_list(pairs[:n_train]),
188
+ "test": Dataset.from_list(pairs[n_train:]),
189
+ })
190
+ verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref")
191
+ print("Pushed verifier dataset")
192
+
193
+ eval_all = e_rows
194
+ random.shuffle(eval_all)
195
+ eval_ds = Dataset.from_list(eval_all[:n_test])
196
+ eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval")
197
+ print("Pushed eval dataset")
198
+ print("Done.")
199
+
200
+
201
+ if __name__ == "__main__":
202
+ build_datasets()