| import json, os, re |
| import torch |
| from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer |
| from peft import PeftModel |
| from datasets import load_dataset |
|
|
| HUB_ORG = 'narcolepticchicken' |
| EVAL_DS = f'{HUB_ORG}/speculative-actions-eval' |
| MAX_EVAL = int(os.environ.get('MAX_EVAL', '200')) |
| ACTIONS = ['tool_call', 'retrieval', 'file_read', 'file_write', 'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED'] |
| COST = {'strong': 1.00, 'cheap': 0.15, 'verifier': 0.30, 'verify_check': 0.10} |
| SP = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED. |
| Format your response as: |
| Action: <action_name> |
| <brief reason>""" |
|
|
| def load_lm(model_id, device): |
| print(f" Loading LM: {model_id}") |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| if tok.pad_token is None: tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) |
| model.eval() |
| return model, tok |
|
|
| def load_rm(adapter_id, device): |
| tok = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B', trust_remote_code=True) |
| if tok.pad_token is None: tok.pad_token = tok.eos_token |
| model = AutoModelForSequenceClassification.from_pretrained('Qwen/Qwen3-4B', num_labels=1, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) |
| model.config.pad_token_id = tok.pad_token_id |
| model = PeftModel.from_pretrained(model, adapter_id) |
| model.eval() |
| return model, tok |
|
|
| def parse_action(text): |
| m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE) |
| if m: return m.group(1).lower() |
| for a in ACTIONS: |
| if a in text.lower(): return a |
| return 'tool_call' |
|
|
| def build_msgs(example): |
| msgs = example['messages'] |
| ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:300]}" for m in msgs[-4:]) |
| return [{'role': 'system', 'content': SP}, {'role': 'user', 'content': f"Predict the next action for:\n\n{ctx}"}] |
|
|
| @torch.no_grad() |
| def predict(model, tok, msgs, device='cuda'): |
| text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = tok(text, return_tensors='pt', truncation=True, max_length=2048).to(device) |
| outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tok.pad_token_id) |
| return parse_action(tok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()) |
|
|
| @torch.no_grad() |
| def reward_score(model, tok, text, device='cuda'): |
| inputs = tok(text, return_tensors='pt', truncation=True, max_length=1024).to(device) |
| return model(**inputs).logits.squeeze().item() |
|
|
| @torch.no_grad() |
| def accept_reject(model, tok, prop, example_msgs, device='cuda'): |
| ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:]) |
| msgs = [{'role': 'system', 'content': 'Say ACCEPT or REJECT only.'}, {'role': 'user', 'content': f'Proposed: {prop}\nContext:\n{ctx}\nDecision:'}] |
| text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = tok(text, return_tensors='pt', truncation=True, max_length=1024).to(device) |
| outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False, pad_token_id=tok.pad_token_id) |
| resp = tok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower() |
| return 'accept' in resp and 'reject' not in resp |
|
|
| def rw_text(prop, ex): |
| msgs = ex['messages'] |
| ctx = '\n'.join(f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:]) |
| return f"User: {ctx}\n\nAssistant: Action: {prop}" |
|
|
| device = 'cuda' |
| ds = load_dataset(EVAL_DS, split='train') |
| data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))] |
| print(f'Evaluating {len(data)} examples') |
|
|
| cm, ctok = load_lm(f'{HUB_ORG}/speculative-proposer-qwen3-1.7b', device) |
| vm, vtok = load_rm(f'{HUB_ORG}/speculative-verifier-qwen3-4b', device) |
| sm, stok = load_lm('Qwen/Qwen3-8B', device) |
|
|
| all_metrics = {} |
|
|
| |
| print('\nConfig A: strong only') |
| res = [{'pred': predict(sm, stok, build_msgs(ex), device), 'true': ex['action_type'], 'cost': COST['strong']} for ex in data] |
| acc = sum(1 for r in res if r['pred'] == r['true']) / len(res) |
| all_metrics['A'] = {'accuracy': round(acc,4), 'avg_cost': COST['strong'], 'n': len(res)} |
| print(f' Acc: {acc:.3f}') |
|
|
| |
| print('\nConfig B: cheap only') |
| res = [{'pred': predict(cm, ctok, build_msgs(ex), device), 'true': ex['action_type'], 'cost': COST['cheap']} for ex in data] |
| acc = sum(1 for r in res if r['pred'] == r['true']) / len(res) |
| all_metrics['B'] = {'accuracy': round(acc,4), 'avg_cost': COST['cheap'], 'n': len(res)} |
| print(f' Acc: {acc:.3f}') |
|
|
| |
| print('\nConfig C: cheap + strong verifier') |
| res = [] |
| for ex in data: |
| cp = predict(cm, ctok, build_msgs(ex), device) |
| accepted = accept_reject(sm, stok, cp, ex['messages'], device) |
| if accepted: |
| pred, cost = cp, COST['cheap'] + COST['verify_check'] |
| else: |
| pred = predict(sm, stok, build_msgs(ex), device) |
| cost = COST['cheap'] + COST['verify_check'] + COST['strong'] |
| res.append({'pred': pred, 'true': ex['action_type'], 'accepted': accepted, 'cost': cost}) |
| acc = sum(1 for r in res if r['pred'] == r['true']) / len(res) |
| ar = sum(1 for r in res if r['accepted']) / len(res) |
| all_metrics['C'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'accept_rate': round(ar,4), 'n': len(res)} |
| print(f' Acc: {acc:.3f} | Accept: {ar:.3f}') |
|
|
| |
| THR = -1.0 |
| print(f'\nConfig D: cheap + reward (thr={THR})') |
| res = [] |
| for ex in data: |
| cp = predict(cm, ctok, build_msgs(ex), device) |
| score = reward_score(vm, vtok, rw_text(cp, ex), device) |
| res.append({'pred': cp, 'true': ex['action_type'], 'cost': COST['cheap'] + COST['verify_check'], 'accepted': score >= THR, 'score': score}) |
| acc = sum(1 for r in res if r['pred'] == r['true']) / len(res) |
| ar = sum(1 for r in res if r['accepted']) / len(res) |
| scores = [r['score'] for r in res] |
| all_metrics['D'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'accept_rate': round(ar,4), 'mean_score': round(sum(scores)/len(scores),3), 'n': len(res)} |
| print(f' Acc: {acc:.3f} | Accept: {ar:.3f} | Score: {sum(scores)/len(scores):.3f}') |
|
|
| |
| N = 3 |
| print(f'\nConfig E: multi-proposal (n={N})') |
| res = [] |
| for i, ex in enumerate(data): |
| if i % 20 == 0: print(f' {i}/{len(data)}') |
| msgs = build_msgs(ex) |
| text = ctok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| proposals = [] |
| for _ in range(N): |
| inputs = ctok(text, return_tensors='pt', truncation=True, max_length=2048).to(device) |
| outputs = cm.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.8, top_p=0.95, pad_token_id=ctok.pad_token_id) |
| proposals.append(parse_action(ctok.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True))) |
| scored = [(p, reward_score(vm, vtok, rw_text(p, ex), device)) for p in set(proposals)] |
| best = max(scored, key=lambda x: x[1])[0] |
| res.append({'pred': best, 'true': ex['action_type'], 'cost': COST['cheap'] * N + COST['verify_check'] * N}) |
| acc = sum(1 for r in res if r['pred'] == r['true']) / len(res) |
| all_metrics['E'] = {'accuracy': round(acc,4), 'avg_cost': round(sum(r['cost'] for r in res)/len(res),4), 'n': len(res)} |
| print(f' Acc: {acc:.3f}') |
|
|
| |
| from collections import Counter |
| dist = Counter(ex['action_type'] for ex in data) |
| maj = dist.most_common(1)[0][0] |
| maj_acc = sum(1 for ex in data if ex['action_type'] == maj) / len(data) |
| rand_acc = 1.0 / len(ACTIONS) |
| print(f'\nBaselines: random={rand_acc:.3f}, majority({maj})={maj_acc:.3f}') |
|
|
| print(f'\n{"="*60}') |
| print(f'{"Config":<6} {"Acc":>8} {"Cost":>8} {"vsRandom":>10} {"vsMaj":>8} {"Acc%":>8}') |
| print('-'*60) |
| for c in ['A','B','C','D','E']: |
| m = all_metrics[c] |
| ar = f'{m.get("accept_rate","-"):.3f}' if isinstance(m.get('accept_rate'), float) else '-' |
| print(f'{c:<6} {m["accuracy"]:>8.3f} {m["avg_cost"]:>8.3f} {m["accuracy"]/rand_acc:>10.1f}x {m["accuracy"]/maj_acc:>8.1f}x {ar:>8}') |
|
|
| print(f'\nCOST-QUALITY FRONTIER') |
| for m in sorted(all_metrics.values(), key=lambda x: x['avg_cost']): |
| print(f" {m.get('config',[k for k,v in all_metrics.items() if v is m][0])}: cost={m['avg_cost']:.3f} acc={m['accuracy']:.3f}") |
|
|
| out = {'metrics': all_metrics, 'baselines': {'random': rand_acc, 'majority': maj_acc, 'majority_class': maj}, 'n': len(data), 'distribution': dict(dist)} |
| with open('/tmp/results.json', 'w') as f: json.dump(out, f, indent=2) |
|
|
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file(path_or_fileobj='/tmp/results.json', path_in_repo='eval_results_v2.json', repo_id=f'{HUB_ORG}/speculative-tool-actions', repo_type='model', commit_message='Eval v2 results (cu121 fix)') |
| print('Done!') |
|
|