narcolepticchicken commited on
Commit
7ca55e0
·
verified ·
1 Parent(s): e125869

Upload eval_final_v2.py

Browse files
Files changed (1) hide show
  1. eval_final_v2.py +369 -0
eval_final_v2.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speculative Tool Actions — Evaluation Runner (v2)
2
+ ======================================================
3
+ Fixed: prompt format matches training data format (Action: <type> prefix).
4
+ Training data uses: system prompt + context → "Action: <type>\n<reason>"
5
+ Eval now uses the same chat template format that training used.
6
+
7
+ Evaluates 5 configurations:
8
+ A: Always strong model (Qwen3-8B)
9
+ B: Cheap model only (Qwen3-1.7B trained proposer)
10
+ C: Cheap proposer + strong verifier (8B ACCEPT/REJECT)
11
+ D: Cheap proposer + trained reward model scorer
12
+ E: Multi-proposal reranking (reward model scores N proposals)
13
+ """
14
+
15
+ import json, os, time, re
16
+ import torch
17
+ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
18
+ from peft import PeftModel
19
+ from datasets import load_dataset
20
+
21
+ # --- Configuration -----------------------------------------------------------
22
+ HUB_ORG = 'narcolepticchicken'
23
+ EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
24
+ MAX_EVAL = int(os.environ.get('MAX_EVAL', '200'))
25
+
26
+ ACTIONS = [
27
+ 'tool_call', 'retrieval', 'file_read', 'file_write',
28
+ 'repair', 'verifier', 'ask_clarification', 'final_answer', 'BLOCKED'
29
+ ]
30
+
31
+ COST = {
32
+ 'strong': 1.00,
33
+ 'cheap': 0.15,
34
+ 'verifier': 0.30,
35
+ 'verify_check': 0.10,
36
+ }
37
+
38
+ # --- Model Loading (unchanged) ------------------------------------------------
39
+ def load_lm(model_id, device):
40
+ print(f" Loading LM: {model_id}")
41
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
42
+ if tok.pad_token is None:
43
+ tok.pad_token = tok.eos_token
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_id, torch_dtype=torch.bfloat16, device_map='auto',
46
+ trust_remote_code=True,
47
+ )
48
+ model.eval()
49
+ return model, tok
50
+
51
+ def load_reward_model(adapter_id, device):
52
+ base_model = 'Qwen/Qwen3-4B'
53
+ print(f" Loading reward model base: {base_model}")
54
+ tok = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
55
+ if tok.pad_token is None:
56
+ tok.pad_token = tok.eos_token
57
+ model = AutoModelForSequenceClassification.from_pretrained(
58
+ base_model, num_labels=1,
59
+ torch_dtype=torch.bfloat16, device_map='auto',
60
+ trust_remote_code=True,
61
+ )
62
+ model.config.pad_token_id = tok.pad_token_id
63
+ print(f" Loading LoRA adapter: {adapter_id}")
64
+ model = PeftModel.from_pretrained(model, adapter_id)
65
+ model.eval()
66
+ return model, tok
67
+
68
+ # --- FIXED: Parse "Action: <type>" from output -------------------------------
69
+ def parse_action(text):
70
+ """Parse action from model output. Looks for 'Action: <type>' prefix."""
71
+ m = re.search(r'Action:\s*(tool_call|retrieval|file_read|file_write|repair|verifier|ask_clarification|final_answer|BLOCKED)', text, re.IGNORECASE)
72
+ if m:
73
+ return m.group(1).lower()
74
+ # Fallback: try finding any action name
75
+ lower = text.lower()
76
+ for a in ACTIONS:
77
+ if a.lower() in lower:
78
+ return a
79
+ return 'tool_call'
80
+
81
+ # --- FIXED: Build prompts matching training format ----------------------------
82
+ SYSTEM_PROMPT = """You are an agent action predictor. Predict the next action from: tool_call, retrieval, file_read, file_write, repair, verifier, ask_clarification, final_answer, BLOCKED.
83
+
84
+ Format your response as:
85
+ Action: <action_name>
86
+ <brief reason>"""
87
+
88
+ def build_proposer_messages(example):
89
+ """Build messages list matching training format: system + context."""
90
+ msgs = example['messages']
91
+ # Build context from conversation
92
+ context_lines = []
93
+ for m in msgs[-4:]: # last 4 messages
94
+ context_lines.append(f"{m['role']}: {str(m['content'])[:300]}")
95
+ context = '\n'.join(context_lines)
96
+
97
+ return [
98
+ {'role': 'system', 'content': SYSTEM_PROMPT},
99
+ {'role': 'user', 'content': f"Predict the next action for:\n\n{context}"},
100
+ ]
101
+
102
+ @torch.no_grad()
103
+ def predict_action(model, tokenizer, messages, device='cuda'):
104
+ """Predict action using chat template (matching training format)."""
105
+ text = tokenizer.apply_chat_template(
106
+ messages, tokenize=False, add_generation_prompt=True
107
+ )
108
+ inputs = tokenizer(text, return_tensors='pt', truncation=True,
109
+ max_length=2048).to(device)
110
+ outputs = model.generate(
111
+ **inputs, max_new_tokens=50, do_sample=False,
112
+ pad_token_id=tokenizer.pad_token_id,
113
+ )
114
+ response = tokenizer.decode(
115
+ outputs[0][inputs['input_ids'].shape[1]:],
116
+ skip_special_tokens=True
117
+ ).strip()
118
+ return parse_action(response)
119
+
120
+ @torch.no_grad()
121
+ def get_reward_score(model, tokenizer, text, device='cuda'):
122
+ inputs = tokenizer(text, return_tensors='pt', truncation=True,
123
+ max_length=1024).to(device)
124
+ score = model(**inputs).logits.squeeze().item()
125
+ return score
126
+
127
+ @torch.no_grad()
128
+ def predict_accept_reject(model, tokenizer, proposed_action, example_msgs, device='cuda'):
129
+ """Strong verifier: ACCEPT or REJECT using chat template."""
130
+ context = '\n'.join(
131
+ f"{m['role']}: {str(m['content'])[:200]}" for m in example_msgs[-3:]
132
+ )
133
+ msgs = [
134
+ {'role': 'system', 'content': 'You are a verifier. Say ACCEPT if the proposed action is correct, REJECT if wrong. Only output ACCEPT or REJECT.'},
135
+ {'role': 'user', 'content': f'Proposed action: {proposed_action}\n\nContext:\n{context}\n\nDecision:'}
136
+ ]
137
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
138
+ inputs = tokenizer(text, return_tensors='pt', truncation=True,
139
+ max_length=1024).to(device)
140
+ outputs = model.generate(**inputs, max_new_tokens=5, do_sample=False,
141
+ pad_token_id=tokenizer.pad_token_id)
142
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:],
143
+ skip_special_tokens=True).strip().lower()
144
+ return 'accept' in response and 'reject' not in response
145
+
146
+ def build_reward_text(proposed_action, example):
147
+ """Build text for reward model scoring — match training format."""
148
+ msgs = example['messages']
149
+ context = '\n'.join(
150
+ f"{m['role']}: {str(m['content'])[:200]}" for m in msgs[-3:]
151
+ )
152
+ return f"User: {context}\n\nAssistant: Action: {proposed_action}"
153
+
154
+ # --- Eval Configs (updated to use new prompt format) --------------------------
155
+ def evaluate_config_A(data, strong_model, strong_tok, device):
156
+ results = []
157
+ for i, ex in enumerate(data):
158
+ if i % 20 == 0: print(f" A: {i}/{len(data)}")
159
+ msgs = build_proposer_messages(ex)
160
+ pred = predict_action(strong_model, strong_tok, msgs, device)
161
+ results.append(dict(pred=pred, true=ex['action_type'],
162
+ cost=COST['strong'], accepted=None,
163
+ safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
164
+ return results
165
+
166
+ def evaluate_config_B(data, cheap_model, cheap_tok, device):
167
+ results = []
168
+ for i, ex in enumerate(data):
169
+ if i % 20 == 0: print(f" B: {i}/{len(data)}")
170
+ msgs = build_proposer_messages(ex)
171
+ pred = predict_action(cheap_model, cheap_tok, msgs, device)
172
+ results.append(dict(pred=pred, true=ex['action_type'],
173
+ cost=COST['cheap'], accepted=None,
174
+ safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
175
+ return results
176
+
177
+ def evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device):
178
+ results = []
179
+ for i, ex in enumerate(data):
180
+ if i % 20 == 0: print(f" C: {i}/{len(data)}")
181
+ msgs = build_proposer_messages(ex)
182
+ cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device)
183
+ accepted = predict_accept_reject(strong_model, strong_tok, cheap_pred, ex['messages'], device)
184
+ if accepted:
185
+ pred, cost = cheap_pred, COST['cheap'] + COST['verify_check']
186
+ else:
187
+ pred = predict_action(strong_model, strong_tok, msgs, device)
188
+ cost = COST['cheap'] + COST['verify_check'] + COST['strong']
189
+ results.append(dict(pred=pred, true=ex['action_type'],
190
+ cost=cost, accepted=accepted,
191
+ safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
192
+ return results
193
+
194
+ def evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device):
195
+ THRESHOLD = -1.0 # calibrated from prior run: all scores are negative
196
+ results = []
197
+ for i, ex in enumerate(data):
198
+ if i % 20 == 0: print(f" D: {i}/{len(data)}")
199
+ msgs = build_proposer_messages(ex)
200
+ cheap_pred = predict_action(cheap_model, cheap_tok, msgs, device)
201
+ reward_text = build_reward_text(cheap_pred, ex)
202
+ score = get_reward_score(verifier_model, verifier_tok, reward_text, device)
203
+ accepted = score >= THRESHOLD
204
+ pred = cheap_pred
205
+ cost = COST['cheap'] + COST['verify_check']
206
+ results.append(dict(pred=pred, true=ex['action_type'],
207
+ cost=cost, accepted=accepted, score=score,
208
+ safe=not (ex['action_type'] == 'BLOCKED' and pred != 'BLOCKED')))
209
+ return results
210
+
211
+ def evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device, n=3):
212
+ results = []
213
+ for i, ex in enumerate(data):
214
+ if i % 10 == 0: print(f" E: {i}/{len(data)}")
215
+ msgs = build_proposer_messages(ex)
216
+ text = cheap_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
217
+ proposals = []
218
+ for _ in range(n):
219
+ inputs = cheap_tok(text, return_tensors='pt', truncation=True,
220
+ max_length=2048).to(device)
221
+ outputs = cheap_model.generate(**inputs, max_new_tokens=50,
222
+ do_sample=True, temperature=0.8, top_p=0.95,
223
+ pad_token_id=cheap_tok.pad_token_id)
224
+ response = cheap_tok.decode(outputs[0][inputs['input_ids'].shape[1]:],
225
+ skip_special_tokens=True)
226
+ proposals.append(parse_action(response))
227
+ scored = []
228
+ for prop in set(proposals):
229
+ reward_text = build_reward_text(prop, ex)
230
+ score = get_reward_score(verifier_model, verifier_tok, reward_text, device)
231
+ scored.append((prop, score))
232
+ best = max(scored, key=lambda x: x[1])[0]
233
+ results.append(dict(pred=best, true=ex['action_type'],
234
+ cost=COST['cheap'] * n + COST['verify_check'] * n,
235
+ accepted=True,
236
+ safe=not (ex['action_type'] == 'BLOCKED' and best != 'BLOCKED')))
237
+ return results
238
+
239
+ # --- Metrics ------------------------------------------------------------------
240
+ def compute_metrics(results, config_name):
241
+ total = len(results)
242
+ correct = sum(1 for r in results if r['pred'] == r['true'])
243
+ avg_cost = sum(r['cost'] for r in results) / total
244
+ safe = sum(1 for r in results if r['safe']) / total
245
+ by_action = {}
246
+ for a in ACTIONS:
247
+ subset = [r for r in results if r['true'] == a]
248
+ if subset:
249
+ by_action[a] = round(sum(1 for r in subset if r['pred'] == a) / len(subset), 3)
250
+ accepted = [r for r in results if r['accepted'] is not None]
251
+ accept_rate = sum(1 for r in accepted if r['accepted']) / len(accepted) if accepted else None
252
+ metrics = {
253
+ 'config': config_name,
254
+ 'accuracy': round(correct / total, 4),
255
+ 'avg_cost': round(avg_cost, 4),
256
+ 'safety': round(safe, 4),
257
+ 'n': total,
258
+ 'by_action': by_action,
259
+ }
260
+ if accept_rate is not None:
261
+ metrics['accept_rate'] = round(accept_rate, 4)
262
+ if results and 'score' in results[0]:
263
+ scores = [r.get('score', 0) for r in results]
264
+ metrics['mean_score'] = round(sum(scores)/len(scores), 3)
265
+ metrics['min_score'] = round(min(scores), 3)
266
+ metrics['max_score'] = round(max(scores), 3)
267
+ return metrics
268
+
269
+ # --- Main ---------------------------------------------------------------------
270
+ def main():
271
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
272
+ print(f'Device: {device}')
273
+
274
+ cheap_id = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b'
275
+ verifier_id = f'{HUB_ORG}/speculative-verifier-qwen3-4b'
276
+ strong_id = 'Qwen/Qwen3-8B'
277
+
278
+ print(f'Loading eval dataset: {EVAL_DS}')
279
+ ds = load_dataset(EVAL_DS, split='train')
280
+ data = [ds[i] for i in range(min(MAX_EVAL, len(ds)))]
281
+ print(f'Evaluating on {len(data)} examples')
282
+
283
+ from collections import Counter
284
+ dist = Counter(ex['action_type'] for ex in data)
285
+ print(f'Action distribution: {dict(dist)}')
286
+
287
+ print('\nLoading models...')
288
+ cheap_model, cheap_tok = load_lm(cheap_id, device)
289
+ verifier_model, verifier_tok = load_reward_model(verifier_id, device)
290
+ strong_model, strong_tok = load_lm(strong_id, device)
291
+
292
+ all_metrics = {}
293
+ configs = [
294
+ ('A', lambda: evaluate_config_A(data, strong_model, strong_tok, device)),
295
+ ('B', lambda: evaluate_config_B(data, cheap_model, cheap_tok, device)),
296
+ ('C', lambda: evaluate_config_C(data, cheap_model, cheap_tok, strong_model, strong_tok, device)),
297
+ ('D', lambda: evaluate_config_D(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)),
298
+ ('E', lambda: evaluate_config_E(data, cheap_model, cheap_tok, verifier_model, verifier_tok, device)),
299
+ ]
300
+
301
+ for name, fn in configs:
302
+ print(f'\n{"="*50}\nEvaluating Config {name}...')
303
+ t0 = time.time()
304
+ try:
305
+ raw = fn()
306
+ elapsed = time.time() - t0
307
+ metrics = compute_metrics(raw, name)
308
+ all_metrics[name] = metrics
309
+ print(f' Accuracy: {metrics["accuracy"]:.3f}')
310
+ print(f' Avg Cost: {metrics["avg_cost"]:.3f}')
311
+ print(f' Safety: {metrics["safety"]:.3f}')
312
+ if metrics.get('accept_rate') is not None:
313
+ print(f' Accept Rate: {metrics["accept_rate"]:.3f}')
314
+ if metrics.get('mean_score') is not None:
315
+ print(f' Mean Score: {metrics["mean_score"]:.3f}')
316
+ print(f' Time: {elapsed:.1f}s')
317
+ except Exception as e:
318
+ print(f' ERROR: {e}')
319
+ import traceback; traceback.print_exc()
320
+ all_metrics[name] = {'config': name, 'error': str(e), 'accuracy': 0, 'avg_cost': 0, 'safety': 0, 'n': 0}
321
+
322
+ print(f'\n{"="*60}')
323
+ print('FINAL COMPARISON')
324
+ print(f'{"Config":<6} {"Accuracy":>10} {"Avg Cost":>10} {"Safety":>10} {"Accept%":>10}')
325
+ print('-' * 60)
326
+ for cfg in ['A', 'B', 'C', 'D', 'E']:
327
+ m = all_metrics.get(cfg, {})
328
+ ar = m.get('accept_rate', '-')
329
+ if isinstance(ar, float): ar = f'{ar:.3f}'
330
+ print(f'{cfg:<6} {m.get("accuracy",0):>10.3f} {m.get("avg_cost",0):>10.3f} '
331
+ f'{m.get("safety",0):>10.3f} {str(ar):>10}')
332
+
333
+ print(f'\n{"="*60}')
334
+ print('COST-QUALITY FRONTIER')
335
+ for m in sorted(all_metrics.values(), key=lambda x: x.get('avg_cost',0)):
336
+ print(f" {m.get('config','?')}: cost={m.get('avg_cost',0):.3f}, "
337
+ f"acc={m.get('accuracy',0):.3f}, safety={m.get('safety',0):.3f}")
338
+
339
+ out_path = '/tmp/eval_results_v2.json'
340
+ output = {
341
+ 'metrics': all_metrics,
342
+ 'config': {
343
+ 'cheap_model': cheap_id,
344
+ 'verifier_model': verifier_id,
345
+ 'strong_model': strong_id,
346
+ 'eval_dataset': EVAL_DS,
347
+ 'n_examples': len(data),
348
+ 'version': 'v2 — fixed prompt format matching training data',
349
+ 'prompt_format': 'chat template with system prompt + Action: <type> output',
350
+ },
351
+ 'action_distribution': dict(dist),
352
+ }
353
+ with open(out_path, 'w') as f:
354
+ json.dump(output, f, indent=2)
355
+ print(f'\nResults saved to {out_path}')
356
+
357
+ from huggingface_hub import HfApi
358
+ api = HfApi()
359
+ api.upload_file(
360
+ path_or_fileobj=out_path,
361
+ path_in_repo='eval_results_v2.json',
362
+ repo_id=f'{HUB_ORG}/speculative-tool-actions',
363
+ repo_type='model',
364
+ commit_message='Eval v2 results with fixed prompt format matching training data',
365
+ )
366
+ print('Uploaded to Hub!')
367
+
368
+ if __name__ == '__main__':
369
+ main()