speculative-tool-actions / eval_runner_v3.py
narcolepticchicken's picture
Upload eval_runner_v3.py
4a1299b verified
"""
Speculative Tool Actions β€” Eval Runner v3
==========================================
Evaluates all 5 configurations on the same eval set.
Config A: 8B strong model (fine-tuned on SFT)
Config B: 1.7B cheap proposer (fine-tuned on SFT)
Config C: 1.7B proposes β†’ 8B verifier ACCEPT/REJECT; fallback to 8B on REJECT
Config D: 1.7B proposes β†’ 4B verifier ACCEPT/REJECT; fallback to 8B on REJECT
Config E: 1.7B generates N=3 diverse proposals β†’ 4B verifier picks best
All models fine-tuned on same SFT data in chat-template "messages" format.
"""
import json
import re
import torch
from collections import Counter
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
HUB = "narcolepticchicken"
ACTIONS = [
"tool_call", "retrieval", "file_read", "file_write",
"repair", "verifier", "ask_clarification", "final_answer", "BLOCKED",
]
COST = {"strong": 1.0, "cheap": 0.15, "verify": 0.05}
SYSTEM_PROMPT = (
"You are an agent action predictor. Given the conversation so far, "
"predict the type of the next action the assistant should take. "
"Choose exactly one from: " + ", ".join(ACTIONS) + ". "
"Output only the action type name, nothing else."
)
VERIFIER_SYSTEM = (
"You are an action verifier. Given conversation context and a proposed next action, "
"determine if the proposal is correct. Respond with exactly ACCEPT or REJECT."
)
def load_proposer(model_name, adapter_id=None):
"""Load an SFT-trained proposer model."""
print(f" Loading {model_name} + {adapter_id or 'none'}")
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
if adapter_id:
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
return model, tok
def load_verifier(adapter_id):
"""Load the verifier (SFT-trained on ACCEPT/REJECT)."""
print(f" Loading verifier: {adapter_id}")
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
return model, tok
def build_proposer_messages(context):
"""Build chat messages for the proposer: system + context + query."""
msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
for m in context[-6:]:
msgs.append({"role": m["role"], "content": str(m["content"])[:500]})
msgs.append({"role": "user", "content": "What should be the next action type?"})
return msgs
def build_verifier_messages(context, proposal):
"""Build chat messages for the verifier: system + context + proposal query."""
msgs = [{"role": "system", "content": VERIFIER_SYSTEM}]
for m in context[-6:]:
msgs.append({"role": m["role"], "content": str(m["content"])[:400]})
msgs.append({
"role": "user",
"content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?"
})
return msgs
def parse_action(text):
"""Extract action type from model output."""
text = text.strip().lower()
for a in ACTIONS:
if a.lower() in text:
return a
return "tool_call"
@torch.no_grad()
def predict_action(model, tok, messages, device, do_sample=False, temperature=0.8):
"""Generate a prediction and parse the action."""
txt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inp = tok(txt, return_tensors="pt", truncation=True, max_length=2048).to(device)
out = model.generate(
**inp,
max_new_tokens=20,
do_sample=do_sample,
temperature=temperature,
top_p=0.95 if do_sample else 1.0,
pad_token_id=tok.pad_token_id,
)
decoded = tok.decode(out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True)
return parse_action(decoded)
@torch.no_grad()
def verify_action(model, tok, messages, device):
"""Ask the verifier: ACCEPT or REJECT?"""
txt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inp = tok(txt, return_tensors="pt", truncation=True, max_length=1024).to(device)
out = model.generate(
**inp,
max_new_tokens=5,
do_sample=False,
pad_token_id=tok.pad_token_id,
)
decoded = tok.decode(out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True).strip().lower()
return "accept" in decoded and "reject" not in decoded
def evaluate():
device = "cuda"
print(f"GPU: {torch.cuda.get_device_name(0)}")
if torch.cuda.device_count() > 1:
print(f" 2nd GPU: {torch.cuda.get_device_name(1)}")
# Load eval data
eval_ds = load_dataset(f"{HUB}/speculative-eval-v3-main", split="train")
data = list(eval_ds.select(range(min(200, len(eval_ds)))))
print(f"\nEvaluating {len(data)} examples")
dist = Counter(ex["action_type"] for ex in data)
print("Distribution:", dict(dist))
# Load models
print("\nLoading models...")
# Proposer (1.7B)
cm, ctok = load_proposer("Qwen/Qwen3-1.7B", f"{HUB}/speculative-proposer-v3-1.7b")
# Strong model (8B)
sm, stok = load_proposer("Qwen/Qwen3-8B", f"{HUB}/speculative-proposer-v3-8b")
# Verifier (4B)
vm, vtok = load_verifier(f"{HUB}/speculative-verifier-v3-4b")
results = {}
# ── Config A: Strong only ───────────────────────────────
print("\nConfig A: 8B strong only")
ra = []
for i, ex in enumerate(data):
if i % 20 == 0:
print(f" {i}/{len(data)}")
msgs = build_proposer_messages(ex["messages"])
p = predict_action(sm, stok, msgs, device)
ra.append({"pred": p, "true": ex["action_type"]})
acc_a = sum(1 for r in ra if r["pred"] == r["true"]) / len(ra)
results["A"] = {"accuracy": round(acc_a, 4), "cost": COST["strong"]}
print(f" Acc: {acc_a:.3f} Cost: {COST['strong']:.3f}")
# ── Config B: Cheap only ────────────────────────────────
print("\nConfig B: 1.7B cheap only")
rb = []
for i, ex in enumerate(data):
if i % 20 == 0:
print(f" {i}/{len(data)}")
msgs = build_proposer_messages(ex["messages"])
p = predict_action(cm, ctok, msgs, device)
rb.append({"pred": p, "true": ex["action_type"]})
acc_b = sum(1 for r in rb if r["pred"] == r["true"]) / len(rb)
results["B"] = {"accuracy": round(acc_b, 4), "cost": COST["cheap"]}
print(f" Acc: {acc_b:.3f} Cost: {COST['cheap']:.3f}")
# ── Config C: Cheap + 8B verifier ───────────────────────
print("\nConfig C: cheap + 8B verifier (not implemented β€” skipping, same as old)")
# The 8B verifier was never properly trained. We'll skip this and focus on D.
results["C"] = {"accuracy": None, "cost": None, "note": "skipped β€” 8B verifier not trained"}
# ── Config D: Cheap + 4B verifier ───────────────────────
print("\nConfig D: cheap + 4B verifier")
rd = []
n_accept = 0
n_fallback = 0
for i, ex in enumerate(data):
if i % 20 == 0:
print(f" {i}/{len(data)}")
msgs = build_proposer_messages(ex["messages"])
cheap_pred = predict_action(cm, ctok, msgs, device)
# Verify
vmsgs = build_verifier_messages(ex["messages"], cheap_pred)
accepted = verify_action(vm, vtok, vmsgs, device)
if accepted:
n_accept += 1
rd.append({"pred": cheap_pred, "true": ex["action_type"], "accepted": True})
else:
n_fallback += 1
# Fall back to strong model
strong_pred = predict_action(sm, stok, msgs, device)
rd.append({"pred": strong_pred, "true": ex["action_type"], "accepted": False})
acc_d = sum(1 for r in rd if r["pred"] == r["true"]) / len(rd)
cost_d = COST["cheap"] + COST["verify"] + COST["strong"] * (n_fallback / len(data))
results["D"] = {
"accuracy": round(acc_d, 4),
"cost": round(cost_d, 4),
"accept_rate": round(n_accept / len(data), 4),
}
print(f" Acc: {acc_d:.3f} Cost: {cost_d:.3f} Accept: {n_accept}/{len(data)} ({n_accept/len(data):.1%})")
# ── Config E: Multi-proposal reranking ──────────────────
print("\nConfig E: multi-proposal (n=3) + 4B verifier")
re_results = []
for i, ex in enumerate(data):
if i % 20 == 0:
print(f" {i}/{len(data)}")
msgs = build_proposer_messages(ex["messages"])
# Generate 3 diverse proposals
proposals = set()
for _ in range(3):
p = predict_action(cm, ctok, msgs, device, do_sample=True, temperature=0.8)
proposals.add(p)
# Score each with verifier
scored = []
for p in proposals:
vmsgs = build_verifier_messages(ex["messages"], p)
accepted = verify_action(vm, vtok, vmsgs, device)
scored.append((p, accepted))
# Pick the first ACCEPT, or fall back to first proposal
best = next((p for p, a in scored if a), list(proposals)[0])
re_results.append({"pred": best, "true": ex["action_type"]})
acc_e = sum(1 for r in re_results if r["pred"] == r["true"]) / len(re_results)
cost_e = COST["cheap"] * 3 + COST["verify"] * 3 # 3 proposals, 3 verifications
results["E"] = {"accuracy": round(acc_e, 4), "cost": round(cost_e, 4)}
print(f" Acc: {acc_e:.3f} Cost: {cost_e:.3f}")
# ── Baselines & Summary ─────────────────────────────────
rand_acc = 1.0 / len(ACTIONS)
maj_class = dist.most_common(1)[0][0]
maj_acc = dist[maj_class] / len(data)
print(f"\n{'='*65}")
print(f"Baselines: random={rand_acc:.3f}, majority({maj_class})={maj_acc:.3f}")
print(f"\n{'Config':<8} {'Acc':>8} {'Cost':>8} {'xRand':>8} {'xMaj':>8}")
print("-" * 55)
for c in ["A", "B", "D", "E"]:
if results[c]["accuracy"] is not None:
m = results[c]
print(f"{c:<8} {m['accuracy']:>8.3f} {m['cost']:>8.3f} {m['accuracy']/rand_acc:>8.1f} {m['accuracy']/maj_acc:>8.1f}")
# ── Cost-quality frontier ────────────────────────────────
print(f"\nCOST-QUALITY FRONTIER")
frontier = [(c, results[c]) for c in ["A", "B", "D", "E"] if results[c]["accuracy"] is not None]
for c, m in sorted(frontier, key=lambda x: x[1]["cost"]):
print(f" {c}: cost={m['cost']:.3f} acc={m['accuracy']:.3f}")
# ── Save ─────────────────────────────────────────────────
output = {
"results": results,
"baselines": {"random": rand_acc, "majority": maj_acc, "majority_class": maj_class},
"n": len(data),
"distribution": dict(dist),
}
with open("/tmp/eval_v3.json", "w") as f:
json.dump(output, f, indent=2)
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj="/tmp/eval_v3.json",
path_in_repo="eval_results_v3.json",
repo_id=f"{HUB}/speculative-tool-actions",
repo_type="model",
commit_message="Eval v3 results",
)
print("\nβœ“ Results uploaded.")
if __name__ == "__main__":
evaluate()