speculative-tool-actions / eval_runner.py
narcolepticchicken's picture
Upload eval_runner.py
0da2d19 verified
import json, random, time, os
from collections import Counter
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
HUB_ORG = 'narcolepticchicken'
EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
ACTIONS = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','BLOCKED']
ACTION_COST = {
'tool_call': 0.3, 'retrieval': 0.2, 'file_read': 0.15, 'file_write': 0.15,
'repair': 0.4, 'verifier': 0.25, 'ask_clarification': 0.1,
'final_answer': 0.2, 'BLOCKED': 0.05
}
# Load models
def load_model(name, device):
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=torch.bfloat16, trust_remote_code=True
)
model = model.to(device)
return model, tok
@torch.no_grad()
def predict_action(model, tokenizer, prompt, device):
inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=2048).to(device)
outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.pad_token_id)
text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()
for a in ACTIONS:
if a.lower() in text:
return a
return 'tool_call'
def build_prompt(context, task_type):
actions_str = ', '.join(ACTIONS)
return f"""You are an AI agent deciding the next action.
Available actions: {actions_str}
Task type: {task_type}
Context: {context}
Next action (choose exactly one from the list above):"""
def run_config_A(data, strong_model, strong_tok, device):
"""Always strong model"""
results = []
for ex in data:
prompt = build_prompt(ex['context'], ex['task_type'])
pred = predict_action(strong_model, strong_tok, prompt, device)
cost = 1.0
results.append({
'pred': pred, 'true': ex['action'],
'cost': cost, 'accepted': None,
'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED'
})
return results
def run_config_B(data, cheap_model, cheap_tok, device):
"""Cheap model only"""
results = []
for ex in data:
prompt = build_prompt(ex['context'], ex['task_type'])
pred = predict_action(cheap_model, cheap_tok, prompt, device)
cost = 0.2
results.append({
'pred': pred, 'true': ex['action'],
'cost': cost, 'accepted': None,
'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED'
})
return results
def run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device):
"""Cheap proposer + strong verifier (accept/reject)"""
results = []
for ex in data:
prompt = build_prompt(ex['context'], ex['task_type'])
cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device)
# Strong verifier checks
verify_prompt = f"""Action proposed: {cheap_pred}
Task type: {ex['task_type']}
Context: {ex['context']}
Is this action correct? Answer YES or NO:"""
verify_text = predict_action(strong_model, strong_tok, verify_prompt, device)
accepted = 'yes' in verify_text.lower()
if accepted:
pred = cheap_pred
cost = 0.2 + 0.3 # cheap + verify
else:
pred = predict_action(strong_model, strong_tok, prompt, device)
cost = 0.2 + 0.3 + 1.0 # cheap + verify + strong
results.append({
'pred': pred, 'true': ex['action'],
'cost': cost, 'accepted': accepted,
'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED'
})
return results
def run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device):
"""Cheap proposer + trained trace judge"""
results = []
for ex in data:
prompt = build_prompt(ex['context'], ex['task_type'])
cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device)
# Trained verifier judges
verify_prompt = f"""Action proposed: {cheap_pred}
Task type: {ex['task_type']}
Context: {ex['context']}
Rate this action 1-10 (10=best):"""
verify_text = predict_action(verifier_model, verifier_tok, verify_prompt, device)
# Extract numeric score
score = 5
for word in verify_text.split():
try:
score = int(word.strip('.,!?'))
break
except:
pass
accepted = score >= 7
if accepted:
pred = cheap_pred
cost = 0.2 + 0.15 # cheap + trained verifier
else:
pred = predict_action(verifier_model, verifier_tok, prompt, device)
cost = 0.2 + 0.15 + 0.6 # cheap + verifier + fallback
results.append({
'pred': pred, 'true': ex['action'],
'cost': cost, 'accepted': accepted,
'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED'
})
return results
def run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device, n_proposals=3):
"""Multi-proposal reranking"""
results = []
for ex in data:
prompt = build_prompt(ex['context'], ex['task_type'])
proposals = []
for _ in range(n_proposals):
proposals.append(predict_action(cheap_model, cheap_tok, prompt, device))
# Strong model scores each
scores = []
for prop in proposals:
score_prompt = f"""Proposed action: {prop}
Task: {ex['task_type']}
Context: {ex['context']}
Score 1-10:"""
score_text = predict_action(strong_model, strong_tok, score_prompt, device)
score = 5
for word in score_text.split():
try:
score = int(word.strip('.,!?'))
break
except:
pass
scores.append(score)
best_idx = scores.index(max(scores))
pred = proposals[best_idx]
cost = 0.2 * n_proposals + 0.3 * n_proposals
results.append({
'pred': pred, 'true': ex['action'],
'cost': cost, 'accepted': True,
'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED'
})
return results
def compute_metrics(results):
correct = sum(1 for r in results if r['pred'] == r['true'])
total = len(results)
accuracy = correct / total
avg_cost = sum(r['cost'] for r in results) / total
safe = sum(1 for r in results if r['safe']) / total
# Per-action accuracy
by_action = {}
for a in ACTIONS:
subset = [r for r in results if r['true'] == a]
if subset:
by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset)
return {
'accuracy': accuracy,
'avg_cost': avg_cost,
'safety': safe,
'n': total,
'by_action': by_action
}
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
# Load evaluation data (first 100 for speed)
print('Loading eval dataset...')
ds = load_dataset(EVAL_DS)['test']
data = [ds[i] for i in range(min(100, len(ds)))]
print(f'Evaluating on {len(data)} examples')
# Load models
print('Loading cheap model (Qwen3-1.7B)...')
cheap_model, cheap_tok = load_model('Qwen/Qwen3-1.7B', device)
print('Loading verifier model (Qwen3-4B)...')
verifier_model, verifier_tok = load_model('Qwen/Qwen3-4B', device)
print('Loading strong model (Qwen2.5-7B)...')
strong_model, strong_tok = load_model('Qwen/Qwen2.5-7B', device)
all_results = {}
print('\n=== Config A: Always Strong ===')
results_A = run_config_A(data, strong_model, strong_tok, device)
all_results['A'] = compute_metrics(results_A)
print(json.dumps(all_results['A'], indent=2))
print('\n=== Config B: Cheap Only ===')
results_B = run_config_B(data, cheap_model, cheap_tok, device)
all_results['B'] = compute_metrics(results_B)
print(json.dumps(all_results['B'], indent=2))
print('\n=== Config C: Cheap + Strong Verifier ===')
results_C = run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device)
all_results['C'] = compute_metrics(results_C)
print(json.dumps(all_results['C'], indent=2))
print('\n=== Config D: Cheap + Trained Verifier ===')
results_D = run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)
all_results['D'] = compute_metrics(results_D)
print(json.dumps(all_results['D'], indent=2))
print('\n=== Config E: Multi-Proposal Reranking ===')
results_E = run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device)
all_results['E'] = compute_metrics(results_E)
print(json.dumps(all_results['E'], indent=2))
# Save results
with open('/tmp/eval_results.json', 'w') as f:
json.dump(all_results, f, indent=2)
print('\n=== Final Comparison ===')
for cfg in ['A','B','C','D','E']:
r = all_results[cfg]
print(f"Config {cfg}: Accuracy={r['accuracy']:.3f}, Cost={r['avg_cost']:.2f}, Safety={r['safety']:.3f}")
# Upload results
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj='/tmp/eval_results.json',
path_in_repo='eval_results.json',
repo_id=f'{HUB_ORG}/speculative-tool-actions',
repo_type='model'
)
print('\nResults uploaded to Hub.')
if __name__ == '__main__':
main()