| import json, re, sys |
| from collections import Counter |
| from datasets import load_dataset, Dataset |
| from random import Random |
|
|
| HUB_ORG = "narcolepticchicken" |
| ACTION_TYPES = [ |
| "tool_call","retrieval","file_read","file_write", |
| "repair","verifier","ask_clarification","final_answer","BLOCKED", |
| ] |
|
|
| def classify_action(content, tool_calls=None): |
| c = (content or "").lower() |
| tc = json.dumps(tool_calls).lower() if tool_calls else "" |
| combined = c + " " + tc |
| if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined): |
| return "final_answer" |
| if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined): |
| return "ask_clarification" |
| if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined): |
| return "BLOCKED" |
| if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined): |
| return "file_write" |
| if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined): |
| return "file_read" |
| if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined): |
| return "repair" |
| if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined): |
| return "verifier" |
| if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined): |
| return "retrieval" |
| if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined): |
| return "tool_call" |
| return "tool_call" |
|
|
| def build(): |
| print("Loading SWE-smith ...") |
| ds_swe = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split="train", streaming=True) |
| p_rows, v_rows, e_rows = [], [], [] |
| count = 0 |
| for ex in ds_swe: |
| count += 1 |
| if count > 3000: break |
| msgs = ex.get("messages", []) |
| resolved = ex.get("resolved", False) |
| state = [] |
| for msg in msgs: |
| role = msg.get("role", "") |
| if role in ("assistant", "agent"): |
| atype = classify_action(msg.get("content", ""), msg.get("tool_calls")) |
| comp = [{"role": "assistant", "content": msg.get("content", "")}] |
| if msg.get("tool_calls"): |
| comp[0]["tool_calls"] = msg["tool_calls"] |
| p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": atype}) |
| v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": bool(resolved), "action_type": atype}) |
| e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": atype}) |
| state.append(msg) |
|
|
| print("Loading ToolBench ...") |
| ds_tb = load_dataset("tuandunghcmut/toolbench-v1", split="train", streaming=True) |
| count = 0 |
| for ex in ds_tb: |
| count += 1 |
| if count > 2000: break |
| conv = ex.get("conversations", {}) |
| state = [] |
| for role, content in zip(conv.get("from", []), conv.get("value", [])): |
| msg = {"role": role, "content": content} |
| if role == "assistant": |
| atype = classify_action(content) |
| p_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "action_type": atype}) |
| v_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "label": True, "action_type": atype}) |
| e_rows.append({"messages": [m.copy() for m in state] + [msg.copy()], "resolved": True, "action_type": atype}) |
| state.append(msg) |
|
|
| print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}") |
| print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common()) |
|
|
| def fmt_proposer(r): |
| sys_msg = {"role": "system", "content": ( |
| "You are an agent action predictor. Predict the next action from: " |
| + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name and brief justification.")} |
| prompt = [sys_msg] + r["prompt"] |
| if prompt: |
| prompt[-1]["content"] += "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES) |
| comp = r["completion"] |
| comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"] |
| return {"prompt": prompt, "completion": comp} |
|
|
| proposer_ds = Dataset.from_list([fmt_proposer(r) for r in p_rows]).shuffle(seed=42).train_test_split(test_size=0.1) |
| proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft") |
| print("Pushed proposer dataset") |
|
|
| rng = Random(42) |
| good = [r for r in v_rows if r["label"]] |
| bad = [r for r in v_rows if not r["label"]] |
| if len(bad) < len(good) * 0.2: |
| for r in good: |
| wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]]) |
| bad.append({ |
| "prompt": [m.copy() for m in r["prompt"]], |
| "completion": [{"role": "assistant", "content": f"Action: {wa}\n(synthetic incorrect action)"}], |
| "label": False, "action_type": wa, |
| }) |
| pairs = [] |
| for g in good: |
| b = rng.choice(bad) |
| pairs.append({"prompt": [m.copy() for m in g["prompt"]], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"]}) |
| verifier_ds = Dataset.from_list(pairs).shuffle(seed=42).train_test_split(test_size=0.1) |
| verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref") |
| print("Pushed verifier dataset") |
|
|
| eval_ds = Dataset.from_list(e_rows).shuffle(seed=42).select(range(min(1000, len(e_rows)))) |
| eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval") |
| print("Pushed eval dataset") |
| print("Done.") |
|
|
| if __name__ == "__main__": |
| build() |
|
|