speculative-tool-actions / eval_final_v2.py
narcolepticchicken's picture
Upload eval_final_v2.py
7ca55e0 verified
"""Speculative Tool Actions — Evaluation Runner (v2)
======================================================
Fixed: prompt format matches training data format (Action: <type> prefix).
Training data uses: system prompt + context → "Action: <type>\n<reason>"
Eval now uses the same chat template format that training used.
Evaluates 5 configurations:
A: Always strong model (Qwen3-8B)
B: Cheap model only (Qwen3-1.7B trained proposer)
C: Cheap proposer + strong verifier (8B ACCEPT/REJECT)
D: Cheap proposer + trained reward model scorer
E: Multi-proposal reranking (reward model scores N proposals)
"""
import json, os, time, re
import torch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
# --- Configuration -----------------------------------------------------------
HUB_ORG = 'narcolepticchicken'
EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
MAX_EVAL = int(os.environ.get('MAX_EVAL', '200'))
ACTIONS = [
'tool_call', 'retrieval', 'file_read', 'file_write',
'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED'
]
COST = {
'strong': 1.00,
'cheap': 0.15,
'verifier': 0.30,
'verify_check': 0.10,
}
# --- Model Loading (unchanged) ------------------------------------------------
def load_lm(model_id, device):
print(f" Loading LM: {model_id}")
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map='auto',
trust_remote_code=True,
)
model.eval()
return model, tok
def load_reward_model(adapter_id, device):
base_model = 'Qwen/Qwen3-4B'
print(f" Loading reward model base: {base_model}")
tok = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForSequenceClassification.from_pretrained(
base_model, num_labels=1,
torch_dtype=torch.bfloat16, device_map='auto',
trust_remote_code=True,
)
model.config.pad_token_id = tok.pad_token_id
print(f" Loading LoRA adapter: {adapter_id}")
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
return model, tok
# --- FIXED: Parse "Action: <type>" from output -------------------------------
def parse_action(text):
"""Parse action from model output. Looks for 'Action: <type>' prefix."""
m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE)
if m:
return m.group(1).lower()
# Fallback: try finding any action name
lower = text.lower()
for a in ACTIONS:
if a.lower() in lower:
return a
return 'tool_call'
# --- FIXED: Build prompts matching training format ----------------------------
SYSTEM_PROMPT = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED.
Format your response as:
Action: <action_name>
<brief reason>"""
def build_proposer_messages(example):
"""Build messages list matching training format: system + context."""
msgs = example['messages']
# Build context from conversation
context_lines = []
for m in msgs[-4:]: # last 4 messages
context_lines.append(f"{m['role']}: {str(m['content'])[:300]}")
context = '\n'.join(context_lines)
return [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': f"Predict the next action for:\n\n{context}"},
]
@torch.no_grad()
def predict_action(model, tokenizer, messages, device='cuda'):
"""Predict action using chat template (matching training format)."""
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors='pt', truncation=True,
max_length=2048).to(device)
outputs = model.generate(
**inputs, max_new_tokens=50, do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
return parse_action(response)
@torch.no_grad()
def get_reward_score(model, tokenizer, text, device='cuda'):
inputs = tokenizer(text, return_tensors='pt', truncation=True,
max_length=1024).to(device)
score = model(**inputs).logits.squeeze().item()
return score
@torch.no_grad()
def predict_accept_reject(model, tokenizer, proposed_action, example_msgs, device='cuda'):
"""Strong verifier: ACCEPT or REJECT using chat template."""
context = '\n'.join(
f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:]
)
msgs = [
{'role': 'system', 'content': 'You are a verifier. Say ACCEPT if the proposed action is correct, REJECT if wrong. Only output ACCEPT or REJECT.'},
{'role': 'user', 'content': f'Proposed action: {proposed_action}\n\nContext:\n{context}\n\nDecision:'}
]
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors='pt', truncation=True,
max_length=1024).to(device)
outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False,
pad_token_id=tokenizer.pad_token_id)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True).strip().lower()
return 'accept' in response and 'reject' not in response
def build_reward_text(proposed_action, example):
"""Build text for reward model scoring — match training format."""
msgs = example['messages']
context = '\n'.join(
f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:]
)
return f"User: {context}\n\nAssistant: Action: {proposed_action}"
# --- Eval Configs (updated to use new prompt format) --------------------------
def evaluate_config_A(data, strong_model, strong_tok, device):
results = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f" A: {i}/{len(data)}")
msgs = build_proposer_messages(ex)
pred = predict_action(strong_model, strong_tok, msgs, device)
results.append(dict(pred=pred, true=ex['action_type'],
cost=COST['strong'], accepted=None,
safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
return results
def evaluate_config_B(data, cheap_model, cheap_tok, device):
results = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f" B: {i}/{len(data)}")
msgs = build_proposer_messages(ex)
pred = predict_action(cheap_model, cheap_tok, msgs, device)
results.append(dict(pred=pred, true=ex['action_type'],
cost=COST['cheap'], accepted=None,
safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
return results
def evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device):
results = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f" C: {i}/{len(data)}")
msgs = build_proposer_messages(ex)
cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device)
accepted = predict_accept_reject(strong_model, strong_tok, cheap_pred, ex['messages'], device)
if accepted:
pred, cost = cheap_pred, COST['cheap'] + COST['verify_check']
else:
pred = predict_action(strong_model, strong_tok, msgs, device)
cost = COST['cheap'] + COST['verify_check'] + COST['strong']
results.append(dict(pred=pred, true=ex['action_type'],
cost=cost, accepted=accepted,
safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
return results
def evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device):
THRESHOLD = -1.0 # calibrated from prior run: all scores are negative
results = []
for i, ex in enumerate(data):
if i % 20 == 0: print(f" D: {i}/{len(data)}")
msgs = build_proposer_messages(ex)
cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device)
reward_text = build_reward_text(cheap_pred, ex)
score = get_reward_score(verifier_model, verifier_tok, reward_text, device)
accepted = score >= THRESHOLD
pred = cheap_pred
cost = COST['cheap'] + COST['verify_check']
results.append(dict(pred=pred, true=ex['action_type'],
cost=cost, accepted=accepted, score=score,
safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
return results
def evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device, n=3):
results = []
for i, ex in enumerate(data):
if i % 10 == 0: print(f" E: {i}/{len(data)}")
msgs = build_proposer_messages(ex)
text = cheap_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
proposals = []
for _ in range(n):
inputs = cheap_tok(text, return_tensors='pt', truncation=True,
max_length=2048).to(device)
outputs = cheap_model.generate(**inputs, max_new_tokens=50,
do_sample=True, temperature=0.8, top_p=0.95,
pad_token_id=cheap_tok.pad_token_id)
response = cheap_tok.decode(outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True)
proposals.append(parse_action(response))
scored = []
for prop in set(proposals):
reward_text = build_reward_text(prop, ex)
score = get_reward_score(verifier_model, verifier_tok, reward_text, device)
scored.append((prop, score))
best = max(scored, key=lambda x: x[1])[0]
results.append(dict(pred=best, true=ex['action_type'],
cost=COST['cheap'] * n + COST['verify_check'] * n,
accepted=True,
safe=not (ex['action_type'] == 'BLOCKED' and best != 'BLOCKED')))
return results
# --- Metrics ------------------------------------------------------------------
def compute_metrics(results, config_name):
total = len(results)
correct = sum(1 for r in results if r['pred'] == r['true'])
avg_cost = sum(r['cost'] for r in results) / total
safe = sum(1 for r in results if r['safe']) / total
by_action = {}
for a in ACTIONS:
subset = [r for r in results if r['true'] == a]
if subset:
by_action[a] = round(sum(1 for r in subset if r['pred'] == a) / len(subset), 3)
accepted = [r for r in results if r['accepted'] is not None]
accept_rate = sum(1 for r in accepted if r['accepted']) / len(accepted) if accepted else None
metrics = {
'config': config_name,
'accuracy': round(correct / total, 4),
'avg_cost': round(avg_cost, 4),
'safety': round(safe, 4),
'n': total,
'by_action': by_action,
}
if accept_rate is not None:
metrics['accept_rate'] = round(accept_rate, 4)
if results and 'score' in results[0]:
scores = [r.get('score', 0) for r in results]
metrics['mean_score'] = round(sum(scores)/len(scores), 3)
metrics['min_score'] = round(min(scores), 3)
metrics['max_score'] = round(max(scores), 3)
return metrics
# --- Main ---------------------------------------------------------------------
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
cheap_id = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b'
verifier_id = f'{HUB_ORG}/speculative-verifier-qwen3-4b'
strong_id = 'Qwen/Qwen3-8B'
print(f'Loading eval dataset: {EVAL_DS}')
ds = load_dataset(EVAL_DS, split='train')
data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))]
print(f'Evaluating on {len(data)} examples')
from collections import Counter
dist = Counter(ex['action_type'] for ex in data)
print(f'Action distribution: {dict(dist)}')
print('\nLoading models...')
cheap_model, cheap_tok = load_lm(cheap_id, device)
verifier_model, verifier_tok = load_reward_model(verifier_id, device)
strong_model, strong_tok = load_lm(strong_id, device)
all_metrics = {}
configs = [
('A', lambda: evaluate_config_A(data, strong_model, strong_tok, device)),
('B', lambda: evaluate_config_B(data, cheap_model, cheap_tok, device)),
('C', lambda: evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device)),
('D', lambda: evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)),
('E', lambda: evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)),
]
for name, fn in configs:
print(f'\n{"="*50}\nEvaluating Config {name}...')
t0 = time.time()
try:
raw = fn()
elapsed = time.time() - t0
metrics = compute_metrics(raw, name)
all_metrics[name] = metrics
print(f' Accuracy: {metrics["accuracy"]:.3f}')
print(f' Avg Cost: {metrics["avg_cost"]:.3f}')
print(f' Safety: {metrics["safety"]:.3f}')
if metrics.get('accept_rate') is not None:
print(f' Accept Rate: {metrics["accept_rate"]:.3f}')
if metrics.get('mean_score') is not None:
print(f' Mean Score: {metrics["mean_score"]:.3f}')
print(f' Time: {elapsed:.1f}s')
except Exception as e:
print(f' ERROR: {e}')
import traceback; traceback.print_exc()
all_metrics[name] = {'config': name, 'error': str(e), 'accuracy': 0, 'avg_cost': 0, 'safety': 0, 'n': 0}
print(f'\n{"="*60}')
print('FINAL COMPARISON')
print(f'{"Config":<6} {"Accuracy":>10} {"Avg Cost":>10} {"Safety":>10} {"Accept%":>10}')
print('-' * 60)
for cfg in ['A', 'B', 'C', 'D', 'E']:
m = all_metrics.get(cfg, {})
ar = m.get('accept_rate', '-')
if isinstance(ar, float): ar = f'{ar:.3f}'
print(f'{cfg:<6} {m.get("accuracy",0):>10.3f} {m.get("avg_cost",0):>10.3f} '
f'{m.get("safety",0):>10.3f} {str(ar):>10}')
print(f'\n{"="*60}')
print('COST-QUALITY FRONTIER')
for m in sorted(all_metrics.values(), key=lambda x: x.get('avg_cost',0)):
print(f" {m.get('config','?')}: cost={m.get('avg_cost',0):.3f}, "
f"acc={m.get('accuracy',0):.3f}, safety={m.get('safety',0):.3f}")
out_path = '/tmp/eval_results_v2.json'
output = {
'metrics': all_metrics,
'config': {
'cheap_model': cheap_id,
'verifier_model': verifier_id,
'strong_model': strong_id,
'eval_dataset': EVAL_DS,
'n_examples': len(data),
'version': 'v2 — fixed prompt format matching training data',
'prompt_format': 'chat template with system prompt + Action: <type> output',
},
'action_distribution': dict(dist),
}
with open(out_path, 'w') as f:
json.dump(output, f, indent=2)
print(f'\nResults saved to {out_path}')
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=out_path,
path_in_repo='eval_results_v2.json',
repo_id=f'{HUB_ORG}/speculative-tool-actions',
repo_type='model',
commit_message='Eval v2 results with fixed prompt format matching training data',
)
print('Uploaded to Hub!')
if __name__ == '__main__':
main()