narcolepticchicken commited on
Commit
69159fc
·
verified ·
1 Parent(s): 20985a9

Upload eval_base_models.py

Browse files
Files changed (1) hide show
  1. eval_base_models.py +202 -0
eval_base_models.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation using base models with prompt engineering.
3
+ Since fine-tuning takes hours, we use carefully crafted prompts
4
+ on base models to simulate the speculative decoding pipeline.
5
+ """
6
+ import json, random, sys
7
+ from collections import Counter
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from datasets import load_dataset
11
+
12
+ HUB_ORG = 'narcolepticchicken'
13
+ EVAL_DS = f'{HUB_ORG}/speculative-actions-eval'
14
+
15
+ def load_model(name, device):
16
+ print(f'Loading {name}...', flush=True)
17
+ tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
18
+ if tok.pad_token is None:
19
+ tok.pad_token = tok.eos_token
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ name,
22
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
23
+ trust_remote_code=True,
24
+ low_cpu_mem_usage=True,
25
+ )
26
+ if device == 'cuda':
27
+ model = model.to(device)
28
+ return model, tok
29
+
30
+ def predict_action(model, tokenizer, prompt, device, max_new_tokens=15):
31
+ with torch.no_grad():
32
+ inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024)
33
+ if device == 'cuda':
34
+ inputs = {k: v.to(device) for k, v in inputs.items()}
35
+ outputs = model.generate(
36
+ **inputs,
37
+ max_new_tokens=max_new_tokens,
38
+ do_sample=False,
39
+ pad_token_id=tokenizer.pad_token_id,
40
+ )
41
+ text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
42
+ return text
43
+
44
+ def parse_action(text):
45
+ text_lower = text.lower()
46
+ actions = ['tool_call','retrieval','file_read','file_write','repair','verifier','ask_clarification','final_answer','blocked']
47
+ for a in actions:
48
+ if a in text_lower:
49
+ return a
50
+ return 'tool_call'
51
+
52
+ def load_eval_data(n=50):
53
+ ds = load_dataset(EVAL_DS)['test']
54
+ n = min(n, len(ds))
55
+ return [ds[i] for i in range(n)]
56
+
57
+ def build_proposer_prompt(context, task_type):
58
+ return f"""You are an AI agent. Choose ONE action from:
59
+ -tool_call
60
+ -retrieval
61
+ -file_read
62
+ -file_write
63
+ -repair
64
+ -verifier
65
+ -ask_clarification
66
+ -final_answer
67
+ -blocked
68
+
69
+ Task: {task_type}
70
+ Context: {context}
71
+
72
+ Action:"""
73
+
74
+ def build_verifier_prompt(context, task_type, proposed):
75
+ return f"""You verify if an agent action is correct.
76
+ Task: {task_type}
77
+ Context: {context}
78
+ Proposed action: {proposed}
79
+
80
+ Is this the best action? Answer with just YES or NO.
81
+
82
+ Answer:"""
83
+
84
+ def run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device):
85
+ results = {'A': [], 'B': [], 'C': [], 'D': [], 'E': []}
86
+
87
+ for i, ex in enumerate(data):
88
+ print(f'Processing {i+1}/{len(data)}...', flush=True)
89
+
90
+ # Config B: Cheap only
91
+ prompt_b = build_proposer_prompt(ex['context'], ex['task_type'])
92
+ pred_b = parse_action(predict_action(proposer, proposer_tok, prompt_b, device))
93
+ results['B'].append({'pred': pred_b, 'true': ex['action'], 'cost': 0.2})
94
+
95
+ # Config A: Strong only
96
+ prompt_a = build_proposer_prompt(ex['context'], ex['task_type'])
97
+ pred_a = parse_action(predict_action(strong, strong_tok, prompt_a, device))
98
+ results['A'].append({'pred': pred_a, 'true': ex['action'], 'cost': 1.0})
99
+
100
+ # Config C: Cheap + Strong verifier
101
+ prompt_c1 = build_proposer_prompt(ex['context'], ex['task_type'])
102
+ cheap_pred = parse_action(predict_action(proposer, proposer_tok, prompt_c1, device))
103
+ prompt_c2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred)
104
+ verify_text = predict_action(strong, strong_tok, prompt_c2, device, max_new_tokens=5)
105
+ accepted = 'yes' in verify_text.lower()
106
+ if accepted:
107
+ pred_c = cheap_pred
108
+ cost_c = 0.2 + 0.3
109
+ else:
110
+ pred_c = parse_action(predict_action(strong, strong_tok, prompt_c1, device))
111
+ cost_c = 0.2 + 0.3 + 1.0
112
+ results['C'].append({'pred': pred_c, 'true': ex['action'], 'cost': cost_c})
113
+
114
+ # Config D: Cheap + Trained verifier (simulated with base Qwen3-4B)
115
+ prompt_d1 = build_proposer_prompt(ex['context'], ex['task_type'])
116
+ cheap_pred_d = parse_action(predict_action(proposer, proposer_tok, prompt_d1, device))
117
+ prompt_d2 = build_verifier_prompt(ex['context'], ex['task_type'], cheap_pred_d)
118
+ verify_text_d = predict_action(verifier, verifier_tok, prompt_d2, device, max_new_tokens=5)
119
+ accepted_d = 'yes' in verify_text_d.lower()
120
+ if accepted_d:
121
+ pred_d = cheap_pred_d
122
+ cost_d = 0.2 + 0.15
123
+ else:
124
+ # Fallback to verifier's own judgment
125
+ pred_d = parse_action(predict_action(verifier, verifier_tok, prompt_d1, device))
126
+ cost_d = 0.2 + 0.15 + 0.6
127
+ results['D'].append({'pred': pred_d, 'true': ex['action'], 'cost': cost_d})
128
+
129
+ # Config E: Multi-proposal reranking
130
+ proposals = []
131
+ for _ in range(3):
132
+ prompt_e = build_proposer_prompt(ex['context'], ex['task_type'])
133
+ proposals.append(parse_action(predict_action(proposer, proposer_tok, prompt_e, device)))
134
+
135
+ scores = []
136
+ for prop in proposals:
137
+ score_prompt = f"""Rate this action 1-10 for the task.
138
+ Task: {ex['task_type']}
139
+ Context: {ex['context']}
140
+ Action: {prop}
141
+
142
+ Score (1-10):"""
143
+ score_text = predict_action(strong, strong_tok, score_prompt, device, max_new_tokens=5)
144
+ score = 5
145
+ for word in score_text.split():
146
+ try:
147
+ score = int(word.strip('.,!?'))
148
+ break
149
+ except:
150
+ pass
151
+ scores.append(score)
152
+
153
+ best_idx = scores.index(max(scores))
154
+ pred_e = proposals[best_idx]
155
+ cost_e = 0.2 * 3 + 0.3 * 3
156
+ results['E'].append({'pred': pred_e, 'true': ex['action'], 'cost': cost_e})
157
+
158
+ return results
159
+
160
+ def compute_metrics(results_list):
161
+ correct = sum(1 for r in results_list if r['pred'] == r['true'])
162
+ total = len(results_list)
163
+ accuracy = correct / total
164
+ avg_cost = sum(r['cost'] for r in results_list) / total
165
+
166
+ by_action = {}
167
+ for a in set(r['true'] for r in results_list):
168
+ subset = [r for r in results_list if r['true'] == a]
169
+ by_action[a] = sum(1 for r in subset if r['pred'] == a) / len(subset)
170
+
171
+ return {'accuracy': accuracy, 'avg_cost': avg_cost, 'n': total, 'by_action': by_action}
172
+
173
+ def main():
174
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
175
+ print(f'Device: {device}', flush=True)
176
+
177
+ # Load a smaller subset for faster evaluation
178
+ data = load_eval_data(n=50)
179
+ print(f'Loaded {len(data)} eval examples', flush=True)
180
+
181
+ # Load models
182
+ proposer, proposer_tok = load_model('Qwen/Qwen3-1.7B', device)
183
+ verifier, verifier_tok = load_model('Qwen/Qwen3-4B', device)
184
+ strong, strong_tok = load_model('Qwen/Qwen2.5-7B', device)
185
+
186
+ print('Running evaluation...', flush=True)
187
+ all_results = run_eval(data, proposer, proposer_tok, verifier, verifier_tok, strong, strong_tok, device)
188
+
189
+ summary = {}
190
+ print('\n=== RESULTS ===', flush=True)
191
+ for cfg in ['A','B','C','D','E']:
192
+ metrics = compute_metrics(all_results[cfg])
193
+ summary[cfg] = metrics
194
+ print(f"Config {cfg}: Accuracy={metrics['accuracy']:.3f}, Cost={metrics['avg_cost']:.2f}", flush=True)
195
+
196
+ with open('/tmp/eval_results_empirical.json', 'w') as f:
197
+ json.dump(summary, f, indent=2)
198
+
199
+ print('\nSaved to /tmp/eval_results_empirical.json', flush=True)
200
+
201
+ if __name__ == '__main__':
202
+ main()