File size: 5,837 Bytes
8be73c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json, re, sys
from collections import Counter
from datasets import load_dataset, Dataset
from random import Random

HUB_ORG = "narcolepticchicken"
ACTION_TYPES = [
    "tool_call","retrieval","file_read","file_write",
    "repair","verifier","ask_clarification","final_answer","BLOCKED",
]

def classify_action(content, tool_calls=None):
    c = (content or "").lower()
    tc = json.dumps(tool_calls).lower() if tool_calls else ""
    combined = c + " " + tc
    if re.search(r'\b(final answer|conclusion|summary:|in conclusion|the answer is)\b', combined):
        return "final_answer"
    if re.search(r'\b(ask for clarification|need more info|could you clarify|what do you mean)\b', combined):
        return "ask_clarification"
    if re.search(r'\b(blocked|unsafe|i cannot|i\'m sorry, but|refuse|not allowed|harmful)\b', combined):
        return "BLOCKED"
    if re.search(r'\b(write.*file|save.*file|edit.*file|patch|diff)\b', combined):
        return "file_write"
    if re.search(r'\b(read.*file|view.*file|cat |head |tail |open.*file|get_content)\b', combined):
        return "file_read"
    if re.search(r'\b(repair|fix.*bug|correct.*error|debug|resolve|try.*again with)\b', combined):
        return "repair"
    if re.search(r'\b(verify|check|validate|test|assert|review)\b', combined):
        return "verifier"
    if re.search(r'\b(search|retrieve|find|lookup|query|google|bing)\b', combined):
        return "retrieval"
    if tool_calls or re.search(r'\b(function call|tool call|invoke|execute)\b', combined):
        return "tool_call"
    return "tool_call"

def build():
    print("Loading SWE-smith ...")
    ds_swe = load_dataset("SWE-bench/SWE-smith-trajectories", "tool", split="train", streaming=True)
    p_rows, v_rows, e_rows = [], [], []
    count = 0
    for ex in ds_swe:
        count += 1
        if count > 3000: break
        msgs = ex.get("messages", [])
        resolved = ex.get("resolved", False)
        state = []
        for msg in msgs:
            role = msg.get("role", "")
            if role in ("assistant", "agent"):
                atype = classify_action(msg.get("content", ""), msg.get("tool_calls"))
                comp = [{"role": "assistant", "content": msg.get("content", "")}]
                if msg.get("tool_calls"):
                    comp[0]["tool_calls"] = msg["tool_calls"]
                p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": atype})
                v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": bool(resolved), "action_type": atype})
                e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": atype})
            state.append(msg)

    print("Loading ToolBench ...")
    ds_tb = load_dataset("tuandunghcmut/toolbench-v1", split="train", streaming=True)
    count = 0
    for ex in ds_tb:
        count += 1
        if count > 2000: break
        conv = ex.get("conversations", {})
        state = []
        for role, content in zip(conv.get("from", []), conv.get("value", [])):
            msg = {"role": role, "content": content}
            if role == "assistant":
                atype = classify_action(content)
                p_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "action_type": atype})
                v_rows.append({"prompt": [m.copy() for m in state], "completion": [msg.copy()], "label": True, "action_type": atype})
                e_rows.append({"messages": [m.copy() for m in state] + [msg.copy()], "resolved": True, "action_type": atype})
            state.append(msg)

    print(f"Rows: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}")
    print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common())

    def fmt_proposer(r):
        sys_msg = {"role": "system", "content": (
            "You are an agent action predictor. Predict the next action from: "
            + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name and brief justification.")}
        prompt = [sys_msg] + r["prompt"]
        if prompt:
            prompt[-1]["content"] += "\n\n[Next Action Prediction] Choose one: " + ", ".join(ACTION_TYPES)
        comp = r["completion"]
        comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"]
        return {"prompt": prompt, "completion": comp}

    proposer_ds = Dataset.from_list([fmt_proposer(r) for r in p_rows]).shuffle(seed=42).train_test_split(test_size=0.1)
    proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft")
    print("Pushed proposer dataset")

    rng = Random(42)
    good = [r for r in v_rows if r["label"]]
    bad = [r for r in v_rows if not r["label"]]
    if len(bad) < len(good) * 0.2:
        for r in good:
            wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
            bad.append({
                "prompt": [m.copy() for m in r["prompt"]],
                "completion": [{"role": "assistant", "content": f"Action: {wa}\n(synthetic incorrect action)"}],
                "label": False, "action_type": wa,
            })
    pairs = []
    for g in good:
        b = rng.choice(bad)
        pairs.append({"prompt": [m.copy() for m in g["prompt"]], "chosen": g["completion"], "rejected": b["completion"], "action_type": g["action_type"]})
    verifier_ds = Dataset.from_list(pairs).shuffle(seed=42).train_test_split(test_size=0.1)
    verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref")
    print("Pushed verifier dataset")

    eval_ds = Dataset.from_list(e_rows).shuffle(seed=42).select(range(min(1000, len(e_rows))))
    eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval")
    print("Pushed eval dataset")
    print("Done.")

if __name__ == "__main__":
    build()