speculative-tool-actions / eval_v2_cu121.py
narcolepticchicken's picture
Upload eval_v2_cu121.py
965a8e4 verified
import json, os, re
import torch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
HUB_ORG = 'narcolepticchicken'
EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
MAX_EVAL = int(os.environ.get('MAX_EVAL', '200'))
ACTIONS = ['tool_call', 'retrieval', 'file_read', 'file_write', 'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED']
COST = {'strong': 1.00, 'cheap': 0.15, 'verifier': 0.30, 'verify_check': 0.10}
SP = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED.
Format your response as:
Action: <action_name>
<brief reason>"""
def load_lm(model_id, device):
print(f" Loading LM: {model_id}")
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
model.eval()
return model, tok
def load_rm(adapter_id, device):
tok = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B', trust_remote_code=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForSequenceClassification.from_pretrained('Qwen/Qwen3-4B', num_labels=1, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
model.config.pad_token_id = tok.pad_token_id
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
return model, tok
def parse_action(text):
m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE)
if m: return m.group(1).lower()
for a in ACTIONS:
if a in text.lower(): return a
return 'tool_call'
def build_msgs(example):
msgs = example['messages']
ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:300]}" for m in msgs[-4:])
return [{'role': 'system', 'content': SP}, {'role': 'user', 'content': f"Predict the next action for:\n\n{ctx}"}]
@torch.no_grad()
def predict(model, tok, msgs, device='cuda'):
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tok(text, return_tensors='pt', truncation=True, max_length=2048).to(device)
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tok.pad_token_id)
return parse_action(tok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip())
@torch.no_grad()
def reward_score(model, tok, text, device='cuda'):
inputs = tok(text, return_tensors='pt', truncation=True, max_length=1024).to(device)
return model(**inputs).logits.squeeze().item()
@torch.no_grad()
def accept_reject(model, tok, prop, example_msgs, device='cuda'):
ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:])
msgs = [{'role': 'system', 'content': 'Say ACCEPT or REJECT only.'}, {'role': 'user', 'content': f'Proposed: {prop}\nContext:\n{ctx}\nDecision:'}]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tok(text, return_tensors='pt', truncation=True, max_length=1024).to(device)
outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False, pad_token_id=tok.pad_token_id)
resp = tok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()
return 'accept' in resp and 'reject' not in resp
def rw_text(prop, ex):
msgs = ex['messages']
ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:])
return f"User: {ctx}\n\nAssistant: Action: {prop}"
device = 'cuda'
ds = load_dataset(EVAL_DS, split='train')
data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))]
print(f'Evaluating {len(data)} examples')
cm, ctok = load_lm(f'{HUB_ORG}/speculative-proposer-qwen3-1.7b', device)
vm, vtok = load_rm(f'{HUB_ORG}/speculative-verifier-qwen3-4b', device)
sm, stok = load_lm('Qwen/Qwen3-8B', device)
all_metrics = {}
# A
print('\nConfig A: strong only')
res = [{'pred': predict(sm, stok, build_msgs(ex), device), 'true': ex['action_type'], 'cost': COST['strong']} for ex in data]
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['A'] = {'accuracy': round(acc,4), 'avg_cost': COST['strong'], 'n': len(res)}
print(f' Acc: {acc:.3f}')
# B
print('\nConfig B: cheap only')
res = [{'pred': predict(cm, ctok, build_msgs(ex), device), 'true': ex['action_type'], 'cost': COST['cheap']} for ex in data]
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['B'] = {'accuracy': round(acc,4), 'avg_cost': COST['cheap'], 'n': len(res)}
print(f' Acc: {acc:.3f}')
# C
print('\nConfig C: cheap + strong verifier')
res = []
for ex in data:
cp = predict(cm, ctok, build_msgs(ex), device)
accepted = accept_reject(sm, stok, cp, ex['messages'], device)
if accepted:
pred, cost = cp, COST['cheap'] + COST['verify_check']
else:
pred = predict(sm, stok, build_msgs(ex), device)
cost = COST['cheap'] + COST['verify_check'] + COST['strong']
res.append({'pred': pred, 'true': ex['action_type'], 'accepted': accepted, 'cost': cost})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
ar = sum(1 for r in res if r['accepted']) / len(res)
all_metrics['C'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'accept_rate': round(ar,4), 'n': len(res)}
print(f' Acc: {acc:.3f} | Accept: {ar:.3f}')
# D
THR = -1.0
print(f'\nConfig D: cheap + reward (thr={THR})')
res = []
for ex in data:
cp = predict(cm, ctok, build_msgs(ex), device)
score = reward_score(vm, vtok, rw_text(cp, ex), device)
res.append({'pred': cp, 'true': ex['action_type'], 'cost': COST['cheap'] + COST['verify_check'], 'accepted': score >= THR, 'score': score})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
ar = sum(1 for r in res if r['accepted']) / len(res)
scores = [r['score'] for r in res]
all_metrics['D'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'accept_rate': round(ar,4), 'mean_score': round(sum(scores)/len(scores),3), 'n': len(res)}
print(f' Acc: {acc:.3f} | Accept: {ar:.3f} | Score: {sum(scores)/len(scores):.3f}')
# E
N = 3
print(f'\nConfig E: multi-proposal (n={N})')
res = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f' {i}/{len(data)}')
msgs = build_msgs(ex)
text = ctok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
proposals = []
for _ in range(N):
inputs = ctok(text, return_tensors='pt', truncation=True, max_length=2048).to(device)
outputs = cm.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.8, top_p=0.95, pad_token_id=ctok.pad_token_id)
proposals.append(parse_action(ctok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)))
scored = [(p, reward_score(vm, vtok, rw_text(p, ex), device)) for p in set(proposals)]
best = max(scored, key=lambda x: x[1])[0]
res.append({'pred': best, 'true': ex['action_type'], 'cost': COST['cheap'] * N + COST['verify_check'] * N})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['E'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'n': len(res)}
print(f' Acc: {acc:.3f}')
# Baselines
from collections import Counter
dist = Counter(ex['action_type'] for ex in data)
maj = dist.most_common(1)[0][0]
maj_acc = sum(1 for ex in data if ex['action_type'] == maj) / len(data)
rand_acc = 1.0 / len(ACTIONS)
print(f'\nBaselines: random={rand_acc:.3f}, majority({maj})={maj_acc:.3f}')
print(f'\n{"="*60}')
print(f'{"Config":<6} {"Acc":>8} {"Cost":>8} {"vsRandom":>10} {"vsMaj":>8} {"Acc%":>8}')
print('-'*60)
for c in ['A','B','C','D','E']:
m = all_metrics[c]
ar = f'{m.get("accept_rate","-"):.3f}' if isinstance(m.get('accept_rate'), float) else '-'
print(f'{c:<6} {m["accuracy"]:>8.3f} {m["avg_cost"]:>8.3f} {m["accuracy"]/rand_acc:>10.1f}x {m["accuracy"]/maj_acc:>8.1f}x {ar:>8}')
print(f'\nCOST-QUALITY FRONTIER')
for m in sorted(all_metrics.values(), key=lambda x: x['avg_cost']):
print(f" {m.get('config',[k for k,v in all_metrics.items() if v is m][0])}: cost={m['avg_cost']:.3f} acc={m['accuracy']:.3f}")
out = {'metrics': all_metrics, 'baselines': {'random': rand_acc, 'majority': maj_acc, 'majority_class': maj}, 'n': len(data), 'distribution': dict(dist)}
with open('/tmp/results.json', 'w') as f: json.dump(out, f, indent=2)
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(path_or_fileobj='/tmp/results.json', path_in_repo='eval_results_v2.json', repo_id=f'{HUB_ORG}/speculative-tool-actions', repo_type='model', commit_message='Eval v2 results (cu121 fix)')
print('Done!')