speculative-tool-actions / eval_base_models.py
narcolepticchicken's picture
Upload eval_base_models.py
69159fc verified
"""
Evaluation using base models with prompt engineering.
Since fine-tuning takes hours, we use carefully crafted prompts
on base models to simulate the speculative decoding pipeline.
"""
import json, random, sys
from collections import Counter
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
HUB_ORG = 'narcolepticchicken'
EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
def load_model(name, device):
print(f'Loading {name}...', flush=True)
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
name,
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if device == 'cuda':
model = model.to(device)
return model, tok
def predict_action(model, tokenizer, prompt, device, max_new_tokens=15):
with torch.no_grad():
inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024)
if device == 'cuda':
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
return text
def parse_action(text):
text_lower = text.lower()
actions = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','blocked']
for a in actions:
if a in text_lower:
return a
return 'tool_call'
def load_eval_data(n=50):
ds = load_dataset(EVAL_DS)['test']
n = min(n, len(ds))
return [ds[i] for i in range(n)]
def build_proposer_prompt(context, task_type):
return f"""You are an AI agent. Choose ONE action from:
-tool_call
-retrieval
-file_read
-file_write
-repair
-verifier
-ask_clarification
-final_answer
-blocked
Task: {task_type}
Context: {context}
Action:"""
def build_verifier_prompt(context, task_type, proposed):
return f"""You verify if an agent action is correct.
Task: {task_type}
Context: {context}
Proposed action: {proposed}
Is this the best action? Answer with just YES or NO.
Answer:"""
def run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device):
results = {'A': [], 'B': [], 'C': [], 'D': [], 'E': []}
for i, ex in enumerate(data):
print(f'Processing {i+1}/{len(data)}...', flush=True)
# Config B: Cheap only
prompt_b = build_proposer_prompt(ex['context'], ex['task_type'])
pred_b = parse_action(predict_action(proposer, proposer_tok, prompt_b, device))
results['B'].append({'pred': pred_b, 'true': ex['action'], 'cost': 0.2})
# Config A: Strong only
prompt_a = build_proposer_prompt(ex['context'], ex['task_type'])
pred_a = parse_action(predict_action(strong, strong_tok, prompt_a, device))
results['A'].append({'pred': pred_a, 'true': ex['action'], 'cost': 1.0})
# Config C: Cheap + Strong verifier
prompt_c1 = build_proposer_prompt(ex['context'], ex['task_type'])
cheap_pred = parse_action(predict_action(proposer, proposer_tok, prompt_c1, device))
prompt_c2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred)
verify_text = predict_action(strong, strong_tok, prompt_c2, device, max_new_tokens=5)
accepted = 'yes' in verify_text.lower()
if accepted:
pred_c = cheap_pred
cost_c = 0.2 + 0.3
else:
pred_c = parse_action(predict_action(strong, strong_tok, prompt_c1, device))
cost_c = 0.2 + 0.3 + 1.0
results['C'].append({'pred': pred_c, 'true': ex['action'], 'cost': cost_c})
# Config D: Cheap + Trained verifier (simulated with base Qwen3-4B)
prompt_d1 = build_proposer_prompt(ex['context'], ex['task_type'])
cheap_pred_d = parse_action(predict_action(proposer, proposer_tok, prompt_d1, device))
prompt_d2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred_d)
verify_text_d = predict_action(verifier, verifier_tok, prompt_d2, device, max_new_tokens=5)
accepted_d = 'yes' in verify_text_d.lower()
if accepted_d:
pred_d = cheap_pred_d
cost_d = 0.2 + 0.15
else:
# Fallback to verifier's own judgment
pred_d = parse_action(predict_action(verifier, verifier_tok, prompt_d1, device))
cost_d = 0.2 + 0.15 + 0.6
results['D'].append({'pred': pred_d, 'true': ex['action'], 'cost': cost_d})
# Config E: Multi-proposal reranking
proposals = []
for _ in range(3):
prompt_e = build_proposer_prompt(ex['context'], ex['task_type'])
proposals.append(parse_action(predict_action(proposer, proposer_tok, prompt_e, device)))
scores = []
for prop in proposals:
score_prompt = f"""Rate this action 1-10 for the task.
Task: {ex['task_type']}
Context: {ex['context']}
Action: {prop}
Score (1-10):"""
score_text = predict_action(strong, strong_tok, score_prompt, device, max_new_tokens=5)
score = 5
for word in score_text.split():
try:
score = int(word.strip('.,!?'))
break
except:
pass
scores.append(score)
best_idx = scores.index(max(scores))
pred_e = proposals[best_idx]
cost_e = 0.2 * 3 + 0.3 * 3
results['E'].append({'pred': pred_e, 'true': ex['action'], 'cost': cost_e})
return results
def compute_metrics(results_list):
correct = sum(1 for r in results_list if r['pred'] == r['true'])
total = len(results_list)
accuracy = correct / total
avg_cost = sum(r['cost'] for r in results_list) / total
by_action = {}
for a in set(r['true'] for r in results_list):
subset = [r for r in results_list if r['true'] == a]
by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset)
return {'accuracy': accuracy, 'avg_cost': avg_cost, 'n': total, 'by_action': by_action}
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}', flush=True)
# Load a smaller subset for faster evaluation
data = load_eval_data(n=50)
print(f'Loaded {len(data)} eval examples', flush=True)
# Load models
proposer, proposer_tok = load_model('Qwen/Qwen3-1.7B', device)
verifier, verifier_tok = load_model('Qwen/Qwen3-4B', device)
strong, strong_tok = load_model('Qwen/Qwen2.5-7B', device)
print('Running evaluation...', flush=True)
all_results = run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device)
summary = {}
print('\n=== RESULTS ===', flush=True)
for cfg in ['A','B','C','D','E']:
metrics = compute_metrics(all_results[cfg])
summary[cfg] = metrics
print(f"Config {cfg}: Accuracy={metrics['accuracy']:.3f}, Cost={metrics['avg_cost']:.2f}", flush=True)
with open('/tmp/eval_results_empirical.json', 'w') as f:
json.dump(summary, f, indent=2)
print('\nSaved to /tmp/eval_results_empirical.json', flush=True)
if __name__ == '__main__':
main()