speculative-tool-actions / dataset_builder.py
narcolepticchicken's picture
Add dataset builder script
2aced14 verified
"""
Speculative Tool Actions — Dataset Builder
==========================================
Converts agent trace datasets into a unified schema with 8 action types:
tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED
Sources:
- SWE-bench/SWE-smith-trajectories (tool split, resolved=True)
- tuandunghcmut/toolbench-v1
Output datasets (pushed to Hub):
- {hub_org}/speculative-actions-proposer-sft -> prompt-completion for next-action SFT
- {hub_org}/speculative-actions-verifier-pref -> chosen/rejected pairs for verifier DPO/Reward
- {hub_org}/speculative-actions-eval -> held-out eval set with gold labels
"""
import json
import re
import argparse
from collections import Counter
from datasets import load_dataset, Dataset
from random import Random
ACTION_TYPES = [
"tool_call",
"retrieval",
"file_read",
"file_write",
"repair",
"verifier",
"ask_clarification",
"final_answer",
"BLOCKED",
]
ACTION_MAP = {a: i for i, a in enumerate(ACTION_TYPES)}
def classify_action(content: str, tool_calls=None) -> str:
"""Heuristic classifier mapping raw agent output to one of ACTION_TYPES."""
c = content.lower()
tc = json.dumps(tool_calls).lower() if tool_calls else ""
combined = c + " " + tc
if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined):
return "final_answer"
if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined):
return "ask_clarification"
if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined):
return "BLOCKED"
if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined):
return "file_write"
if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined):
return "file_read"
if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined):
return "repair"
if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined):
return "verifier"
if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined):
return "retrieval"
if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined):
return "tool_call"
return "tool_call"
def process_swe_smith(split="train", max_rows=10_000):
print(f"Loading SWE-smith tool/{split} ...")
ds = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split=split, streaming=True)
rows_proposer = []
rows_verifier = []
rows_eval = []
count = 0
for example in ds:
count += 1
if count > max_rows:
break
messages = example.get("messages", [])
resolved = example.get("resolved", False)
state_so_far = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", None)
if role in ("assistant", "agent"):
action_type = classify_action(content, tool_calls)
prompt_messages = state_so_far.copy()
completion_messages = [{"role": "assistant", "content": content}]
if tool_calls:
completion_messages[0]["tool_calls"] = tool_calls
rows_proposer.append({
"prompt": prompt_messages,
"completion": completion_messages,
"action_type": action_type,
})
rows_verifier.append({
"prompt": prompt_messages,
"completion": completion_messages,
"label": bool(resolved),
"action_type": action_type,
})
rows_eval.append({
"messages": prompt_messages + completion_messages,
"resolved": resolved,
"action_type": action_type,
})
state_so_far.append(msg)
print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows")
return rows_proposer, rows_verifier, rows_eval
def process_toolbench(split="train", max_rows=5_000):
print(f"Loading toolbench/{split} ...")
ds = load_dataset("tuandunghcmut/toolbench-v1", split=split, streaming=True)
rows_proposer = []
rows_verifier = []
rows_eval = []
count = 0
for example in ds:
count += 1
if count > max_rows:
break
conv = example.get("conversations", {})
froms = conv.get("from", [])
values = conv.get("value", [])
state_so_far = []
for role, content in zip(froms, values):
msg = {"role": role, "content": content}
if role == "assistant":
action_type = classify_action(content)
rows_proposer.append({
"prompt": state_so_far.copy(),
"completion": [msg],
"action_type": action_type,
})
rows_verifier.append({
"prompt": state_so_far.copy(),
"completion": [msg],
"label": True,
"action_type": action_type,
})
rows_eval.append({
"messages": state_so_far + [msg],
"resolved": True,
"action_type": action_type,
})
state_so_far.append(msg)
print(f" -> {len(rows_proposer)} proposer rows, {len(rows_verifier)} verifier rows")
return rows_proposer, rows_verifier, rows_eval
def build_proposer_dataset(rows, hub_org):
def fmt(row):
system_msg = {
"role": "system",
"content": (
"You are an agent action predictor. Given the conversation state, "
"predict the next action from: " + ", ".join(ACTION_TYPES) + ". "
"Respond with exactly the action name and a brief justification."
),
}
prompt = [system_msg] + row["prompt"]
prompt[-1]["content"] += (
"\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES)
)
completion = row["completion"]
action_type = row["action_type"]
completion[0]["content"] = f"Action: {action_type}\n" + completion[0]["content"]
return {"prompt": prompt, "completion": completion}
data = [fmt(r) for r in rows]
ds = Dataset.from_list(data)
ds = ds.shuffle(seed=42).train_test_split(test_size=0.1)
ds.push_to_hub(f"{hub_org}/speculative-actions-proposer-sft")
print(f"Pushed proposer SFT dataset to {hub_org}/speculative-actions-proposer-sft")
return ds
def build_verifier_dataset(rows, hub_org):
rng = Random(42)
good_rows = [r for r in rows if r["label"]]
bad_rows = [r for r in rows if not r["label"]]
if len(bad_rows) < len(good_rows) * 0.2:
for r in good_rows:
wrong_action = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
bad = {
"prompt": r["prompt"],
"completion": [{"role": "assistant", "content": f"Action: {wrong_action}\n(synthetic incorrect action)"}],
"label": False,
"action_type": wrong_action,
}
bad_rows.append(bad)
pairs = []
for g in good_rows:
b = rng.choice(bad_rows)
pairs.append({
"prompt": g["prompt"],
"chosen": g["completion"],
"rejected": b["completion"],
"action_type": g["action_type"],
})
ds = Dataset.from_list(pairs)
ds = ds.shuffle(seed=42).train_test_split(test_size=0.1)
ds.push_to_hub(f"{hub_org}/speculative-actions-verifier-pref")
print(f"Pushed verifier preference dataset to {hub_org}/speculative-actions-verifier-pref")
return ds
def build_eval_dataset(rows, hub_org):
ds = Dataset.from_list(rows)
ds = ds.shuffle(seed=42).select(range(min(2_000, len(rows))))
ds.push_to_hub(f"{hub_org}/speculative-actions-eval")
print(f"Pushed eval dataset to {hub_org}/speculative-actions-eval")
return ds
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hub_org", default="narcolepticchicken", type=str)
parser.add_argument("--max_swe", type=int, default=5_000)
parser.add_argument("--max_toolbench", type=int, default=3_000)
args = parser.parse_args()
p1, v1, e1 = process_swe_smith("train", args.max_swe)
p2, v2, e2 = process_toolbench("train", args.max_toolbench)
proposer_rows = p1 + p2
verifier_rows = v1 + v2
eval_rows = e1 + e2
print(f"\nTotal rows: proposer={len(proposer_rows)}, verifier={len(verifier_rows)}, eval={len(eval_rows)}")
print("\nAction distribution (proposer):")
for act, n in Counter(r["action_type"] for r in proposer_rows).most_common():
print(f" {act}: {n}")
build_proposer_dataset(proposer_rows, args.hub_org)
build_verifier_dataset(verifier_rows, args.hub_org)
build_eval_dataset(eval_rows, args.hub_org)
print("\nDataset construction complete.")
if __name__ == "__main__":
main()