speculative-tool-actions / build_datasets_raw.py
narcolepticchicken's picture
Add raw dataset builder script
8be73c9 verified
import json, re, sys
from collections import Counter
from datasets import load_dataset, Dataset
from random import Random
HUB_ORG = "narcolepticchicken"
ACTION_TYPES = [
"tool_call","retrieval","file_read","file_write",
"repair","verifier","ask_clarification","final_answer","BLOCKED",
]
def classify_action(content, tool_calls=None):
c = (content or "").lower()
tc = json.dumps(tool_calls).lower() if tool_calls else ""
combined = c + " " + tc
if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined):
return "final_answer"
if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined):
return "ask_clarification"
if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined):
return "BLOCKED"
if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined):
return "file_write"
if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined):
return "file_read"
if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined):
return "repair"
if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined):
return "verifier"
if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined):
return "retrieval"
if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined):
return "tool_call"
return "tool_call"
def build():
print("Loading SWE-smith ...")
ds_swe = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split="train", streaming=True)
p_rows, v_rows, e_rows = [], [], []
count = 0
for ex in ds_swe:
count += 1
if count > 3000: break
msgs = ex.get("messages", [])
resolved = ex.get("resolved", False)
state = []
for msg in msgs:
role = msg.get("role", "")
if role in ("assistant", "agent"):
atype = classify_action(msg.get("content", ""), msg.get("tool_calls"))
comp = [{"role": "assistant", "content": msg.get("content", "")}]
if msg.get("tool_calls"):
comp[0]["tool_calls"] = msg["tool_calls"]
p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": atype})
v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": bool(resolved), "action_type": atype})
e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": atype})
state.append(msg)
print("Loading ToolBench ...")
ds_tb = load_dataset("tuandunghcmut/toolbench-v1", split="train", streaming=True)
count = 0
for ex in ds_tb:
count += 1
if count > 2000: break
conv = ex.get("conversations", {})
state = []
for role, content in zip(conv.get("from", []), conv.get("value", [])):
msg = {"role": role, "content": content}
if role == "assistant":
atype = classify_action(content)
p_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "action_type": atype})
v_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "label": True, "action_type": atype})
e_rows.append({"messages": [m.copy() for m in state] + [msg.copy()], "resolved": True, "action_type": atype})
state.append(msg)
print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}")
print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common())
def fmt_proposer(r):
sys_msg = {"role": "system", "content": (
"You are an agent action predictor. Predict the next action from: "
+ ", ".join(ACTION_TYPES) + ". Respond with exactly the action name and brief justification.")}
prompt = [sys_msg] + r["prompt"]
if prompt:
prompt[-1]["content"] += "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES)
comp = r["completion"]
comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"]
return {"prompt": prompt, "completion": comp}
proposer_ds = Dataset.from_list([fmt_proposer(r) for r in p_rows]).shuffle(seed=42).train_test_split(test_size=0.1)
proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft")
print("Pushed proposer dataset")
rng = Random(42)
good = [r for r in v_rows if r["label"]]
bad = [r for r in v_rows if not r["label"]]
if len(bad) < len(good) * 0.2:
for r in good:
wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
bad.append({
"prompt": [m.copy() for m in r["prompt"]],
"completion": [{"role": "assistant", "content": f"Action: {wa}\n(synthetic incorrect action)"}],
"label": False, "action_type": wa,
})
pairs = []
for g in good:
b = rng.choice(bad)
pairs.append({"prompt": [m.copy() for m in g["prompt"]], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"]})
verifier_ds = Dataset.from_list(pairs).shuffle(seed=42).train_test_split(test_size=0.1)
verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref")
print("Pushed verifier dataset")
eval_ds = Dataset.from_list(e_rows).shuffle(seed=42).select(range(min(1000, len(e_rows))))
eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval")
print("Pushed eval dataset")
print("Done.")
if __name__ == "__main__":
build()