#!/usr/bin/env python3 """Generate synthetic agent trace datasets and push to Hub.""" import os, json, re, random from collections import Counter from datasets import Dataset, DatasetDict HUB_ORG = "narcolepticchicken" ACTION_TYPES = [ "tool_call", "retrieval", "file_read", "file_write", "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED", ] TASK_TEMPLATES = [ "Fix a bug in the authentication module.", "Implement a new search feature.", "Write unit tests for the API layer.", "Refactor the database connection pool.", "Add logging to the payment gateway.", "Update documentation for the CLI tool.", "Debug a memory leak in the worker process.", "Optimize the image processing pipeline.", "Integrate a third-party OAuth provider.", "Set up CI/CD for the microservice.", "Migrate from REST to GraphQL.", "Add rate limiting to the public API.", "Create a backup strategy for the database.", "Audit the codebase for security vulnerabilities.", "Implement caching for frequently accessed data.", ] STATE_TEMPLATES = { "tool_call": [ "I need to call the API to fetch user data.", "Let me invoke the linter to check syntax.", "I'll execute the test runner now.", "Time to trigger the deployment script.", ], "retrieval": [ "I should search for similar issues in the tracker.", "Let me look up the documentation for this function.", "I'll query the knowledge base for best practices.", "Need to find examples of this pattern online.", ], "file_read": [ "I need to read the configuration file first.", "Let me check the existing implementation.", "I'll examine the log file for clues.", "Need to view the schema definition.", ], "file_write": [ "I'll write the fix to the source file.", "Let me save the test cases to disk.", "Need to update the requirements file.", "I'll create a new migration script.", ], "repair": [ "The build is failing; let me fix the import error.", "There's a null pointer exception to patch.", "I need to correct the regex pattern.", "Let me resolve the merge conflict.", ], "verifier": [ "Let me verify the fix by running tests.", "I should check if the output is valid JSON.", "Need to validate the schema changes.", "I'll confirm the permissions are correct.", ], "ask_clarification": [ "I'm unsure about the expected behavior—could you clarify?", "What is the target environment for this change?", "Do you want me to preserve backward compatibility?", "Which branch should I base this on?", ], "final_answer": [ "The task is complete. Summary of changes: ...", "All tests pass. Here's the final solution.", "Deployment successful. Verification complete.", "Issue resolved. Closing the ticket.", ], "BLOCKED": [ "This request appears unsafe and I cannot proceed.", "I'm sorry, but I cannot execute this command.", "Blocked: the action violates safety policies.", "Unsafe operation detected. Refusing to continue.", ], } OBSERVATION_TEMPLATES = { "tool_call": "Tool returned: status=200, data={...}", "retrieval": "Found 3 relevant documents. Top result: ...", "file_read": "File contents: 142 lines, class Foo { ... }", "file_write": "File saved successfully. 3 lines changed.", "repair": "Build passing. 0 errors, 2 warnings.", "verifier": "Validation passed. Schema matches.", "ask_clarification": "User replied: please use the main branch.", "final_answer": "(no further action)", "BLOCKED": "(no further action)", } def generate_trace(length=5, resolved_prob=0.8): task = random.choice(TASK_TEMPLATES) messages = [{"role": "user", "content": task}] gold_actions = [] for step in range(length): if step == length - 1: action = random.choices(["final_answer", "BLOCKED"], weights=[0.85, 0.15])[0] elif step == 0: action = random.choices( ["tool_call", "retrieval", "file_read", "ask_clarification"], weights=[0.3, 0.25, 0.25, 0.2] )[0] else: action = random.choice(ACTION_TYPES[:-2]) content = random.choice(STATE_TEMPLATES[action]) messages.append({"role": "assistant", "content": content}) gold_actions.append(action) if action not in ("final_answer", "BLOCKED", "ask_clarification"): messages.append({"role": "tool", "content": OBSERVATION_TEMPLATES[action]}) resolved = random.random() < resolved_prob return messages, gold_actions, resolved def build_datasets(n_train=5000, n_test=500): print("=== Generating Synthetic Datasets ===") random.seed(42) p_rows, v_rows, e_rows = [], [], [] for _ in range(n_train + n_test): msgs, actions, resolved = generate_trace(length=random.randint(2, 6), resolved_prob=0.75 if _ < n_train else 0.5) state = [] assistant_count = 0 for msg in msgs: if msg["role"] == "assistant": action = actions[assistant_count] assistant_count += 1 comp = [{"role": "assistant", "content": msg["content"]}] p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": action}) v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": resolved, "action_type": action}) e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": action}) state.append(msg) print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}") print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common()) def fmt_proposer(r): sys_msg = {"role": "system", "content": ( "You are an agent action predictor. Predict the next action from: " + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name.")} prompt = [sys_msg] + r["prompt"] if prompt: prompt[-1]["content"] += "\n\n[Next Action] Choose one: " + ", ".join(ACTION_TYPES) comp = r["completion"] comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"] return {"prompt": prompt, "completion": comp} proposer_all = [fmt_proposer(r) for r in p_rows] random.shuffle(proposer_all) proposer_ds = DatasetDict({ "train": Dataset.from_list(proposer_all[:n_train]), "test": Dataset.from_list(proposer_all[n_train:]), }) proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft") print("Pushed proposer dataset") rng = random.Random(42) good = [r for r in v_rows if r["label"]] bad = [r for r in v_rows if not r["label"]] if len(bad) < len(good) * 0.2: for r in good: wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) bad.append({ "prompt": [m.copy() for m in r["prompt"]], "completion": [{"role": "assistant", "content": f"Action: {wa}\n(incorrect action)"}], "label": False, "action_type": wa, }) pairs = [] for g in good: b = rng.choice(bad) pairs.append({ "prompt": [m.copy() for m in g["prompt"]], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"], }) random.shuffle(pairs) verifier_ds = DatasetDict({ "train": Dataset.from_list(pairs[:n_train]), "test": Dataset.from_list(pairs[n_train:]), }) verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref") print("Pushed verifier dataset") eval_all = e_rows random.shuffle(eval_all) eval_ds = Dataset.from_list(eval_all[:n_test]) eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval") print("Pushed eval dataset") print("Done.") if __name__ == "__main__": build_datasets()