""" Speculative Tool Actions — Dataset Builder ========================================== Converts agent trace datasets into a unified schema with 8 action types: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED Sources: - SWE-bench/SWE-smith-trajectories (tool split, resolved=True) - tuandunghcmut/toolbench-v1 Output datasets (pushed to Hub): - {hub_org}/speculative-actions-proposer-sft -> prompt-completion for next-action SFT - {hub_org}/speculative-actions-verifier-pref -> chosen/rejected pairs for verifier DPO/Reward - {hub_org}/speculative-actions-eval -> held-out eval set with gold labels """ import json import re import argparse from collections import Counter from datasets import load_dataset, Dataset from random import Random ACTION_TYPES = [ "tool_call", "retrieval", "file_read", "file_write", "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED", ] ACTION_MAP = {a: i for i, a in enumerate(ACTION_TYPES)} def classify_action(content: str, tool_calls=None) -> str: """Heuristic classifier mapping raw agent output to one of ACTION_TYPES.""" c = content.lower() tc = json.dumps(tool_calls).lower() if tool_calls else "" combined = c + " " + tc if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined): return "final_answer" if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined): return "ask_clarification" if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined): return "BLOCKED" if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined): return "file_write" if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined): return "file_read" if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined): return "repair" if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined): return "verifier" if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined): return "retrieval" if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined): return "tool_call" return "tool_call" def process_swe_smith(split="train", max_rows=10_000): print(f"Loading SWE-smith tool/{split} ...") ds = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split=split, streaming=True) rows_proposer = [] rows_verifier = [] rows_eval = [] count = 0 for example in ds: count += 1 if count > max_rows: break messages = example.get("messages", []) resolved = example.get("resolved", False) state_so_far = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") tool_calls = msg.get("tool_calls", None) if role in ("assistant", "agent"): action_type = classify_action(content, tool_calls) prompt_messages = state_so_far.copy() completion_messages = [{"role": "assistant", "content": content}] if tool_calls: completion_messages[0]["tool_calls"] = tool_calls rows_proposer.append({ "prompt": prompt_messages, "completion": completion_messages, "action_type": action_type, }) rows_verifier.append({ "prompt": prompt_messages, "completion": completion_messages, "label": bool(resolved), "action_type": action_type, }) rows_eval.append({ "messages": prompt_messages + completion_messages, "resolved": resolved, "action_type": action_type, }) state_so_far.append(msg) print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows") return rows_proposer, rows_verifier, rows_eval def process_toolbench(split="train", max_rows=5_000): print(f"Loading toolbench/{split} ...") ds = load_dataset("tuandunghcmut/toolbench-v1", split=split, streaming=True) rows_proposer = [] rows_verifier = [] rows_eval = [] count = 0 for example in ds: count += 1 if count > max_rows: break conv = example.get("conversations", {}) froms = conv.get("from", []) values = conv.get("value", []) state_so_far = [] for role, content in zip(froms, values): msg = {"role": role, "content": content} if role == "assistant": action_type = classify_action(content) rows_proposer.append({ "prompt": state_so_far.copy(), "completion": [msg], "action_type": action_type, }) rows_verifier.append({ "prompt": state_so_far.copy(), "completion": [msg], "label": True, "action_type": action_type, }) rows_eval.append({ "messages": state_so_far + [msg], "resolved": True, "action_type": action_type, }) state_so_far.append(msg) print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows") return rows_proposer, rows_verifier, rows_eval def build_proposer_dataset(rows, hub_org): def fmt(row): system_msg = { "role": "system", "content": ( "You are an agent action predictor. Given the conversation state, " "predict the next action from: " + ", ".join(ACTION_TYPES) + ". " "Respond with exactly the action name and a brief justification." ), } prompt = [system_msg] + row["prompt"] prompt[-1]["content"] += ( "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES) ) completion = row["completion"] action_type = row["action_type"] completion[0]["content"] = f"Action: {action_type}\n" + completion[0]["content"] return {"prompt": prompt, "completion": completion} data = [fmt(r) for r in rows] ds = Dataset.from_list(data) ds = ds.shuffle(seed=42).train_test_split(test_size=0.1) ds.push_to_hub(f"{hub_org}/speculative-actions-proposer-sft") print(f"Pushed proposer SFT dataset to {hub_org}/speculative-actions-proposer-sft") return ds def build_verifier_dataset(rows, hub_org): rng = Random(42) good_rows = [r for r in rows if r["label"]] bad_rows = [r for r in rows if not r["label"]] if len(bad_rows) < len(good_rows) * 0.2: for r in good_rows: wrong_action = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) bad = { "prompt": r["prompt"], "completion": [{"role": "assistant", "content": f"Action: {wrong_action}\n(synthetic incorrect action)"}], "label": False, "action_type": wrong_action, } bad_rows.append(bad) pairs = [] for g in good_rows: b = rng.choice(bad_rows) pairs.append({ "prompt": g["prompt"], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"], }) ds = Dataset.from_list(pairs) ds = ds.shuffle(seed=42).train_test_split(test_size=0.1) ds.push_to_hub(f"{hub_org}/speculative-actions-verifier-pref") print(f"Pushed verifier preference dataset to {hub_org}/speculative-actions-verifier-pref") return ds def build_eval_dataset(rows, hub_org): ds = Dataset.from_list(rows) ds = ds.shuffle(seed=42).select(range(min(2_000, len(rows)))) ds.push_to_hub(f"{hub_org}/speculative-actions-eval") print(f"Pushed eval dataset to {hub_org}/speculative-actions-eval") return ds def main(): parser = argparse.ArgumentParser() parser.add_argument("--hub_org", default="narcolepticchicken", type=str) parser.add_argument("--max_swe", type=int, default=5_000) parser.add_argument("--max_toolbench", type=int, default=3_000) args = parser.parse_args() p1, v1, e1 = process_swe_smith("train", args.max_swe) p2, v2, e2 = process_toolbench("train", args.max_toolbench) proposer_rows = p1 + p2 verifier_rows = v1 + v2 eval_rows = e1 + e2 print(f"\nTotal rows: proposer={len(proposer_rows)}, verifier={len(verifier_rows)}, eval={len(eval_rows)}") print("\nAction distribution (proposer):") for act, n in Counter(r["action_type"] for r in proposer_rows).most_common(): print(f" {act}: {n}") build_proposer_dataset(proposer_rows, args.hub_org) build_verifier_dataset(verifier_rows, args.hub_org) build_eval_dataset(eval_rows, args.hub_org) print("\nDataset construction complete.") if __name__ == "__main__": main()