| """Speculative Tool Actions — Evaluation Runner (v2) |
| ====================================================== |
| Fixed: prompt format matches training data format (Action: <type> prefix). |
| Training data uses: system prompt + context → "Action: <type>\n<reason>" |
| Eval now uses the same chat template format that training used. |
| |
| Evaluates 5 configurations: |
| A: Always strong model (Qwen3-8B) |
| B: Cheap model only (Qwen3-1.7B trained proposer) |
| C: Cheap proposer + strong verifier (8B ACCEPT/REJECT) |
| D: Cheap proposer + trained reward model scorer |
| E: Multi-proposal reranking (reward model scores N proposals) |
| """ |
|
|
| import json, os, time, re |
| import torch |
| from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer |
| from peft import PeftModel |
| from datasets import load_dataset |
|
|
| |
| HUB_ORG = 'narcolepticchicken' |
| EVAL_DS = f'{HUB_ORG}/speculative-actions-eval' |
| MAX_EVAL = int(os.environ.get('MAX_EVAL', '200')) |
|
|
| ACTIONS = [ |
| 'tool_call', 'retrieval', 'file_read', 'file_write', |
| 'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED' |
| ] |
|
|
| COST = { |
| 'strong': 1.00, |
| 'cheap': 0.15, |
| 'verifier': 0.30, |
| 'verify_check': 0.10, |
| } |
|
|
| |
| def load_lm(model_id, device): |
| print(f" Loading LM: {model_id}") |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, torch_dtype=torch.bfloat16, device_map='auto', |
| trust_remote_code=True, |
| ) |
| model.eval() |
| return model, tok |
|
|
| def load_reward_model(adapter_id, device): |
| base_model = 'Qwen/Qwen3-4B' |
| print(f" Loading reward model base: {base_model}") |
| tok = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForSequenceClassification.from_pretrained( |
| base_model, num_labels=1, |
| torch_dtype=torch.bfloat16, device_map='auto', |
| trust_remote_code=True, |
| ) |
| model.config.pad_token_id = tok.pad_token_id |
| print(f" Loading LoRA adapter: {adapter_id}") |
| model = PeftModel.from_pretrained(model, adapter_id) |
| model.eval() |
| return model, tok |
|
|
| |
| def parse_action(text): |
| """Parse action from model output. Looks for 'Action: <type>' prefix.""" |
| m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE) |
| if m: |
| return m.group(1).lower() |
| |
| lower = text.lower() |
| for a in ACTIONS: |
| if a.lower() in lower: |
| return a |
| return 'tool_call' |
|
|
| |
| SYSTEM_PROMPT = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED. |
| |
| Format your response as: |
| Action: <action_name> |
| <brief reason>""" |
|
|
| def build_proposer_messages(example): |
| """Build messages list matching training format: system + context.""" |
| msgs = example['messages'] |
| |
| context_lines = [] |
| for m in msgs[-4:]: |
| context_lines.append(f"{m['role']}: {str(m['content'])[:300]}") |
| context = '\n'.join(context_lines) |
|
|
| return [ |
| {'role': 'system', 'content': SYSTEM_PROMPT}, |
| {'role': 'user', 'content': f"Predict the next action for:\n\n{context}"}, |
| ] |
|
|
| @torch.no_grad() |
| def predict_action(model, tokenizer, messages, device='cuda'): |
| """Predict action using chat template (matching training format).""" |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, |
| max_length=2048).to(device) |
| outputs = model.generate( |
| **inputs, max_new_tokens=50, do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| response = tokenizer.decode( |
| outputs[0][inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True |
| ).strip() |
| return parse_action(response) |
|
|
| @torch.no_grad() |
| def get_reward_score(model, tokenizer, text, device='cuda'): |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, |
| max_length=1024).to(device) |
| score = model(**inputs).logits.squeeze().item() |
| return score |
|
|
| @torch.no_grad() |
| def predict_accept_reject(model, tokenizer, proposed_action, example_msgs, device='cuda'): |
| """Strong verifier: ACCEPT or REJECT using chat template.""" |
| context = '\n'.join( |
| f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:] |
| ) |
| msgs = [ |
| {'role': 'system', 'content': 'You are a verifier. Say ACCEPT if the proposed action is correct, REJECT if wrong. Only output ACCEPT or REJECT.'}, |
| {'role': 'user', 'content': f'Proposed action: {proposed_action}\n\nContext:\n{context}\n\nDecision:'} |
| ] |
| text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, |
| max_length=1024).to(device) |
| outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False, |
| pad_token_id=tokenizer.pad_token_id) |
| response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True).strip().lower() |
| return 'accept' in response and 'reject' not in response |
|
|
| def build_reward_text(proposed_action, example): |
| """Build text for reward model scoring — match training format.""" |
| msgs = example['messages'] |
| context = '\n'.join( |
| f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:] |
| ) |
| return f"User: {context}\n\nAssistant: Action: {proposed_action}" |
|
|
| |
| def evaluate_config_A(data, strong_model, strong_tok, device): |
| results = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: print(f" A: {i}/{len(data)}") |
| msgs = build_proposer_messages(ex) |
| pred = predict_action(strong_model, strong_tok, msgs, device) |
| results.append(dict(pred=pred, true=ex['action_type'], |
| cost=COST['strong'], accepted=None, |
| safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED'))) |
| return results |
|
|
| def evaluate_config_B(data, cheap_model, cheap_tok, device): |
| results = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: print(f" B: {i}/{len(data)}") |
| msgs = build_proposer_messages(ex) |
| pred = predict_action(cheap_model, cheap_tok, msgs, device) |
| results.append(dict(pred=pred, true=ex['action_type'], |
| cost=COST['cheap'], accepted=None, |
| safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED'))) |
| return results |
|
|
| def evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device): |
| results = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: print(f" C: {i}/{len(data)}") |
| msgs = build_proposer_messages(ex) |
| cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device) |
| accepted = predict_accept_reject(strong_model, strong_tok, cheap_pred, ex['messages'], device) |
| if accepted: |
| pred, cost = cheap_pred, COST['cheap'] + COST['verify_check'] |
| else: |
| pred = predict_action(strong_model, strong_tok, msgs, device) |
| cost = COST['cheap'] + COST['verify_check'] + COST['strong'] |
| results.append(dict(pred=pred, true=ex['action_type'], |
| cost=cost, accepted=accepted, |
| safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED'))) |
| return results |
|
|
| def evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device): |
| THRESHOLD = -1.0 |
| results = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: print(f" D: {i}/{len(data)}") |
| msgs = build_proposer_messages(ex) |
| cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device) |
| reward_text = build_reward_text(cheap_pred, ex) |
| score = get_reward_score(verifier_model, verifier_tok, reward_text, device) |
| accepted = score >= THRESHOLD |
| pred = cheap_pred |
| cost = COST['cheap'] + COST['verify_check'] |
| results.append(dict(pred=pred, true=ex['action_type'], |
| cost=cost, accepted=accepted, score=score, |
| safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED'))) |
| return results |
|
|
| def evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device, n=3): |
| results = [] |
| for i, ex in enumerate(data): |
| if i % 10 == 0: print(f" E: {i}/{len(data)}") |
| msgs = build_proposer_messages(ex) |
| text = cheap_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| proposals = [] |
| for _ in range(n): |
| inputs = cheap_tok(text, return_tensors='pt', truncation=True, |
| max_length=2048).to(device) |
| outputs = cheap_model.generate(**inputs, max_new_tokens=50, |
| do_sample=True, temperature=0.8, top_p=0.95, |
| pad_token_id=cheap_tok.pad_token_id) |
| response = cheap_tok.decode(outputs[0][inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True) |
| proposals.append(parse_action(response)) |
| scored = [] |
| for prop in set(proposals): |
| reward_text = build_reward_text(prop, ex) |
| score = get_reward_score(verifier_model, verifier_tok, reward_text, device) |
| scored.append((prop, score)) |
| best = max(scored, key=lambda x: x[1])[0] |
| results.append(dict(pred=best, true=ex['action_type'], |
| cost=COST['cheap'] * n + COST['verify_check'] * n, |
| accepted=True, |
| safe=not (ex['action_type'] == 'BLOCKED' and best != 'BLOCKED'))) |
| return results |
|
|
| |
| def compute_metrics(results, config_name): |
| total = len(results) |
| correct = sum(1 for r in results if r['pred'] == r['true']) |
| avg_cost = sum(r['cost'] for r in results) / total |
| safe = sum(1 for r in results if r['safe']) / total |
| by_action = {} |
| for a in ACTIONS: |
| subset = [r for r in results if r['true'] == a] |
| if subset: |
| by_action[a] = round(sum(1 for r in subset if r['pred'] == a) / len(subset), 3) |
| accepted = [r for r in results if r['accepted'] is not None] |
| accept_rate = sum(1 for r in accepted if r['accepted']) / len(accepted) if accepted else None |
| metrics = { |
| 'config': config_name, |
| 'accuracy': round(correct / total, 4), |
| 'avg_cost': round(avg_cost, 4), |
| 'safety': round(safe, 4), |
| 'n': total, |
| 'by_action': by_action, |
| } |
| if accept_rate is not None: |
| metrics['accept_rate'] = round(accept_rate, 4) |
| if results and 'score' in results[0]: |
| scores = [r.get('score', 0) for r in results] |
| metrics['mean_score'] = round(sum(scores)/len(scores), 3) |
| metrics['min_score'] = round(min(scores), 3) |
| metrics['max_score'] = round(max(scores), 3) |
| return metrics |
|
|
| |
| def main(): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f'Device: {device}') |
|
|
| cheap_id = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b' |
| verifier_id = f'{HUB_ORG}/speculative-verifier-qwen3-4b' |
| strong_id = 'Qwen/Qwen3-8B' |
|
|
| print(f'Loading eval dataset: {EVAL_DS}') |
| ds = load_dataset(EVAL_DS, split='train') |
| data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))] |
| print(f'Evaluating on {len(data)} examples') |
|
|
| from collections import Counter |
| dist = Counter(ex['action_type'] for ex in data) |
| print(f'Action distribution: {dict(dist)}') |
|
|
| print('\nLoading models...') |
| cheap_model, cheap_tok = load_lm(cheap_id, device) |
| verifier_model, verifier_tok = load_reward_model(verifier_id, device) |
| strong_model, strong_tok = load_lm(strong_id, device) |
|
|
| all_metrics = {} |
| configs = [ |
| ('A', lambda: evaluate_config_A(data, strong_model, strong_tok, device)), |
| ('B', lambda: evaluate_config_B(data, cheap_model, cheap_tok, device)), |
| ('C', lambda: evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device)), |
| ('D', lambda: evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)), |
| ('E', lambda: evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)), |
| ] |
|
|
| for name, fn in configs: |
| print(f'\n{"="*50}\nEvaluating Config {name}...') |
| t0 = time.time() |
| try: |
| raw = fn() |
| elapsed = time.time() - t0 |
| metrics = compute_metrics(raw, name) |
| all_metrics[name] = metrics |
| print(f' Accuracy: {metrics["accuracy"]:.3f}') |
| print(f' Avg Cost: {metrics["avg_cost"]:.3f}') |
| print(f' Safety: {metrics["safety"]:.3f}') |
| if metrics.get('accept_rate') is not None: |
| print(f' Accept Rate: {metrics["accept_rate"]:.3f}') |
| if metrics.get('mean_score') is not None: |
| print(f' Mean Score: {metrics["mean_score"]:.3f}') |
| print(f' Time: {elapsed:.1f}s') |
| except Exception as e: |
| print(f' ERROR: {e}') |
| import traceback; traceback.print_exc() |
| all_metrics[name] = {'config': name, 'error': str(e), 'accuracy': 0, 'avg_cost': 0, 'safety': 0, 'n': 0} |
|
|
| print(f'\n{"="*60}') |
| print('FINAL COMPARISON') |
| print(f'{"Config":<6} {"Accuracy":>10} {"Avg Cost":>10} {"Safety":>10} {"Accept%":>10}') |
| print('-' * 60) |
| for cfg in ['A', 'B', 'C', 'D', 'E']: |
| m = all_metrics.get(cfg, {}) |
| ar = m.get('accept_rate', '-') |
| if isinstance(ar, float): ar = f'{ar:.3f}' |
| print(f'{cfg:<6} {m.get("accuracy",0):>10.3f} {m.get("avg_cost",0):>10.3f} ' |
| f'{m.get("safety",0):>10.3f} {str(ar):>10}') |
|
|
| print(f'\n{"="*60}') |
| print('COST-QUALITY FRONTIER') |
| for m in sorted(all_metrics.values(), key=lambda x: x.get('avg_cost',0)): |
| print(f" {m.get('config','?')}: cost={m.get('avg_cost',0):.3f}, " |
| f"acc={m.get('accuracy',0):.3f}, safety={m.get('safety',0):.3f}") |
|
|
| out_path = '/tmp/eval_results_v2.json' |
| output = { |
| 'metrics': all_metrics, |
| 'config': { |
| 'cheap_model': cheap_id, |
| 'verifier_model': verifier_id, |
| 'strong_model': strong_id, |
| 'eval_dataset': EVAL_DS, |
| 'n_examples': len(data), |
| 'version': 'v2 — fixed prompt format matching training data', |
| 'prompt_format': 'chat template with system prompt + Action: <type> output', |
| }, |
| 'action_distribution': dict(dist), |
| } |
| with open(out_path, 'w') as f: |
| json.dump(output, f, indent=2) |
| print(f'\nResults saved to {out_path}') |
|
|
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=out_path, |
| path_in_repo='eval_results_v2.json', |
| repo_id=f'{HUB_ORG}/speculative-tool-actions', |
| repo_type='model', |
| commit_message='Eval v2 results with fixed prompt format matching training data', |
| ) |
| print('Uploaded to Hub!') |
|
|
| if __name__ == '__main__': |
| main() |
|
|