| """ |
| Speculative Tool Actions — Dataset Builder |
| ========================================== |
| Converts agent trace datasets into a unified schema with 8 action types: |
| tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED |
| |
| Sources: |
| - SWE-bench/SWE-smith-trajectories (tool split, resolved=True) |
| - tuandunghcmut/toolbench-v1 |
| |
| Output datasets (pushed to Hub): |
| - {hub_org}/speculative-actions-proposer-sft -> prompt-completion for next-action SFT |
| - {hub_org}/speculative-actions-verifier-pref -> chosen/rejected pairs for verifier DPO/Reward |
| - {hub_org}/speculative-actions-eval -> held-out eval set with gold labels |
| """ |
|
|
| import json |
| import re |
| import argparse |
| from collections import Counter |
| from datasets import load_dataset, Dataset |
| from random import Random |
|
|
| ACTION_TYPES = [ |
| "tool_call", |
| "retrieval", |
| "file_read", |
| "file_write", |
| "repair", |
| "verifier", |
| "ask_clarification", |
| "final_answer", |
| "BLOCKED", |
| ] |
|
|
| ACTION_MAP = {a: i for i, a in enumerate(ACTION_TYPES)} |
|
|
|
|
| def classify_action(content: str, tool_calls=None) -> str: |
| """Heuristic classifier mapping raw agent output to one of ACTION_TYPES.""" |
| c = content.lower() |
| tc = json.dumps(tool_calls).lower() if tool_calls else "" |
| combined = c + " " + tc |
|
|
| if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined): |
| return "final_answer" |
| if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined): |
| return "ask_clarification" |
| if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined): |
| return "BLOCKED" |
| if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined): |
| return "file_write" |
| if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined): |
| return "file_read" |
| if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined): |
| return "repair" |
| if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined): |
| return "verifier" |
| if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined): |
| return "retrieval" |
| if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined): |
| return "tool_call" |
| return "tool_call" |
|
|
|
|
| def process_swe_smith(split="train", max_rows=10_000): |
| print(f"Loading SWE-smith tool/{split} ...") |
| ds = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split=split, streaming=True) |
|
|
| rows_proposer = [] |
| rows_verifier = [] |
| rows_eval = [] |
|
|
| count = 0 |
| for example in ds: |
| count += 1 |
| if count > max_rows: |
| break |
|
|
| messages = example.get("messages", []) |
| resolved = example.get("resolved", False) |
|
|
| state_so_far = [] |
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| tool_calls = msg.get("tool_calls", None) |
|
|
| if role in ("assistant", "agent"): |
| action_type = classify_action(content, tool_calls) |
| prompt_messages = state_so_far.copy() |
| completion_messages = [{"role": "assistant", "content": content}] |
| if tool_calls: |
| completion_messages[0]["tool_calls"] = tool_calls |
|
|
| rows_proposer.append({ |
| "prompt": prompt_messages, |
| "completion": completion_messages, |
| "action_type": action_type, |
| }) |
| rows_verifier.append({ |
| "prompt": prompt_messages, |
| "completion": completion_messages, |
| "label": bool(resolved), |
| "action_type": action_type, |
| }) |
| rows_eval.append({ |
| "messages": prompt_messages + completion_messages, |
| "resolved": resolved, |
| "action_type": action_type, |
| }) |
| state_so_far.append(msg) |
|
|
| print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows") |
| return rows_proposer, rows_verifier, rows_eval |
|
|
|
|
| def process_toolbench(split="train", max_rows=5_000): |
| print(f"Loading toolbench/{split} ...") |
| ds = load_dataset("tuandunghcmut/toolbench-v1", split=split, streaming=True) |
|
|
| rows_proposer = [] |
| rows_verifier = [] |
| rows_eval = [] |
|
|
| count = 0 |
| for example in ds: |
| count += 1 |
| if count > max_rows: |
| break |
|
|
| conv = example.get("conversations", {}) |
| froms = conv.get("from", []) |
| values = conv.get("value", []) |
|
|
| state_so_far = [] |
| for role, content in zip(froms, values): |
| msg = {"role": role, "content": content} |
| if role == "assistant": |
| action_type = classify_action(content) |
| rows_proposer.append({ |
| "prompt": state_so_far.copy(), |
| "completion": [msg], |
| "action_type": action_type, |
| }) |
| rows_verifier.append({ |
| "prompt": state_so_far.copy(), |
| "completion": [msg], |
| "label": True, |
| "action_type": action_type, |
| }) |
| rows_eval.append({ |
| "messages": state_so_far + [msg], |
| "resolved": True, |
| "action_type": action_type, |
| }) |
| state_so_far.append(msg) |
|
|
| print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows") |
| return rows_proposer, rows_verifier, rows_eval |
|
|
|
|
| def build_proposer_dataset(rows, hub_org): |
| def fmt(row): |
| system_msg = { |
| "role": "system", |
| "content": ( |
| "You are an agent action predictor. Given the conversation state, " |
| "predict the next action from: " + ", ".join(ACTION_TYPES) + ". " |
| "Respond with exactly the action name and a brief justification." |
| ), |
| } |
| prompt = [system_msg] + row["prompt"] |
| prompt[-1]["content"] += ( |
| "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES) |
| ) |
| completion = row["completion"] |
| action_type = row["action_type"] |
| completion[0]["content"] = f"Action: {action_type}\n" + completion[0]["content"] |
| return {"prompt": prompt, "completion": completion} |
|
|
| data = [fmt(r) for r in rows] |
| ds = Dataset.from_list(data) |
| ds = ds.shuffle(seed=42).train_test_split(test_size=0.1) |
| ds.push_to_hub(f"{hub_org}/speculative-actions-proposer-sft") |
| print(f"Pushed proposer SFT dataset to {hub_org}/speculative-actions-proposer-sft") |
| return ds |
|
|
|
|
| def build_verifier_dataset(rows, hub_org): |
| rng = Random(42) |
| good_rows = [r for r in rows if r["label"]] |
| bad_rows = [r for r in rows if not r["label"]] |
|
|
| if len(bad_rows) < len(good_rows) * 0.2: |
| for r in good_rows: |
| wrong_action = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) |
| bad = { |
| "prompt": r["prompt"], |
| "completion": [{"role": "assistant", "content": f"Action: {wrong_action}\n(synthetic incorrect action)"}], |
| "label": False, |
| "action_type": wrong_action, |
| } |
| bad_rows.append(bad) |
|
|
| pairs = [] |
| for g in good_rows: |
| b = rng.choice(bad_rows) |
| pairs.append({ |
| "prompt": g["prompt"], |
| "chosen": g["completion"], |
| "rejected": b["completion"], |
| "action_type": g["action_type"], |
| }) |
|
|
| ds = Dataset.from_list(pairs) |
| ds = ds.shuffle(seed=42).train_test_split(test_size=0.1) |
| ds.push_to_hub(f"{hub_org}/speculative-actions-verifier-pref") |
| print(f"Pushed verifier preference dataset to {hub_org}/speculative-actions-verifier-pref") |
| return ds |
|
|
|
|
| def build_eval_dataset(rows, hub_org): |
| ds = Dataset.from_list(rows) |
| ds = ds.shuffle(seed=42).select(range(min(2_000, len(rows)))) |
| ds.push_to_hub(f"{hub_org}/speculative-actions-eval") |
| print(f"Pushed eval dataset to {hub_org}/speculative-actions-eval") |
| return ds |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hub_org", default="narcolepticchicken", type=str) |
| parser.add_argument("--max_swe", type=int, default=5_000) |
| parser.add_argument("--max_toolbench", type=int, default=3_000) |
| args = parser.parse_args() |
|
|
| p1, v1, e1 = process_swe_smith("train", args.max_swe) |
| p2, v2, e2 = process_toolbench("train", args.max_toolbench) |
|
|
| proposer_rows = p1 + p2 |
| verifier_rows = v1 + v2 |
| eval_rows = e1 + e2 |
|
|
| print(f"\nTotal rows: proposer={len(proposer_rows)}, verifier={len(verifier_rows)}, eval={len(eval_rows)}") |
|
|
| print("\nAction distribution (proposer):") |
| for act, n in Counter(r["action_type"] for r in proposer_rows).most_common(): |
| print(f" {act}: {n}") |
|
|
| build_proposer_dataset(proposer_rows, args.hub_org) |
| build_verifier_dataset(verifier_rows, args.hub_org) |
| build_eval_dataset(eval_rows, args.hub_org) |
|
|
| print("\nDataset construction complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|