| |
| """Generate synthetic agent trace datasets and push to Hub.""" |
| import os, json, re, random |
| from collections import Counter |
| from datasets import Dataset, DatasetDict |
|
|
| HUB_ORG = "narcolepticchicken" |
| ACTION_TYPES = [ |
| "tool_call", "retrieval", "file_read", "file_write", |
| "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED", |
| ] |
|
|
| TASK_TEMPLATES = [ |
| "Fix a bug in the authentication module.", |
| "Implement a new search feature.", |
| "Write unit tests for the API layer.", |
| "Refactor the database connection pool.", |
| "Add logging to the payment gateway.", |
| "Update documentation for the CLI tool.", |
| "Debug a memory leak in the worker process.", |
| "Optimize the image processing pipeline.", |
| "Integrate a third-party OAuth provider.", |
| "Set up CI/CD for the microservice.", |
| "Migrate from REST to GraphQL.", |
| "Add rate limiting to the public API.", |
| "Create a backup strategy for the database.", |
| "Audit the codebase for security vulnerabilities.", |
| "Implement caching for frequently accessed data.", |
| ] |
|
|
| STATE_TEMPLATES = { |
| "tool_call": [ |
| "I need to call the API to fetch user data.", |
| "Let me invoke the linter to check syntax.", |
| "I'll execute the test runner now.", |
| "Time to trigger the deployment script.", |
| ], |
| "retrieval": [ |
| "I should search for similar issues in the tracker.", |
| "Let me look up the documentation for this function.", |
| "I'll query the knowledge base for best practices.", |
| "Need to find examples of this pattern online.", |
| ], |
| "file_read": [ |
| "I need to read the configuration file first.", |
| "Let me check the existing implementation.", |
| "I'll examine the log file for clues.", |
| "Need to view the schema definition.", |
| ], |
| "file_write": [ |
| "I'll write the fix to the source file.", |
| "Let me save the test cases to disk.", |
| "Need to update the requirements file.", |
| "I'll create a new migration script.", |
| ], |
| "repair": [ |
| "The build is failing; let me fix the import error.", |
| "There's a null pointer exception to patch.", |
| "I need to correct the regex pattern.", |
| "Let me resolve the merge conflict.", |
| ], |
| "verifier": [ |
| "Let me verify the fix by running tests.", |
| "I should check if the output is valid JSON.", |
| "Need to validate the schema changes.", |
| "I'll confirm the permissions are correct.", |
| ], |
| "ask_clarification": [ |
| "I'm unsure about the expected behavior—could you clarify?", |
| "What is the target environment for this change?", |
| "Do you want me to preserve backward compatibility?", |
| "Which branch should I base this on?", |
| ], |
| "final_answer": [ |
| "The task is complete. Summary of changes: ...", |
| "All tests pass. Here's the final solution.", |
| "Deployment successful. Verification complete.", |
| "Issue resolved. Closing the ticket.", |
| ], |
| "BLOCKED": [ |
| "This request appears unsafe and I cannot proceed.", |
| "I'm sorry, but I cannot execute this command.", |
| "Blocked: the action violates safety policies.", |
| "Unsafe operation detected. Refusing to continue.", |
| ], |
| } |
|
|
| OBSERVATION_TEMPLATES = { |
| "tool_call": "Tool returned: status=200, data={...}", |
| "retrieval": "Found 3 relevant documents. Top result: ...", |
| "file_read": "File contents: 142 lines, class Foo { ... }", |
| "file_write": "File saved successfully. 3 lines changed.", |
| "repair": "Build passing. 0 errors, 2 warnings.", |
| "verifier": "Validation passed. Schema matches.", |
| "ask_clarification": "User replied: please use the main branch.", |
| "final_answer": "(no further action)", |
| "BLOCKED": "(no further action)", |
| } |
|
|
|
|
| def generate_trace(length=5, resolved_prob=0.8): |
| task = random.choice(TASK_TEMPLATES) |
| messages = [{"role": "user", "content": task}] |
| gold_actions = [] |
| for step in range(length): |
| if step == length - 1: |
| action = random.choices(["final_answer", "BLOCKED"], weights=[0.85, 0.15])[0] |
| elif step == 0: |
| action = random.choices( |
| ["tool_call", "retrieval", "file_read", "ask_clarification"], |
| weights=[0.3, 0.25, 0.25, 0.2] |
| )[0] |
| else: |
| action = random.choice(ACTION_TYPES[:-2]) |
| content = random.choice(STATE_TEMPLATES[action]) |
| messages.append({"role": "assistant", "content": content}) |
| gold_actions.append(action) |
| if action not in ("final_answer", "BLOCKED", "ask_clarification"): |
| messages.append({"role": "tool", "content": OBSERVATION_TEMPLATES[action]}) |
| resolved = random.random() < resolved_prob |
| return messages, gold_actions, resolved |
|
|
|
|
| def build_datasets(n_train=5000, n_test=500): |
| print("=== Generating Synthetic Datasets ===") |
| random.seed(42) |
| p_rows, v_rows, e_rows = [], [], [] |
| for _ in range(n_train + n_test): |
| msgs, actions, resolved = generate_trace(length=random.randint(2, 6), resolved_prob=0.75 if _ < n_train else 0.5) |
| state = [] |
| assistant_count = 0 |
| for msg in msgs: |
| if msg["role"] == "assistant": |
| action = actions[assistant_count] |
| assistant_count += 1 |
| comp = [{"role": "assistant", "content": msg["content"]}] |
| p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": action}) |
| v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": resolved, "action_type": action}) |
| e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": action}) |
| state.append(msg) |
|
|
| print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}") |
| print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common()) |
|
|
| def fmt_proposer(r): |
| sys_msg = {"role": "system", "content": ( |
| "You are an agent action predictor. Predict the next action from: " |
| + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name.")} |
| prompt = [sys_msg] + r["prompt"] |
| if prompt: |
| prompt[-1]["content"] += "\n\n[Next Action] Choose one: " + ", ".join(ACTION_TYPES) |
| comp = r["completion"] |
| comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"] |
| return {"prompt": prompt, "completion": comp} |
|
|
| proposer_all = [fmt_proposer(r) for r in p_rows] |
| random.shuffle(proposer_all) |
| proposer_ds = DatasetDict({ |
| "train": Dataset.from_list(proposer_all[:n_train]), |
| "test": Dataset.from_list(proposer_all[n_train:]), |
| }) |
| proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft") |
| print("Pushed proposer dataset") |
|
|
| rng = random.Random(42) |
| good = [r for r in v_rows if r["label"]] |
| bad = [r for r in v_rows if not r["label"]] |
| if len(bad) < len(good) * 0.2: |
| for r in good: |
| wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) |
| bad.append({ |
| "prompt": [m.copy() for m in r["prompt"]], |
| "completion": [{"role": "assistant", "content": f"Action: {wa}\n(incorrect action)"}], |
| "label": False, "action_type": wa, |
| }) |
| pairs = [] |
| for g in good: |
| b = rng.choice(bad) |
| pairs.append({ |
| "prompt": [m.copy() for m in g["prompt"]], |
| "chosen": g["completion"], |
| "rejected": b["completion"], |
| "action_type": g["action_type"], |
| }) |
| random.shuffle(pairs) |
| verifier_ds = DatasetDict({ |
| "train": Dataset.from_list(pairs[:n_train]), |
| "test": Dataset.from_list(pairs[n_train:]), |
| }) |
| verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref") |
| print("Pushed verifier dataset") |
|
|
| eval_all = e_rows |
| random.shuffle(eval_all) |
| eval_ds = Dataset.from_list(eval_all[:n_test]) |
| eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval") |
| print("Pushed eval dataset") |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| build_datasets() |
|
|