| """ |
| Speculative Tool Actions β Eval Runner v3 |
| ========================================== |
| Evaluates all 5 configurations on the same eval set. |
| |
| Config A: 8B strong model (fine-tuned on SFT) |
| Config B: 1.7B cheap proposer (fine-tuned on SFT) |
| Config C: 1.7B proposes β 8B verifier ACCEPT/REJECT; fallback to 8B on REJECT |
| Config D: 1.7B proposes β 4B verifier ACCEPT/REJECT; fallback to 8B on REJECT |
| Config E: 1.7B generates N=3 diverse proposals β 4B verifier picks best |
| |
| All models fine-tuned on same SFT data in chat-template "messages" format. |
| """ |
|
|
| import json |
| import re |
| import torch |
| from collections import Counter |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from datasets import load_dataset |
|
|
| HUB = "narcolepticchicken" |
| ACTIONS = [ |
| "tool_call", "retrieval", "file_read", "file_write", |
| "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED", |
| ] |
| COST = {"strong": 1.0, "cheap": 0.15, "verify": 0.05} |
|
|
| SYSTEM_PROMPT = ( |
| "You are an agent action predictor. Given the conversation so far, " |
| "predict the type of the next action the assistant should take. " |
| "Choose exactly one from: " + ", ".join(ACTIONS) + ". " |
| "Output only the action type name, nothing else." |
| ) |
|
|
| VERIFIER_SYSTEM = ( |
| "You are an action verifier. Given conversation context and a proposed next action, " |
| "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT." |
| ) |
|
|
|
|
| def load_proposer(model_name, adapter_id=None): |
| """Load an SFT-trained proposer model.""" |
| print(f" Loading {model_name} + {adapter_id or 'none'}") |
| tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| if adapter_id: |
| model = PeftModel.from_pretrained(model, adapter_id) |
| model.eval() |
| return model, tok |
|
|
|
|
| def load_verifier(adapter_id): |
| """Load the verifier (SFT-trained on ACCEPT/REJECT).""" |
| print(f" Loading verifier: {adapter_id}") |
| tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| "Qwen/Qwen3-4B", |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| model = PeftModel.from_pretrained(model, adapter_id) |
| model.eval() |
| return model, tok |
|
|
|
|
| def build_proposer_messages(context): |
| """Build chat messages for the proposer: system + context + query.""" |
| msgs = [{"role": "system", "content": SYSTEM_PROMPT}] |
| for m in context[-6:]: |
| msgs.append({"role": m["role"], "content": str(m["content"])[:500]}) |
| msgs.append({"role": "user", "content": "What should be the next action type?"}) |
| return msgs |
|
|
|
|
| def build_verifier_messages(context, proposal): |
| """Build chat messages for the verifier: system + context + proposal query.""" |
| msgs = [{"role": "system", "content": VERIFIER_SYSTEM}] |
| for m in context[-6:]: |
| msgs.append({"role": m["role"], "content": str(m["content"])[:400]}) |
| msgs.append({ |
| "role": "user", |
| "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?" |
| }) |
| return msgs |
|
|
|
|
| def parse_action(text): |
| """Extract action type from model output.""" |
| text = text.strip().lower() |
| for a in ACTIONS: |
| if a.lower() in text: |
| return a |
| return "tool_call" |
|
|
|
|
| @torch.no_grad() |
| def predict_action(model, tok, messages, device, do_sample=False, temperature=0.8): |
| """Generate a prediction and parse the action.""" |
| txt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inp = tok(txt, return_tensors="pt", truncation=True, max_length=2048).to(device) |
| out = model.generate( |
| **inp, |
| max_new_tokens=20, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=0.95 if do_sample else 1.0, |
| pad_token_id=tok.pad_token_id, |
| ) |
| decoded = tok.decode(out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True) |
| return parse_action(decoded) |
|
|
|
|
| @torch.no_grad() |
| def verify_action(model, tok, messages, device): |
| """Ask the verifier: ACCEPT or REJECT?""" |
| txt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inp = tok(txt, return_tensors="pt", truncation=True, max_length=1024).to(device) |
| out = model.generate( |
| **inp, |
| max_new_tokens=5, |
| do_sample=False, |
| pad_token_id=tok.pad_token_id, |
| ) |
| decoded = tok.decode(out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True).strip().lower() |
| return "accept" in decoded and "reject" not in decoded |
|
|
|
|
| def evaluate(): |
| device = "cuda" |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| if torch.cuda.device_count() > 1: |
| print(f" 2nd GPU: {torch.cuda.get_device_name(1)}") |
|
|
| |
| eval_ds = load_dataset(f"{HUB}/speculative-eval-v3-main", split="train") |
| data = list(eval_ds.select(range(min(200, len(eval_ds))))) |
| print(f"\nEvaluating {len(data)} examples") |
| dist = Counter(ex["action_type"] for ex in data) |
| print("Distribution:", dict(dist)) |
|
|
| |
| print("\nLoading models...") |
|
|
| |
| cm, ctok = load_proposer("Qwen/Qwen3-1.7B", f"{HUB}/speculative-proposer-v3-1.7b") |
| |
| sm, stok = load_proposer("Qwen/Qwen3-8B", f"{HUB}/speculative-proposer-v3-8b") |
| |
| vm, vtok = load_verifier(f"{HUB}/speculative-verifier-v3-4b") |
|
|
| results = {} |
|
|
| |
| print("\nConfig A: 8B strong only") |
| ra = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: |
| print(f" {i}/{len(data)}") |
| msgs = build_proposer_messages(ex["messages"]) |
| p = predict_action(sm, stok, msgs, device) |
| ra.append({"pred": p, "true": ex["action_type"]}) |
| acc_a = sum(1 for r in ra if r["pred"] == r["true"]) / len(ra) |
| results["A"] = {"accuracy": round(acc_a, 4), "cost": COST["strong"]} |
| print(f" Acc: {acc_a:.3f} Cost: {COST['strong']:.3f}") |
|
|
| |
| print("\nConfig B: 1.7B cheap only") |
| rb = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: |
| print(f" {i}/{len(data)}") |
| msgs = build_proposer_messages(ex["messages"]) |
| p = predict_action(cm, ctok, msgs, device) |
| rb.append({"pred": p, "true": ex["action_type"]}) |
| acc_b = sum(1 for r in rb if r["pred"] == r["true"]) / len(rb) |
| results["B"] = {"accuracy": round(acc_b, 4), "cost": COST["cheap"]} |
| print(f" Acc: {acc_b:.3f} Cost: {COST['cheap']:.3f}") |
|
|
| |
| print("\nConfig C: cheap + 8B verifier (not implemented β skipping, same as old)") |
| |
| results["C"] = {"accuracy": None, "cost": None, "note": "skipped β 8B verifier not trained"} |
|
|
| |
| print("\nConfig D: cheap + 4B verifier") |
| rd = [] |
| n_accept = 0 |
| n_fallback = 0 |
| for i, ex in enumerate(data): |
| if i % 20 == 0: |
| print(f" {i}/{len(data)}") |
| msgs = build_proposer_messages(ex["messages"]) |
| cheap_pred = predict_action(cm, ctok, msgs, device) |
|
|
| |
| vmsgs = build_verifier_messages(ex["messages"], cheap_pred) |
| accepted = verify_action(vm, vtok, vmsgs, device) |
|
|
| if accepted: |
| n_accept += 1 |
| rd.append({"pred": cheap_pred, "true": ex["action_type"], "accepted": True}) |
| else: |
| n_fallback += 1 |
| |
| strong_pred = predict_action(sm, stok, msgs, device) |
| rd.append({"pred": strong_pred, "true": ex["action_type"], "accepted": False}) |
|
|
| acc_d = sum(1 for r in rd if r["pred"] == r["true"]) / len(rd) |
| cost_d = COST["cheap"] + COST["verify"] + COST["strong"] * (n_fallback / len(data)) |
| results["D"] = { |
| "accuracy": round(acc_d, 4), |
| "cost": round(cost_d, 4), |
| "accept_rate": round(n_accept / len(data), 4), |
| } |
| print(f" Acc: {acc_d:.3f} Cost: {cost_d:.3f} Accept: {n_accept}/{len(data)} ({n_accept/len(data):.1%})") |
|
|
| |
| print("\nConfig E: multi-proposal (n=3) + 4B verifier") |
| re_results = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: |
| print(f" {i}/{len(data)}") |
| msgs = build_proposer_messages(ex["messages"]) |
|
|
| |
| proposals = set() |
| for _ in range(3): |
| p = predict_action(cm, ctok, msgs, device, do_sample=True, temperature=0.8) |
| proposals.add(p) |
|
|
| |
| scored = [] |
| for p in proposals: |
| vmsgs = build_verifier_messages(ex["messages"], p) |
| accepted = verify_action(vm, vtok, vmsgs, device) |
| scored.append((p, accepted)) |
|
|
| |
| best = next((p for p, a in scored if a), list(proposals)[0]) |
| re_results.append({"pred": best, "true": ex["action_type"]}) |
|
|
| acc_e = sum(1 for r in re_results if r["pred"] == r["true"]) / len(re_results) |
| cost_e = COST["cheap"] * 3 + COST["verify"] * 3 |
| results["E"] = {"accuracy": round(acc_e, 4), "cost": round(cost_e, 4)} |
| print(f" Acc: {acc_e:.3f} Cost: {cost_e:.3f}") |
|
|
| |
| rand_acc = 1.0 / len(ACTIONS) |
| maj_class = dist.most_common(1)[0][0] |
| maj_acc = dist[maj_class] / len(data) |
|
|
| print(f"\n{'='*65}") |
| print(f"Baselines: random={rand_acc:.3f}, majority({maj_class})={maj_acc:.3f}") |
| print(f"\n{'Config':<8} {'Acc':>8} {'Cost':>8} {'xRand':>8} {'xMaj':>8}") |
| print("-" * 55) |
| for c in ["A", "B", "D", "E"]: |
| if results[c]["accuracy"] is not None: |
| m = results[c] |
| print(f"{c:<8} {m['accuracy']:>8.3f} {m['cost']:>8.3f} {m['accuracy']/rand_acc:>8.1f} {m['accuracy']/maj_acc:>8.1f}") |
|
|
| |
| print(f"\nCOST-QUALITY FRONTIER") |
| frontier = [(c, results[c]) for c in ["A", "B", "D", "E"] if results[c]["accuracy"] is not None] |
| for c, m in sorted(frontier, key=lambda x: x[1]["cost"]): |
| print(f" {c}: cost={m['cost']:.3f} acc={m['accuracy']:.3f}") |
|
|
| |
| output = { |
| "results": results, |
| "baselines": {"random": rand_acc, "majority": maj_acc, "majority_class": maj_class}, |
| "n": len(data), |
| "distribution": dict(dist), |
| } |
| with open("/tmp/eval_v3.json", "w") as f: |
| json.dump(output, f, indent=2) |
|
|
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj="/tmp/eval_v3.json", |
| path_in_repo="eval_results_v3.json", |
| repo_id=f"{HUB}/speculative-tool-actions", |
| repo_type="model", |
| commit_message="Eval v3 results", |
| ) |
| print("\nβ Results uploaded.") |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|