| """ |
| Evaluation using base models with prompt engineering. |
| Since fine-tuning takes hours, we use carefully crafted prompts |
| on base models to simulate the speculative decoding pipeline. |
| """ |
| import json, random, sys |
| from collections import Counter |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
|
|
| HUB_ORG = 'narcolepticchicken' |
| EVAL_DS = f'{HUB_ORG}/speculative-actions-eval' |
|
|
| def load_model(name, device): |
| print(f'Loading {name}...', flush=True) |
| tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| name, |
| torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
| if device == 'cuda': |
| model = model.to(device) |
| return model, tok |
|
|
| def predict_action(model, tokenizer, prompt, device, max_new_tokens=15): |
| with torch.no_grad(): |
| inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024) |
| if device == 'cuda': |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
| return text |
|
|
| def parse_action(text): |
| text_lower = text.lower() |
| actions = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','blocked'] |
| for a in actions: |
| if a in text_lower: |
| return a |
| return 'tool_call' |
|
|
| def load_eval_data(n=50): |
| ds = load_dataset(EVAL_DS)['test'] |
| n = min(n, len(ds)) |
| return [ds[i] for i in range(n)] |
|
|
| def build_proposer_prompt(context, task_type): |
| return f"""You are an AI agent. Choose ONE action from: |
| -tool_call |
| -retrieval |
| -file_read |
| -file_write |
| -repair |
| -verifier |
| -ask_clarification |
| -final_answer |
| -blocked |
| |
| Task: {task_type} |
| Context: {context} |
| |
| Action:""" |
|
|
| def build_verifier_prompt(context, task_type, proposed): |
| return f"""You verify if an agent action is correct. |
| Task: {task_type} |
| Context: {context} |
| Proposed action: {proposed} |
| |
| Is this the best action? Answer with just YES or NO. |
| |
| Answer:""" |
|
|
| def run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device): |
| results = {'A': [], 'B': [], 'C': [], 'D': [], 'E': []} |
| |
| for i, ex in enumerate(data): |
| print(f'Processing {i+1}/{len(data)}...', flush=True) |
| |
| |
| prompt_b = build_proposer_prompt(ex['context'], ex['task_type']) |
| pred_b = parse_action(predict_action(proposer, proposer_tok, prompt_b, device)) |
| results['B'].append({'pred': pred_b, 'true': ex['action'], 'cost': 0.2}) |
| |
| |
| prompt_a = build_proposer_prompt(ex['context'], ex['task_type']) |
| pred_a = parse_action(predict_action(strong, strong_tok, prompt_a, device)) |
| results['A'].append({'pred': pred_a, 'true': ex['action'], 'cost': 1.0}) |
| |
| |
| prompt_c1 = build_proposer_prompt(ex['context'], ex['task_type']) |
| cheap_pred = parse_action(predict_action(proposer, proposer_tok, prompt_c1, device)) |
| prompt_c2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred) |
| verify_text = predict_action(strong, strong_tok, prompt_c2, device, max_new_tokens=5) |
| accepted = 'yes' in verify_text.lower() |
| if accepted: |
| pred_c = cheap_pred |
| cost_c = 0.2 + 0.3 |
| else: |
| pred_c = parse_action(predict_action(strong, strong_tok, prompt_c1, device)) |
| cost_c = 0.2 + 0.3 + 1.0 |
| results['C'].append({'pred': pred_c, 'true': ex['action'], 'cost': cost_c}) |
| |
| |
| prompt_d1 = build_proposer_prompt(ex['context'], ex['task_type']) |
| cheap_pred_d = parse_action(predict_action(proposer, proposer_tok, prompt_d1, device)) |
| prompt_d2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred_d) |
| verify_text_d = predict_action(verifier, verifier_tok, prompt_d2, device, max_new_tokens=5) |
| accepted_d = 'yes' in verify_text_d.lower() |
| if accepted_d: |
| pred_d = cheap_pred_d |
| cost_d = 0.2 + 0.15 |
| else: |
| |
| pred_d = parse_action(predict_action(verifier, verifier_tok, prompt_d1, device)) |
| cost_d = 0.2 + 0.15 + 0.6 |
| results['D'].append({'pred': pred_d, 'true': ex['action'], 'cost': cost_d}) |
| |
| |
| proposals = [] |
| for _ in range(3): |
| prompt_e = build_proposer_prompt(ex['context'], ex['task_type']) |
| proposals.append(parse_action(predict_action(proposer, proposer_tok, prompt_e, device))) |
| |
| scores = [] |
| for prop in proposals: |
| score_prompt = f"""Rate this action 1-10 for the task. |
| Task: {ex['task_type']} |
| Context: {ex['context']} |
| Action: {prop} |
| |
| Score (1-10):""" |
| score_text = predict_action(strong, strong_tok, score_prompt, device, max_new_tokens=5) |
| score = 5 |
| for word in score_text.split(): |
| try: |
| score = int(word.strip('.,!?')) |
| break |
| except: |
| pass |
| scores.append(score) |
| |
| best_idx = scores.index(max(scores)) |
| pred_e = proposals[best_idx] |
| cost_e = 0.2 * 3 + 0.3 * 3 |
| results['E'].append({'pred': pred_e, 'true': ex['action'], 'cost': cost_e}) |
| |
| return results |
|
|
| def compute_metrics(results_list): |
| correct = sum(1 for r in results_list if r['pred'] == r['true']) |
| total = len(results_list) |
| accuracy = correct / total |
| avg_cost = sum(r['cost'] for r in results_list) / total |
| |
| by_action = {} |
| for a in set(r['true'] for r in results_list): |
| subset = [r for r in results_list if r['true'] == a] |
| by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset) |
| |
| return {'accuracy': accuracy, 'avg_cost': avg_cost, 'n': total, 'by_action': by_action} |
|
|
| def main(): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f'Device: {device}', flush=True) |
| |
| |
| data = load_eval_data(n=50) |
| print(f'Loaded {len(data)} eval examples', flush=True) |
| |
| |
| proposer, proposer_tok = load_model('Qwen/Qwen3-1.7B', device) |
| verifier, verifier_tok = load_model('Qwen/Qwen3-4B', device) |
| strong, strong_tok = load_model('Qwen/Qwen2.5-7B', device) |
| |
| print('Running evaluation...', flush=True) |
| all_results = run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device) |
| |
| summary = {} |
| print('\n=== RESULTS ===', flush=True) |
| for cfg in ['A','B','C','D','E']: |
| metrics = compute_metrics(all_results[cfg]) |
| summary[cfg] = metrics |
| print(f"Config {cfg}: Accuracy={metrics['accuracy']:.3f}, Cost={metrics['avg_cost']:.2f}", flush=True) |
| |
| with open('/tmp/eval_results_empirical.json', 'w') as f: |
| json.dump(summary, f, indent=2) |
| |
| print('\nSaved to /tmp/eval_results_empirical.json', flush=True) |
|
|
| if __name__ == '__main__': |
| main() |
|
|