speculative-tool-actions / eval_v2_clean.py
narcolepticchicken's picture
Upload eval_v2_clean.py
90d6e8c verified
import json, os, time, re
import torch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
HUB_ORG = 'narcolepticchicken'
EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
MAX_EVAL = int(os.environ.get('MAX_EVAL', '200'))
ACTIONS = ['tool_call', 'retrieval', 'file_read', 'file_write', 'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED']
COST = {'strong': 1.00, 'cheap': 0.15, 'verifier': 0.30, 'verify_check': 0.10}
SYSTEM_PROMPT = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED.
Format your response as:
Action: <action_name>
<brief reason>"""
def load_lm(model_id, device):
print(f" Loading LM: {model_id}")
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
model.eval()
return model, tok
def load_reward_model(adapter_id, device):
tok = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B', trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForSequenceClassification.from_pretrained('Qwen/Qwen3-4B', num_labels=1, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
model.config.pad_token_id = tok.pad_token_id
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
return model, tok
def parse_action(text):
m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE)
if m:
return m.group(1).lower()
for a in ACTIONS:
if a in text.lower():
return a
return 'tool_call'
def build_proposer_messages(example):
msgs = example['messages']
context = '\n'.join(f"{m['role']}: {str(m['content'])[:300]}" for m in msgs[-4:])
return [{'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': f"Predict the next action for:\n\n{context}"}]
@torch.no_grad()
def predict_action(model, tokenizer, messages, device='cuda'):
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=2048).to(device)
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.pad_token_id)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
return parse_action(response)
@torch.no_grad()
def get_reward_score(model, tokenizer, text, device='cuda'):
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=1024).to(device)
return model(**inputs).logits.squeeze().item()
@torch.no_grad()
def acc_rej(model, tokenizer, proposed_action, example_msgs, device='cuda'):
context = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:])
msgs = [
{'role': 'system', 'content': 'You are a verifier. Say ACCEPT if the proposed action is correct, REJECT if wrong. Only output ACCEPT or REJECT.'},
{'role': 'user', 'content': f'Proposed action: {proposed_action}\n\nContext:\n{context}\n\nDecision:'}
]
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=1024).to(device)
outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False, pad_token_id=tokenizer.pad_token_id)
resp = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()
return 'accept' in resp and 'reject' not in resp
def build_reward_text(proposed_action, example):
msgs = example['messages']
context = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:])
return f"User: {context}\n\nAssistant: Action: {proposed_action}"
def main():
device = 'cuda'
print(f'Device: {device}')
ds = load_dataset(EVAL_DS, split='train')
data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))]
print(f'Evaluating {len(data)} examples')
print('Loading cheap model...')
cm, ctok = load_lm(f'{HUB_ORG}/speculative-proposer-qwen3-1.7b', device)
print('Loading reward verifier...')
vm, vtok = load_reward_model(f'{HUB_ORG}/speculative-verifier-qwen3-4b', device)
print('Loading strong model...')
sm, stok = load_lm('Qwen/Qwen3-8B', device)
all_metrics = {}
# A
print('\nConfig A: strong only')
res = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f' {i}/{len(data)}')
pred = predict_action(sm, stok, build_proposer_messages(ex), device)
res.append({'pred': pred, 'true': ex['action_type'], 'cost': COST['strong']})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['A'] = {'config': 'A', 'accuracy': round(acc, 4), 'avg_cost': COST['strong'], 'n': len(res)}
print(f' Acc: {acc:.3f}')
# B
print('\nConfig B: cheap only')
res = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f' {i}/{len(data)}')
pred = predict_action(cm, ctok, build_proposer_messages(ex), device)
res.append({'pred': pred, 'true': ex['action_type'], 'cost': COST['cheap']})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['B'] = {'config': 'B', 'accuracy': round(acc, 4), 'avg_cost': COST['cheap'], 'n': len(res)}
print(f' Acc: {acc:.3f}')
# C
print('\nConfig C: cheap + strong verifier')
res = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f' {i}/{len(data)}')
cheap_pred = predict_action(cm, ctok, build_proposer_messages(ex), device)
accepted = acc_rej(sm, stok, cheap_pred, ex['messages'], device)
if accepted:
pred, cost = cheap_pred, COST['cheap'] + COST['verify_check']
else:
pred = predict_action(sm, stok, build_proposer_messages(ex), device)
cost = COST['cheap'] + COST['verify_check'] + COST['strong']
res.append({'pred': pred, 'true': ex['action_type'], 'cost': cost, 'accepted': accepted})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
ar = sum(1 for r in res if r['accepted']) / len(res)
all_metrics['C'] = {'config': 'C', 'accuracy': round(acc, 4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res), 4), 'accept_rate': round(ar, 4), 'n': len(res)}
print(f' Acc: {acc:.3f} | Accept: {ar:.3f}')
# D
THR = -1.0
print(f'\nConfig D: cheap + reward verifier (thr={THR})')
res = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f' {i}/{len(data)}')
cheap_pred = predict_action(cm, ctok, build_proposer_messages(ex), device)
score = get_reward_score(vm, vtok, build_reward_text(cheap_pred, ex), device)
res.append({'pred': cheap_pred, 'true': ex['action_type'], 'cost': COST['cheap'] + COST['verify_check'], 'accepted': score >= THR, 'score': score})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
ar = sum(1 for r in res if r['accepted']) / len(res)
scores = [r['score'] for r in res]
all_metrics['D'] = {'config': 'D', 'accuracy': round(acc, 4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res), 4), 'accept_rate': round(ar, 4), 'mean_score': round(sum(scores)/len(scores), 3), 'min_score': round(min(scores), 3), 'max_score': round(max(scores), 3), 'n': len(res)}
print(f' Acc: {acc:.3f} | Accept: {ar:.3f} | Score: {sum(scores)/len(scores):.3f}')
# E
N = 3
print(f'\nConfig E: multi-proposal (n={N})')
res = []
for i, ex in enumerate(data):
if i % 10 == 0: print(f' {i}/{len(data)}')
msgs = build_proposer_messages(ex)
text = ctok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
proposals = []
for _ in range(N):
inputs = ctok(text, return_tensors='pt', truncation=True, max_length=2048).to(device)
outputs = cm.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.8, top_p=0.95, pad_token_id=ctok.pad_token_id)
proposals.append(parse_action(ctok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)))
scored = [(p, get_reward_score(vm, vtok, build_reward_text(p, ex), device)) for p in set(proposals)]
best = max(scored, key=lambda x: x[1])[0]
res.append({'pred': best, 'true': ex['action_type'], 'cost': COST['cheap'] * N + COST['verify_check'] * N})
acc = sum(1 for r in res if r['pred'] == r['true']) / len(res)
all_metrics['E'] = {'config': 'E', 'accuracy': round(acc, 4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res), 4), 'n': len(res)}
print(f' Acc: {acc:.3f}')
print(f'\n{"="*60}')
print(f'{"Config":<6} {"Acc":>8} {"Cost":>8} {"Accept%":>10}')
print('-'*60)
for c in ['A','B','C','D','E']:
m = all_metrics[c]
ar = f'{m.get("accept_rate","-"):.3f}' if isinstance(m.get('accept_rate'), float) else '-'
print(f'{c:<6} {m["accuracy"]:>8.3f} {m["avg_cost"]:>8.3f} {ar:>10}')
print('\nCOST-QUALITY FRONTIER')
for m in sorted(all_metrics.values(), key=lambda x: x['avg_cost']):
print(f" {m['config']}: cost={m['avg_cost']:.3f} acc={m['accuracy']:.3f}")
with open('/tmp/eval_results_v2.json', 'w') as f:
json.dump({'metrics': all_metrics, 'version': 'v2-fixed-prompts', 'n': len(data)}, f, indent=2)
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(path_or_fileobj='/tmp/eval_results_v2.json', path_in_repo='eval_results_v2.json', repo_id=f'{HUB_ORG}/speculative-tool-actions', repo_type='model', commit_message='Eval v2 results (fixed prompt format)')
print('Uploaded!')
if __name__ == '__main__':
main()