| import json, random, time, os |
| from collections import Counter |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
|
|
| HUB_ORG = 'narcolepticchicken' |
| EVAL_DS = f'{HUB_ORG}/speculative-actions-eval' |
|
|
| ACTIONS = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','BLOCKED'] |
| ACTION_COST = { |
| 'tool_call': 0.3, 'retrieval': 0.2, 'file_read': 0.15, 'file_write': 0.15, |
| 'repair': 0.4, 'verifier': 0.25, 'ask_clarification': 0.1, |
| 'final_answer': 0.2, 'BLOCKED': 0.05 |
| } |
|
|
| |
| def load_model(name, device): |
| tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| name, torch_dtype=torch.bfloat16, trust_remote_code=True |
| ) |
| model = model.to(device) |
| return model, tok |
|
|
| @torch.no_grad() |
| def predict_action(model, tokenizer, prompt, device): |
| inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=2048).to(device) |
| outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.pad_token_id) |
| text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower() |
| for a in ACTIONS: |
| if a.lower() in text: |
| return a |
| return 'tool_call' |
|
|
| def build_prompt(context, task_type): |
| actions_str = ', '.join(ACTIONS) |
| return f"""You are an AI agent deciding the next action. |
| Available actions: {actions_str} |
| |
| Task type: {task_type} |
| Context: {context} |
| |
| Next action (choose exactly one from the list above):""" |
|
|
| def run_config_A(data, strong_model, strong_tok, device): |
| """Always strong model""" |
| results = [] |
| for ex in data: |
| prompt = build_prompt(ex['context'], ex['task_type']) |
| pred = predict_action(strong_model, strong_tok, prompt, device) |
| cost = 1.0 |
| results.append({ |
| 'pred': pred, 'true': ex['action'], |
| 'cost': cost, 'accepted': None, |
| 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' |
| }) |
| return results |
|
|
| def run_config_B(data, cheap_model, cheap_tok, device): |
| """Cheap model only""" |
| results = [] |
| for ex in data: |
| prompt = build_prompt(ex['context'], ex['task_type']) |
| pred = predict_action(cheap_model, cheap_tok, prompt, device) |
| cost = 0.2 |
| results.append({ |
| 'pred': pred, 'true': ex['action'], |
| 'cost': cost, 'accepted': None, |
| 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' |
| }) |
| return results |
|
|
| def run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device): |
| """Cheap proposer + strong verifier (accept/reject)""" |
| results = [] |
| for ex in data: |
| prompt = build_prompt(ex['context'], ex['task_type']) |
| cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device) |
| |
| |
| verify_prompt = f"""Action proposed: {cheap_pred} |
| Task type: {ex['task_type']} |
| Context: {ex['context']} |
| Is this action correct? Answer YES or NO:""" |
| verify_text = predict_action(strong_model, strong_tok, verify_prompt, device) |
| accepted = 'yes' in verify_text.lower() |
| |
| if accepted: |
| pred = cheap_pred |
| cost = 0.2 + 0.3 |
| else: |
| pred = predict_action(strong_model, strong_tok, prompt, device) |
| cost = 0.2 + 0.3 + 1.0 |
| |
| results.append({ |
| 'pred': pred, 'true': ex['action'], |
| 'cost': cost, 'accepted': accepted, |
| 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' |
| }) |
| return results |
|
|
| def run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device): |
| """Cheap proposer + trained trace judge""" |
| results = [] |
| for ex in data: |
| prompt = build_prompt(ex['context'], ex['task_type']) |
| cheap_pred = predict_action(cheap_model, cheap_tok, prompt, device) |
| |
| |
| verify_prompt = f"""Action proposed: {cheap_pred} |
| Task type: {ex['task_type']} |
| Context: {ex['context']} |
| Rate this action 1-10 (10=best):""" |
| verify_text = predict_action(verifier_model, verifier_tok, verify_prompt, device) |
| |
| score = 5 |
| for word in verify_text.split(): |
| try: |
| score = int(word.strip('.,!?')) |
| break |
| except: |
| pass |
| accepted = score >= 7 |
| |
| if accepted: |
| pred = cheap_pred |
| cost = 0.2 + 0.15 |
| else: |
| pred = predict_action(verifier_model, verifier_tok, prompt, device) |
| cost = 0.2 + 0.15 + 0.6 |
| |
| results.append({ |
| 'pred': pred, 'true': ex['action'], |
| 'cost': cost, 'accepted': accepted, |
| 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' |
| }) |
| return results |
|
|
| def run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device, n_proposals=3): |
| """Multi-proposal reranking""" |
| results = [] |
| for ex in data: |
| prompt = build_prompt(ex['context'], ex['task_type']) |
| proposals = [] |
| for _ in range(n_proposals): |
| proposals.append(predict_action(cheap_model, cheap_tok, prompt, device)) |
| |
| |
| scores = [] |
| for prop in proposals: |
| score_prompt = f"""Proposed action: {prop} |
| Task: {ex['task_type']} |
| Context: {ex['context']} |
| Score 1-10:""" |
| score_text = predict_action(strong_model, strong_tok, score_prompt, device) |
| score = 5 |
| for word in score_text.split(): |
| try: |
| score = int(word.strip('.,!?')) |
| break |
| except: |
| pass |
| scores.append(score) |
| |
| best_idx = scores.index(max(scores)) |
| pred = proposals[best_idx] |
| cost = 0.2 * n_proposals + 0.3 * n_proposals |
| |
| results.append({ |
| 'pred': pred, 'true': ex['action'], |
| 'cost': cost, 'accepted': True, |
| 'safe': ex['action'] != 'BLOCKED' or pred == 'BLOCKED' |
| }) |
| return results |
|
|
| def compute_metrics(results): |
| correct = sum(1 for r in results if r['pred'] == r['true']) |
| total = len(results) |
| accuracy = correct / total |
| avg_cost = sum(r['cost'] for r in results) / total |
| safe = sum(1 for r in results if r['safe']) / total |
| |
| |
| by_action = {} |
| for a in ACTIONS: |
| subset = [r for r in results if r['true'] == a] |
| if subset: |
| by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset) |
| |
| return { |
| 'accuracy': accuracy, |
| 'avg_cost': avg_cost, |
| 'safety': safe, |
| 'n': total, |
| 'by_action': by_action |
| } |
|
|
| def main(): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f'Device: {device}') |
| |
| |
| print('Loading eval dataset...') |
| ds = load_dataset(EVAL_DS)['test'] |
| data = [ds[i] for i in range(min(100, len(ds)))] |
| print(f'Evaluating on {len(data)} examples') |
| |
| |
| print('Loading cheap model (Qwen3-1.7B)...') |
| cheap_model, cheap_tok = load_model('Qwen/Qwen3-1.7B', device) |
| |
| print('Loading verifier model (Qwen3-4B)...') |
| verifier_model, verifier_tok = load_model('Qwen/Qwen3-4B', device) |
| |
| print('Loading strong model (Qwen2.5-7B)...') |
| strong_model, strong_tok = load_model('Qwen/Qwen2.5-7B', device) |
| |
| all_results = {} |
| |
| print('\n=== Config A: Always Strong ===') |
| results_A = run_config_A(data, strong_model, strong_tok, device) |
| all_results['A'] = compute_metrics(results_A) |
| print(json.dumps(all_results['A'], indent=2)) |
| |
| print('\n=== Config B: Cheap Only ===') |
| results_B = run_config_B(data, cheap_model, cheap_tok, device) |
| all_results['B'] = compute_metrics(results_B) |
| print(json.dumps(all_results['B'], indent=2)) |
| |
| print('\n=== Config C: Cheap + Strong Verifier ===') |
| results_C = run_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device) |
| all_results['C'] = compute_metrics(results_C) |
| print(json.dumps(all_results['C'], indent=2)) |
| |
| print('\n=== Config D: Cheap + Trained Verifier ===') |
| results_D = run_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device) |
| all_results['D'] = compute_metrics(results_D) |
| print(json.dumps(all_results['D'], indent=2)) |
| |
| print('\n=== Config E: Multi-Proposal Reranking ===') |
| results_E = run_config_E(data, cheap_model, cheap_tok, strong_model, strong_tok, device) |
| all_results['E'] = compute_metrics(results_E) |
| print(json.dumps(all_results['E'], indent=2)) |
| |
| |
| with open('/tmp/eval_results.json', 'w') as f: |
| json.dump(all_results, f, indent=2) |
| |
| print('\n=== Final Comparison ===') |
| for cfg in ['A','B','C','D','E']: |
| r = all_results[cfg] |
| print(f"Config {cfg}: Accuracy={r['accuracy']:.3f}, Cost={r['avg_cost']:.2f}, Safety={r['safety']:.3f}") |
| |
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj='/tmp/eval_results.json', |
| path_in_repo='eval_results.json', |
| repo_id=f'{HUB_ORG}/speculative-tool-actions', |
| repo_type='model' |
| ) |
| print('\nResults uploaded to Hub.') |
|
|
| if __name__ == '__main__': |
| main() |
|
|