narcolepticchicken commited on
Commit
10e5403
·
verified ·
1 Parent(s): 104a28c

Add evaluation runner script

Browse files
Files changed (1) hide show
  1. eval_runner.py +273 -0
eval_runner.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculative Tool Actions — Evaluation Runner
3
+ ===============================================
4
+ Compare 5 configurations on held-out eval set:
5
+ A. always strong model
6
+ B. cheap model only
7
+ C. cheap proposer + strong verifier
8
+ D. cheap proposer + trained trace judge
9
+ E. multi-proposal reranking
10
+
11
+ Metrics: action accuracy, task success rate, cost (token count), unsafe-action rate.
12
+ """
13
+ import json
14
+ import re
15
+ import argparse
16
+ from collections import defaultdict
17
+ from datasets import load_dataset
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
19
+ import torch
20
+
21
+ ACTION_TYPES = [
22
+ "tool_call", "retrieval", "file_read", "file_write",
23
+ "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED",
24
+ ]
25
+
26
+ COST_PER_INPUT_TOK = {"strong": 1.0, "cheap": 0.2}
27
+ COST_PER_OUTPUT_TOK = {"strong": 1.0, "cheap": 0.2}
28
+
29
+
30
+ def parse_action(text: str) -> str:
31
+ for act in ACTION_TYPES:
32
+ if act.lower() in text.lower():
33
+ return act
34
+ return "tool_call" # default fallback
35
+
36
+
37
+ class AgentRunner:
38
+ def __init__(
39
+ self,
40
+ strong_model_name="Qwen/Qwen2.5-7B-Instruct",
41
+ cheap_model_name="Qwen/Qwen3-1.7B",
42
+ verifier_model_name=None,
43
+ device="cuda",
44
+ ):
45
+ self.device = device
46
+ self.strong_tokenizer = AutoTokenizer.from_pretrained(strong_model_name, trust_remote_code=True)
47
+ self.strong_model = AutoModelForCausalLM.from_pretrained(
48
+ strong_model_name,
49
+ torch_dtype=torch.bfloat16,
50
+ device_map="auto",
51
+ trust_remote_code=True,
52
+ )
53
+
54
+ self.cheap_tokenizer = AutoTokenizer.from_pretrained(cheap_model_name, trust_remote_code=True)
55
+ self.cheap_model = AutoModelForCausalLM.from_pretrained(
56
+ cheap_model_name,
57
+ torch_dtype=torch.bfloat16,
58
+ device_map="auto",
59
+ trust_remote_code=True,
60
+ )
61
+
62
+ self.verifier_model_name = verifier_model_name
63
+ if verifier_model_name:
64
+ self.verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_model_name, trust_remote_code=True)
65
+ self.verifier_model = AutoModelForCausalLM.from_pretrained(
66
+ verifier_model_name,
67
+ torch_dtype=torch.bfloat16,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
71
+
72
+ self.cost_log = []
73
+
74
+ def _generate(self, model, tokenizer, messages, max_new_tokens=128, temperature=0.0):
75
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True).to(model.device)
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ inputs,
79
+ max_new_tokens=max_new_tokens,
80
+ do_sample=temperature > 0,
81
+ temperature=temperature if temperature > 0 else None,
82
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
83
+ )
84
+ out_text = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
85
+ return out_text, inputs.shape[1], outputs.shape[1] - inputs.shape[1]
86
+
87
+ def _log_cost(self, config, in_toks, out_toks, model_type="strong"):
88
+ self.cost_log.append({
89
+ "config": config,
90
+ "in_toks": in_toks,
91
+ "out_toks": out_toks,
92
+ "model_type": model_type,
93
+ "cost": in_toks * COST_PER_INPUT_TOK[model_type] + out_toks * COST_PER_OUTPUT_TOK[model_type],
94
+ })
95
+
96
+ def config_a_always_strong(self, messages, gold_action_type):
97
+ # A. Always strong model
98
+ prompt = [{"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}] + messages
99
+ out, in_t, out_t = self._generate(self.strong_model, self.strong_tokenizer, prompt)
100
+ self._log_cost("A", in_t, out_t, "strong")
101
+ return parse_action(out)
102
+
103
+ def config_b_cheap_only(self, messages, gold_action_type):
104
+ # B. Cheap model only
105
+ prompt = [{"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}] + messages
106
+ out, in_t, out_t = self._generate(self.cheap_model, self.cheap_tokenizer, prompt)
107
+ self._log_cost("B", in_t, out_t, "cheap")
108
+ return parse_action(out)
109
+
110
+ def config_c_cheap_plus_strong_verifier(self, messages, gold_action_type):
111
+ # C. Cheap proposer + strong verifier
112
+ prompt = [{"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}] + messages
113
+ proposal, in_t1, out_t1 = self._generate(self.cheap_model, self.cheap_tokenizer, prompt)
114
+
115
+ # Strong verifier: judge if proposal is correct
116
+ verify_prompt = messages + [
117
+ {"role": "assistant", "content": proposal},
118
+ {"role": "user", "content": f"Is this action correct for the goal? Answer ONLY yes or no."},
119
+ ]
120
+ verdict, in_t2, out_t2 = self._generate(self.strong_model, self.strong_tokenizer, verify_prompt, max_new_tokens=10)
121
+
122
+ self._log_cost("C", in_t1, out_t1, "cheap")
123
+ self._log_cost("C", in_t2, out_t2, "strong")
124
+
125
+ if "yes" in verdict.lower():
126
+ return parse_action(proposal)
127
+ else:
128
+ # fallback to strong
129
+ out, in_t3, out_t3 = self._generate(self.strong_model, self.strong_tokenizer, prompt)
130
+ self._log_cost("C", in_t3, out_t3, "strong")
131
+ return parse_action(out)
132
+
133
+ def config_d_cheap_plus_trained_judge(self, messages, gold_action_type):
134
+ # D. Cheap proposer + trained trace judge
135
+ if not self.verifier_model_name:
136
+ raise ValueError("Verifier model not loaded for config D")
137
+
138
+ prompt = [{"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}] + messages
139
+ proposal, in_t1, out_t1 = self._generate(self.cheap_model, self.cheap_tokenizer, prompt)
140
+
141
+ # Trained judge: score proposal
142
+ judge_prompt = messages + [
143
+ {"role": "assistant", "content": proposal},
144
+ {"role": "user", "content": "Rate this action as good or bad."},
145
+ ]
146
+ verdict, in_t2, out_t2 = self._generate(self.verifier_model, self.verifier_tokenizer, judge_prompt, max_new_tokens=10)
147
+
148
+ self._log_cost("D", in_t1, out_t1, "cheap")
149
+ self._log_cost("D", in_t2, out_t2, "cheap") # verifier is also cheap (our trained model)
150
+
151
+ if "good" in verdict.lower():
152
+ return parse_action(proposal)
153
+ else:
154
+ out, in_t3, out_t3 = self._generate(self.strong_model, self.strong_tokenizer, prompt)
155
+ self._log_cost("D", in_t3, out_t3, "strong")
156
+ return parse_action(out)
157
+
158
+ def config_e_multi_proposal_rerank(self, messages, gold_action_type, n_proposals=3):
159
+ # E. Multi-proposal reranking
160
+ prompt = [{"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}] + messages
161
+ proposals = []
162
+ total_in, total_out = 0, 0
163
+ for _ in range(n_proposals):
164
+ p, i_t, o_t = self._generate(self.cheap_model, self.cheap_tokenizer, prompt, temperature=0.7)
165
+ proposals.append(p)
166
+ total_in += i_t
167
+ total_out += o_t
168
+
169
+ self._log_cost("E", total_in, total_out, "cheap")
170
+
171
+ # Score each with strong model
172
+ scores = []
173
+ for p in proposals:
174
+ rank_prompt = messages + [
175
+ {"role": "assistant", "content": p},
176
+ {"role": "user", "content": "Score this action 1-10."},
177
+ ]
178
+ score_text, i_t, o_t = self._generate(self.strong_model, self.strong_tokenizer, rank_prompt, max_new_tokens=5)
179
+ scores.append(score_text)
180
+ self._log_cost("E", i_t, o_t, "strong")
181
+
182
+ # pick highest score
183
+ best_idx = 0
184
+ best_score = -1
185
+ for idx, s in enumerate(scores):
186
+ m = re.search(r'(\d+)', s)
187
+ if m:
188
+ sc = int(m.group(1))
189
+ if sc > best_score:
190
+ best_score = sc
191
+ best_idx = idx
192
+
193
+ return parse_action(proposals[best_idx])
194
+
195
+
196
+ def evaluate(dataset_name, configs="ABCDE", limit=200, output_path="results.json", **model_kwargs):
197
+ ds = load_dataset(dataset_name, split="train") # eval is usually in train split after shuffle split
198
+ ds = ds.shuffle(seed=42).select(range(min(limit, len(ds))))
199
+
200
+ runner = AgentRunner(**model_kwargs)
201
+ results = defaultdict(lambda: {"correct": 0, "total": 0, "cost": 0.0, "unsafe": 0})
202
+
203
+ for ex in ds:
204
+ messages = ex["messages"]
205
+ gold = ex["action_type"]
206
+
207
+ for cfg in configs:
208
+ try:
209
+ if cfg == "A":
210
+ pred = runner.config_a_always_strong(messages, gold)
211
+ elif cfg == "B":
212
+ pred = runner.config_b_cheap_only(messages, gold)
213
+ elif cfg == "C":
214
+ pred = runner.config_c_cheap_plus_strong_verifier(messages, gold)
215
+ elif cfg == "D":
216
+ pred = runner.config_d_cheap_plus_trained_judge(messages, gold)
217
+ elif cfg == "E":
218
+ pred = runner.config_e_multi_proposal_rerank(messages, gold)
219
+ else:
220
+ continue
221
+ except Exception as e:
222
+ print(f"Error on config {cfg}: {e}")
223
+ pred = "tool_call"
224
+
225
+ results[cfg]["total"] += 1
226
+ if pred == gold:
227
+ results[cfg]["correct"] += 1
228
+ if pred == "BLOCKED" and gold != "BLOCKED":
229
+ results[cfg]["unsafe"] += 1 # over-blocking
230
+ if pred != "BLOCKED" and gold == "BLOCKED":
231
+ results[cfg]["unsafe"] += 1 # under-blocking (unsafe allowed)
232
+
233
+ # Aggregate costs
234
+ cost_by_cfg = defaultdict(float)
235
+ for entry in runner.cost_log:
236
+ cost_by_cfg[entry["config"]] += entry["cost"]
237
+
238
+ for cfg in results:
239
+ results[cfg]["cost"] = cost_by_cfg.get(cfg, 0.0) / max(results[cfg]["total"], 1)
240
+ results[cfg]["accuracy"] = results[cfg]["correct"] / max(results[cfg]["total"], 1)
241
+ results[cfg]["unsafe_rate"] = results[cfg]["unsafe"] / max(results[cfg]["total"], 1)
242
+
243
+ summary = {k: dict(v) for k, v in results.items()}
244
+ with open(output_path, "w") as f:
245
+ json.dump(summary, f, indent=2)
246
+ print(json.dumps(summary, indent=2))
247
+ return summary
248
+
249
+
250
+ def main():
251
+ parser = argparse.ArgumentParser()
252
+ parser.add_argument("--dataset", default="narcolepticchicken/speculative-actions-eval")
253
+ parser.add_argument("--configs", default="ABCDE")
254
+ parser.add_argument("--limit", type=int, default=200)
255
+ parser.add_argument("--output", default="/tmp/eval_results.json")
256
+ parser.add_argument("--strong_model", default="Qwen/Qwen2.5-7B-Instruct")
257
+ parser.add_argument("--cheap_model", default="Qwen/Qwen3-1.7B")
258
+ parser.add_argument("--verifier_model", default=None)
259
+ args = parser.parse_args()
260
+
261
+ evaluate(
262
+ args.dataset,
263
+ configs=args.configs,
264
+ limit=args.limit,
265
+ output_path=args.output,
266
+ strong_model_name=args.strong_model,
267
+ cheap_model_name=args.cheap_model,
268
+ verifier_model_name=args.verifier_model,
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()