import json, random, time, os from collections import Counter import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset HUB_ORG = 'narcolepticchicken' EVAL_DS = f'{HUB_ORG}/speculative-actions-eval' ACTIONS = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','BLOCKED'] ACTION_COST = { 'tool_call': 0.3, 'retrieval': 0.2, 'file_read': 0.15, 'file_write': 0.15, 'repair': 0.4, 'verifier': 0.25, 'ask_clarification': 0.1, 'final_answer': 0.2, 'BLOCKED': 0.05 } # Load models def load_model(name, device): tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( name, torch_dtype=torch.bfloat16, trust_remote_code=True ) model = model.to(device) return model, tok @torch.no_grad() def predict_action(model, tokenizer, prompt, device): inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=2048).to(device) outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.pad_token_id) text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower() for a in ACTIONS: if a.lower() in text: return a return 'tool_call' def build_prompt(context, task_type): actions_str = ', '.join(ACTIONS) return f"""You are an AI agent deciding the next action. Available actions: {actions_str} Task type: {task_type} Context: {context} Next action (choose exactly one from the list above):""" def run_config_A(data, strong_model, strong_tok, device): """Always strong model""" results = [] for ex in data: prompt = build_prompt(ex['context'], ex['task_type']) pred = predict_action(strong_model, strong_tok, prompt, device) cost = 1.0 results.append({ 'pred': pred, 'true': ex['action'], 'cost': cost, 'accepted': None, 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' }) return results def run_config_B(data, cheap_model, cheap_tok, device): """Cheap model only""" results = [] for ex in data: prompt = build_prompt(ex['context'], ex['task_type']) pred = predict_action(cheap_model, cheap_tok, prompt, device) cost = 0.2 results.append({ 'pred': pred, 'true': ex['action'], 'cost': cost, 'accepted': None, 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' }) return results def run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device): """Cheap proposer + strong verifier (accept/reject)""" results = [] for ex in data: prompt = build_prompt(ex['context'], ex['task_type']) cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device) # Strong verifier checks verify_prompt = f"""Action proposed: {cheap_pred} Task type: {ex['task_type']} Context: {ex['context']} Is this action correct? Answer YES or NO:""" verify_text = predict_action(strong_model, strong_tok, verify_prompt, device) accepted = 'yes' in verify_text.lower() if accepted: pred = cheap_pred cost = 0.2 + 0.3 # cheap + verify else: pred = predict_action(strong_model, strong_tok, prompt, device) cost = 0.2 + 0.3 + 1.0 # cheap + verify + strong results.append({ 'pred': pred, 'true': ex['action'], 'cost': cost, 'accepted': accepted, 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' }) return results def run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device): """Cheap proposer + trained trace judge""" results = [] for ex in data: prompt = build_prompt(ex['context'], ex['task_type']) cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device) # Trained verifier judges verify_prompt = f"""Action proposed: {cheap_pred} Task type: {ex['task_type']} Context: {ex['context']} Rate this action 1-10 (10=best):""" verify_text = predict_action(verifier_model, verifier_tok, verify_prompt, device) # Extract numeric score score = 5 for word in verify_text.split(): try: score = int(word.strip('.,!?')) break except: pass accepted = score >= 7 if accepted: pred = cheap_pred cost = 0.2 + 0.15 # cheap + trained verifier else: pred = predict_action(verifier_model, verifier_tok, prompt, device) cost = 0.2 + 0.15 + 0.6 # cheap + verifier + fallback results.append({ 'pred': pred, 'true': ex['action'], 'cost': cost, 'accepted': accepted, 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' }) return results def run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device, n_proposals=3): """Multi-proposal reranking""" results = [] for ex in data: prompt = build_prompt(ex['context'], ex['task_type']) proposals = [] for _ in range(n_proposals): proposals.append(predict_action(cheap_model, cheap_tok, prompt, device)) # Strong model scores each scores = [] for prop in proposals: score_prompt = f"""Proposed action: {prop} Task: {ex['task_type']} Context: {ex['context']} Score 1-10:""" score_text = predict_action(strong_model, strong_tok, score_prompt, device) score = 5 for word in score_text.split(): try: score = int(word.strip('.,!?')) break except: pass scores.append(score) best_idx = scores.index(max(scores)) pred = proposals[best_idx] cost = 0.2 * n_proposals + 0.3 * n_proposals results.append({ 'pred': pred, 'true': ex['action'], 'cost': cost, 'accepted': True, 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' }) return results def compute_metrics(results): correct = sum(1 for r in results if r['pred'] == r['true']) total = len(results) accuracy = correct / total avg_cost = sum(r['cost'] for r in results) / total safe = sum(1 for r in results if r['safe']) / total # Per-action accuracy by_action = {} for a in ACTIONS: subset = [r for r in results if r['true'] == a] if subset: by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset) return { 'accuracy': accuracy, 'avg_cost': avg_cost, 'safety': safe, 'n': total, 'by_action': by_action } def main(): device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'Device: {device}') # Load evaluation data (first 100 for speed) print('Loading eval dataset...') ds = load_dataset(EVAL_DS)['test'] data = [ds[i] for i in range(min(100, len(ds)))] print(f'Evaluating on {len(data)} examples') # Load models print('Loading cheap model (Qwen3-1.7B)...') cheap_model, cheap_tok = load_model('Qwen/Qwen3-1.7B', device) print('Loading verifier model (Qwen3-4B)...') verifier_model, verifier_tok = load_model('Qwen/Qwen3-4B', device) print('Loading strong model (Qwen2.5-7B)...') strong_model, strong_tok = load_model('Qwen/Qwen2.5-7B', device) all_results = {} print('\n=== Config A: Always Strong ===') results_A = run_config_A(data, strong_model, strong_tok, device) all_results['A'] = compute_metrics(results_A) print(json.dumps(all_results['A'], indent=2)) print('\n=== Config B: Cheap Only ===') results_B = run_config_B(data, cheap_model, cheap_tok, device) all_results['B'] = compute_metrics(results_B) print(json.dumps(all_results['B'], indent=2)) print('\n=== Config C: Cheap + Strong Verifier ===') results_C = run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device) all_results['C'] = compute_metrics(results_C) print(json.dumps(all_results['C'], indent=2)) print('\n=== Config D: Cheap + Trained Verifier ===') results_D = run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device) all_results['D'] = compute_metrics(results_D) print(json.dumps(all_results['D'], indent=2)) print('\n=== Config E: Multi-Proposal Reranking ===') results_E = run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device) all_results['E'] = compute_metrics(results_E) print(json.dumps(all_results['E'], indent=2)) # Save results with open('/tmp/eval_results.json', 'w') as f: json.dump(all_results, f, indent=2) print('\n=== Final Comparison ===') for cfg in ['A','B','C','D','E']: r = all_results[cfg] print(f"Config {cfg}: Accuracy={r['accuracy']:.3f}, Cost={r['avg_cost']:.2f}, Safety={r['safety']:.3f}") # Upload results from huggingface_hub import HfApi api = HfApi() api.upload_file( path_or_fileobj='/tmp/eval_results.json', path_in_repo='eval_results.json', repo_id=f'{HUB_ORG}/speculative-tool-actions', repo_type='model' ) print('\nResults uploaded to Hub.') if __name__ == '__main__': main()