File size: 8,257 Bytes
ec39fa1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | #!/usr/bin/env python3
"""Generate synthetic agent trace datasets and push to Hub."""
import os, json, re, random
from collections import Counter
from datasets import Dataset, DatasetDict
HUB_ORG = "narcolepticchicken"
ACTION_TYPES = [
"tool_call", "retrieval", "file_read", "file_write",
"repair", "verifier", "ask_clarification", "final_answer", "BLOCKED",
]
TASK_TEMPLATES = [
"Fix a bug in the authentication module.",
"Implement a new search feature.",
"Write unit tests for the API layer.",
"Refactor the database connection pool.",
"Add logging to the payment gateway.",
"Update documentation for the CLI tool.",
"Debug a memory leak in the worker process.",
"Optimize the image processing pipeline.",
"Integrate a third-party OAuth provider.",
"Set up CI/CD for the microservice.",
"Migrate from REST to GraphQL.",
"Add rate limiting to the public API.",
"Create a backup strategy for the database.",
"Audit the codebase for security vulnerabilities.",
"Implement caching for frequently accessed data.",
]
STATE_TEMPLATES = {
"tool_call": [
"I need to call the API to fetch user data.",
"Let me invoke the linter to check syntax.",
"I'll execute the test runner now.",
"Time to trigger the deployment script.",
],
"retrieval": [
"I should search for similar issues in the tracker.",
"Let me look up the documentation for this function.",
"I'll query the knowledge base for best practices.",
"Need to find examples of this pattern online.",
],
"file_read": [
"I need to read the configuration file first.",
"Let me check the existing implementation.",
"I'll examine the log file for clues.",
"Need to view the schema definition.",
],
"file_write": [
"I'll write the fix to the source file.",
"Let me save the test cases to disk.",
"Need to update the requirements file.",
"I'll create a new migration script.",
],
"repair": [
"The build is failing; let me fix the import error.",
"There's a null pointer exception to patch.",
"I need to correct the regex pattern.",
"Let me resolve the merge conflict.",
],
"verifier": [
"Let me verify the fix by running tests.",
"I should check if the output is valid JSON.",
"Need to validate the schema changes.",
"I'll confirm the permissions are correct.",
],
"ask_clarification": [
"I'm unsure about the expected behavior—could you clarify?",
"What is the target environment for this change?",
"Do you want me to preserve backward compatibility?",
"Which branch should I base this on?",
],
"final_answer": [
"The task is complete. Summary of changes: ...",
"All tests pass. Here's the final solution.",
"Deployment successful. Verification complete.",
"Issue resolved. Closing the ticket.",
],
"BLOCKED": [
"This request appears unsafe and I cannot proceed.",
"I'm sorry, but I cannot execute this command.",
"Blocked: the action violates safety policies.",
"Unsafe operation detected. Refusing to continue.",
],
}
OBSERVATION_TEMPLATES = {
"tool_call": "Tool returned: status=200, data={...}",
"retrieval": "Found 3 relevant documents. Top result: ...",
"file_read": "File contents: 142 lines, class Foo { ... }",
"file_write": "File saved successfully. 3 lines changed.",
"repair": "Build passing. 0 errors, 2 warnings.",
"verifier": "Validation passed. Schema matches.",
"ask_clarification": "User replied: please use the main branch.",
"final_answer": "(no further action)",
"BLOCKED": "(no further action)",
}
def generate_trace(length=5, resolved_prob=0.8):
task = random.choice(TASK_TEMPLATES)
messages = [{"role": "user", "content": task}]
gold_actions = []
for step in range(length):
if step == length - 1:
action = random.choices(["final_answer", "BLOCKED"], weights=[0.85, 0.15])[0]
elif step == 0:
action = random.choices(
["tool_call", "retrieval", "file_read", "ask_clarification"],
weights=[0.3, 0.25, 0.25, 0.2]
)[0]
else:
action = random.choice(ACTION_TYPES[:-2])
content = random.choice(STATE_TEMPLATES[action])
messages.append({"role": "assistant", "content": content})
gold_actions.append(action)
if action not in ("final_answer", "BLOCKED", "ask_clarification"):
messages.append({"role": "tool", "content": OBSERVATION_TEMPLATES[action]})
resolved = random.random() < resolved_prob
return messages, gold_actions, resolved
def build_datasets(n_train=5000, n_test=500):
print("=== Generating Synthetic Datasets ===")
random.seed(42)
p_rows, v_rows, e_rows = [], [], []
for _ in range(n_train + n_test):
msgs, actions, resolved = generate_trace(length=random.randint(2, 6), resolved_prob=0.75 if _ < n_train else 0.5)
state = []
assistant_count = 0
for msg in msgs:
if msg["role"] == "assistant":
action = actions[assistant_count]
assistant_count += 1
comp = [{"role": "assistant", "content": msg["content"]}]
p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": action})
v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": resolved, "action_type": action})
e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": action})
state.append(msg)
print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}")
print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common())
def fmt_proposer(r):
sys_msg = {"role": "system", "content": (
"You are an agent action predictor. Predict the next action from: "
+ ", ".join(ACTION_TYPES) + ". Respond with exactly the action name.")}
prompt = [sys_msg] + r["prompt"]
if prompt:
prompt[-1]["content"] += "\n\n[Next Action] Choose one: " + ", ".join(ACTION_TYPES)
comp = r["completion"]
comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"]
return {"prompt": prompt, "completion": comp}
proposer_all = [fmt_proposer(r) for r in p_rows]
random.shuffle(proposer_all)
proposer_ds = DatasetDict({
"train": Dataset.from_list(proposer_all[:n_train]),
"test": Dataset.from_list(proposer_all[n_train:]),
})
proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft")
print("Pushed proposer dataset")
rng = random.Random(42)
good = [r for r in v_rows if r["label"]]
bad = [r for r in v_rows if not r["label"]]
if len(bad) < len(good) * 0.2:
for r in good:
wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
bad.append({
"prompt": [m.copy() for m in r["prompt"]],
"completion": [{"role": "assistant", "content": f"Action: {wa}\n(incorrect action)"}],
"label": False, "action_type": wa,
})
pairs = []
for g in good:
b = rng.choice(bad)
pairs.append({
"prompt": [m.copy() for m in g["prompt"]],
"chosen": g["completion"],
"rejected": b["completion"],
"action_type": g["action_type"],
})
random.shuffle(pairs)
verifier_ds = DatasetDict({
"train": Dataset.from_list(pairs[:n_train]),
"test": Dataset.from_list(pairs[n_train:]),
})
verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref")
print("Pushed verifier dataset")
eval_all = e_rows
random.shuffle(eval_all)
eval_ds = Dataset.from_list(eval_all[:n_test])
eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval")
print("Pushed eval dataset")
print("Done.")
if __name__ == "__main__":
build_datasets()
|