import json, re, sys from collections import Counter from datasets import load_dataset, Dataset from random import Random HUB_ORG = "narcolepticchicken" ACTION_TYPES = [ "tool_call","retrieval","file_read","file_write", "repair","verifier","ask_clarification","final_answer","BLOCKED", ] def classify_action(content, tool_calls=None): c = (content or "").lower() tc = json.dumps(tool_calls).lower() if tool_calls else "" combined = c + " " + tc if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined): return "final_answer" if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined): return "ask_clarification" if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined): return "BLOCKED" if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined): return "file_write" if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined): return "file_read" if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined): return "repair" if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined): return "verifier" if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined): return "retrieval" if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined): return "tool_call" return "tool_call" def build(): print("Loading SWE-smith ...") ds_swe = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split="train", streaming=True) p_rows, v_rows, e_rows = [], [], [] count = 0 for ex in ds_swe: count += 1 if count > 3000: break msgs = ex.get("messages", []) resolved = ex.get("resolved", False) state = [] for msg in msgs: role = msg.get("role", "") if role in ("assistant", "agent"): atype = classify_action(msg.get("content", ""), msg.get("tool_calls")) comp = [{"role": "assistant", "content": msg.get("content", "")}] if msg.get("tool_calls"): comp[0]["tool_calls"] = msg["tool_calls"] p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": atype}) v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": bool(resolved), "action_type": atype}) e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": atype}) state.append(msg) print("Loading ToolBench ...") ds_tb = load_dataset("tuandunghcmut/toolbench-v1", split="train", streaming=True) count = 0 for ex in ds_tb: count += 1 if count > 2000: break conv = ex.get("conversations", {}) state = [] for role, content in zip(conv.get("from", []), conv.get("value", [])): msg = {"role": role, "content": content} if role == "assistant": atype = classify_action(content) p_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "action_type": atype}) v_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "label": True, "action_type": atype}) e_rows.append({"messages": [m.copy() for m in state] + [msg.copy()], "resolved": True, "action_type": atype}) state.append(msg) print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}") print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common()) def fmt_proposer(r): sys_msg = {"role": "system", "content": ( "You are an agent action predictor. Predict the next action from: " + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name and brief justification.")} prompt = [sys_msg] + r["prompt"] if prompt: prompt[-1]["content"] += "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES) comp = r["completion"] comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"] return {"prompt": prompt, "completion": comp} proposer_ds = Dataset.from_list([fmt_proposer(r) for r in p_rows]).shuffle(seed=42).train_test_split(test_size=0.1) proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft") print("Pushed proposer dataset") rng = Random(42) good = [r for r in v_rows if r["label"]] bad = [r for r in v_rows if not r["label"]] if len(bad) < len(good) * 0.2: for r in good: wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) bad.append({ "prompt": [m.copy() for m in r["prompt"]], "completion": [{"role": "assistant", "content": f"Action: {wa}\n(synthetic incorrect action)"}], "label": False, "action_type": wa, }) pairs = [] for g in good: b = rng.choice(bad) pairs.append({"prompt": [m.copy() for m in g["prompt"]], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"]}) verifier_ds = Dataset.from_list(pairs).shuffle(seed=42).train_test_split(test_size=0.1) verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref") print("Pushed verifier dataset") eval_ds = Dataset.from_list(e_rows).shuffle(seed=42).select(range(min(1000, len(e_rows)))) eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval") print("Pushed eval dataset") print("Done.") if __name__ == "__main__": build()