narcolepticchicken commited on
Commit
ba7590b
·
verified ·
1 Parent(s): 8be73c9

Add synthetic data + full training pipeline

Browse files
Files changed (1) hide show
  1. synthetic_data_and_train.py +535 -0
synthetic_data_and_train.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Speculative Tool Actions — Synthetic Data + Training Pipeline
4
+ ==============================================================
5
+ Generates synthetic agent traces with 9 action types, trains proposer + verifier,
6
+ evaluates all 5 configs, and produces cost-quality frontier report.
7
+
8
+ Designed to run as a single HF Job on GPU hardware.
9
+ """
10
+ import os, sys, json, re, random, math, subprocess, time
11
+ from collections import Counter, defaultdict
12
+ from datetime import datetime
13
+
14
+ # Install required packages
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet",
16
+ "datasets", "transformers", "trl", "peft", "accelerate", "huggingface_hub", "trackio", "torch"])
17
+
18
+ import torch
19
+ from datasets import Dataset, DatasetDict
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
21
+ from peft import LoraConfig, get_peft_model
22
+ from trl import SFTTrainer, SFTConfig, RewardTrainer, RewardConfig
23
+
24
+ set_seed(42)
25
+ torch.manual_seed(42)
26
+ random.seed(42)
27
+
28
+ HUB_ORG = "narcolepticchicken"
29
+ ACTION_TYPES = [
30
+ "tool_call", "retrieval", "file_read", "file_write",
31
+ "repair", "verifier", "ask_clarification", "final_answer", "BLOCKED",
32
+ ]
33
+
34
+ COST = {"strong_in": 1.0, "strong_out": 1.0, "cheap_in": 0.2, "cheap_out": 0.2}
35
+
36
+ # ========================================================================
37
+ # Synthetic Data Generator
38
+ # ========================================================================
39
+ TASK_TEMPLATES = [
40
+ "Fix a bug in the authentication module.",
41
+ "Implement a new search feature.",
42
+ "Write unit tests for the API layer.",
43
+ "Refactor the database connection pool.",
44
+ "Add logging to the payment gateway.",
45
+ "Update documentation for the CLI tool.",
46
+ "Debug a memory leak in the worker process.",
47
+ "Optimize the image processing pipeline.",
48
+ "Integrate a third-party OAuth provider.",
49
+ "Set up CI/CD for the microservice.",
50
+ "Migrate from REST to GraphQL.",
51
+ "Add rate limiting to the public API.",
52
+ "Create a backup strategy for the database.",
53
+ "Audit the codebase for security vulnerabilities.",
54
+ "Implement caching for frequently accessed data.",
55
+ ]
56
+
57
+ STATE_TEMPLATES = {
58
+ "tool_call": [
59
+ "I need to call the API to fetch user data.",
60
+ "Let me invoke the linter to check syntax.",
61
+ "I'll execute the test runner now.",
62
+ "Time to trigger the deployment script.",
63
+ ],
64
+ "retrieval": [
65
+ "I should search for similar issues in the tracker.",
66
+ "Let me look up the documentation for this function.",
67
+ "I'll query the knowledge base for best practices.",
68
+ "Need to find examples of this pattern online.",
69
+ ],
70
+ "file_read": [
71
+ "I need to read the configuration file first.",
72
+ "Let me check the existing implementation.",
73
+ "I'll examine the log file for clues.",
74
+ "Need to view the schema definition.",
75
+ ],
76
+ "file_write": [
77
+ "I'll write the fix to the source file.",
78
+ "Let me save the test cases to disk.",
79
+ "Need to update the requirements file.",
80
+ "I'll create a new migration script.",
81
+ ],
82
+ "repair": [
83
+ "The build is failing; let me fix the import error.",
84
+ "There's a null pointer exception to patch.",
85
+ "I need to correct the regex pattern.",
86
+ "Let me resolve the merge conflict.",
87
+ ],
88
+ "verifier": [
89
+ "Let me verify the fix by running tests.",
90
+ "I should check if the output is valid JSON.",
91
+ "Need to validate the schema changes.",
92
+ "I'll confirm the permissions are correct.",
93
+ ],
94
+ "ask_clarification": [
95
+ "I'm unsure about the expected behavior—could you clarify?",
96
+ "What is the target environment for this change?",
97
+ "Do you want me to preserve backward compatibility?",
98
+ "Which branch should I base this on?",
99
+ ],
100
+ "final_answer": [
101
+ "The task is complete. Summary of changes: ...",
102
+ "All tests pass. Here's the final solution.",
103
+ "Deployment successful. Verification complete.",
104
+ "Issue resolved. Closing the ticket.",
105
+ ],
106
+ "BLOCKED": [
107
+ "This request appears unsafe and I cannot proceed.",
108
+ "I'm sorry, but I cannot execute this command.",
109
+ "Blocked: the action violates safety policies.",
110
+ "Unsafe operation detected. Refusing to continue.",
111
+ ],
112
+ }
113
+
114
+ OBSERVATION_TEMPLATES = {
115
+ "tool_call": "Tool returned: status=200, data={...}",
116
+ "retrieval": "Found 3 relevant documents. Top result: ...",
117
+ "file_read": "File contents: 142 lines, class Foo { ... }",
118
+ "file_write": "File saved successfully. 3 lines changed.",
119
+ "repair": "Build passing. 0 errors, 2 warnings.",
120
+ "verifier": "Validation passed. Schema matches.",
121
+ "ask_clarification": "User replied: please use the main branch.",
122
+ "final_answer": "(no further action)",
123
+ "BLOCKED": "(no further action)",
124
+ }
125
+
126
+
127
+ def generate_trace(length=5, resolved_prob=0.8):
128
+ """Generate a single synthetic agent trace."""
129
+ task = random.choice(TASK_TEMPLATES)
130
+ messages = [{"role": "user", "content": task}]
131
+ gold_actions = []
132
+ for step in range(length):
133
+ # Choose action based on position in trace
134
+ if step == length - 1:
135
+ action = random.choices(
136
+ ["final_answer", "BLOCKED"],
137
+ weights=[0.85, 0.15]
138
+ )[0]
139
+ elif step == 0:
140
+ action = random.choices(
141
+ ["tool_call", "retrieval", "file_read", "ask_clarification"],
142
+ weights=[0.3, 0.25, 0.25, 0.2]
143
+ )[0]
144
+ else:
145
+ action = random.choice(ACTION_TYPES[:-2]) # exclude final_answer, BLOCKED
146
+
147
+ content = random.choice(STATE_TEMPLATES[action])
148
+ messages.append({"role": "assistant", "content": content})
149
+ gold_actions.append(action)
150
+ if action not in ("final_answer", "BLOCKED", "ask_clarification"):
151
+ messages.append({"role": "tool", "content": OBSERVATION_TEMPLATES[action]})
152
+
153
+ resolved = random.random() < resolved_prob
154
+ return messages, gold_actions, resolved
155
+
156
+
157
+ def build_synthetic_datasets(n_train=5000, n_test=500):
158
+ print("=== Generating Synthetic Datasets ===")
159
+ p_rows, v_rows, e_rows = [], [], []
160
+
161
+ for _ in range(n_train + n_test):
162
+ msgs, actions, resolved = generate_trace(
163
+ length=random.randint(2, 6),
164
+ resolved_prob=0.75 if _ < n_train else 0.5
165
+ )
166
+ state = []
167
+ for i, msg in enumerate(msgs):
168
+ if msg["role"] == "assistant":
169
+ action = actions[len([m for m in msgs[:i] if m["role"] == "assistant"]) - 1]
170
+ comp = [{"role": "assistant", "content": msg["content"]}]
171
+ p_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "action_type": action})
172
+ v_rows.append({"prompt": [m.copy() for m in state], "completion": comp, "label": resolved, "action_type": action})
173
+ e_rows.append({"messages": [m.copy() for m in state] + comp, "resolved": resolved, "action_type": action})
174
+ state.append(msg)
175
+
176
+ print(f"Total: proposer={len(p_rows)}, verifier={len(v_rows)}, eval={len(e_rows)}")
177
+ print("Distribution:", Counter(r["action_type"] for r in p_rows).most_common())
178
+
179
+ def fmt_proposer(r):
180
+ sys_msg = {"role": "system", "content": (
181
+ "You are an agent action predictor. Predict the next action from: "
182
+ + ", ".join(ACTION_TYPES) + ". Respond with exactly the action name.")}
183
+ prompt = [sys_msg] + r["prompt"]
184
+ if prompt:
185
+ prompt[-1]["content"] += "\n\n[Next Action] Choose one: " + ", ".join(ACTION_TYPES)
186
+ comp = r["completion"]
187
+ comp[0]["content"] = f"Action: {r['action_type']}\n" + comp[0]["content"]
188
+ return {"prompt": prompt, "completion": comp}
189
+
190
+ proposer_all = [fmt_proposer(r) for r in p_rows]
191
+ random.shuffle(proposer_all)
192
+ proposer_ds = DatasetDict({
193
+ "train": Dataset.from_list(proposer_all[:n_train]),
194
+ "test": Dataset.from_list(proposer_all[n_train:]),
195
+ })
196
+ proposer_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-proposer-sft")
197
+ print("Pushed proposer dataset")
198
+
199
+ rng = random.Random(42)
200
+ good = [r for r in v_rows if r["label"]]
201
+ bad = [r for r in v_rows if not r["label"]]
202
+ if len(bad) < len(good) * 0.2:
203
+ for r in good:
204
+ wa = rng.choice([a for a in ACTION_TYPES if a != r["action_type"]])
205
+ bad.append({
206
+ "prompt": [m.copy() for m in r["prompt"]],
207
+ "completion": [{"role": "assistant", "content": f"Action: {wa}\n(incorrect action)"}],
208
+ "label": False, "action_type": wa,
209
+ })
210
+ pairs = []
211
+ for g in good:
212
+ b = rng.choice(bad)
213
+ pairs.append({
214
+ "prompt": [m.copy() for m in g["prompt"]],
215
+ "chosen": g["completion"],
216
+ "rejected": b["completion"],
217
+ "action_type": g["action_type"],
218
+ })
219
+ random.shuffle(pairs)
220
+ verifier_ds = DatasetDict({
221
+ "train": Dataset.from_list(pairs[:n_train]),
222
+ "test": Dataset.from_list(pairs[n_train:]),
223
+ })
224
+ verifier_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-verifier-pref")
225
+ print("Pushed verifier dataset")
226
+
227
+ eval_all = e_rows
228
+ random.shuffle(eval_all)
229
+ eval_ds = Dataset.from_list(eval_all[:n_test])
230
+ eval_ds.push_to_hub(f"{HUB_ORG}/speculative-actions-eval")
231
+ print("Pushed eval dataset")
232
+
233
+ return proposer_ds, verifier_ds, eval_ds
234
+
235
+
236
+ # ========================================================================
237
+ # Training
238
+ # ========================================================================
239
+ def train_proposer():
240
+ print("\n=== Training Proposer ===")
241
+ ds = DatasetDict.load_from_disk(f"{HUB_ORG}/speculative-actions-proposer-sft") if False else None
242
+ # load from hub
243
+ from datasets import load_dataset
244
+ ds = load_dataset(f"{HUB_ORG}/speculative-actions-proposer-sft")
245
+
246
+ peft_config = LoraConfig(
247
+ r=16, lora_alpha=32,
248
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
249
+ modules_to_save=["embed_tokens", "lm_head"],
250
+ )
251
+ config = SFTConfig(
252
+ output_dir="/tmp/proposer-out",
253
+ hub_model_id=f"{HUB_ORG}/speculative-proposer-qwen3-1.7b",
254
+ push_to_hub=True,
255
+ learning_rate=2e-4,
256
+ per_device_train_batch_size=4,
257
+ gradient_accumulation_steps=4,
258
+ num_train_epochs=2,
259
+ max_seq_length=2048,
260
+ bf16=True,
261
+ gradient_checkpointing=True,
262
+ logging_strategy="steps",
263
+ logging_steps=10,
264
+ logging_first_step=True,
265
+ disable_tqdm=True,
266
+ report_to="trackio",
267
+ run_name="proposer-sft-qwen3-1.7b",
268
+ )
269
+ trainer = SFTTrainer(
270
+ model="Qwen/Qwen3-1.7B",
271
+ train_dataset=ds["train"],
272
+ eval_dataset=ds["test"],
273
+ args=config,
274
+ peft_config=peft_config,
275
+ )
276
+ trainer.train()
277
+ trainer.push_to_hub()
278
+ print("Proposer training done.")
279
+
280
+
281
+ def train_verifier():
282
+ print("\n=== Training Verifier ===")
283
+ from datasets import load_dataset
284
+ ds = load_dataset(f"{HUB_ORG}/speculative-actions-verifier-pref")
285
+
286
+ peft_config = LoraConfig(
287
+ r=16, lora_alpha=32,
288
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
289
+ modules_to_save=["score"],
290
+ )
291
+ config = RewardConfig(
292
+ output_dir="/tmp/verifier-out",
293
+ hub_model_id=f"{HUB_ORG}/speculative-verifier-qwen3-4b",
294
+ push_to_hub=True,
295
+ learning_rate=1e-3,
296
+ per_device_train_batch_size=2,
297
+ gradient_accumulation_steps=8,
298
+ num_train_epochs=2,
299
+ max_seq_length=2048,
300
+ bf16=True,
301
+ gradient_checkpointing=True,
302
+ logging_strategy="steps",
303
+ logging_steps=10,
304
+ logging_first_step=True,
305
+ disable_tqdm=True,
306
+ report_to="trackio",
307
+ run_name="verifier-reward-qwen3-4b",
308
+ )
309
+ trainer = RewardTrainer(
310
+ model="Qwen/Qwen3-4B",
311
+ train_dataset=ds["train"],
312
+ eval_dataset=ds["test"],
313
+ args=config,
314
+ peft_config=peft_config,
315
+ )
316
+ trainer.train()
317
+ trainer.push_to_hub()
318
+ print("Verifier training done.")
319
+
320
+
321
+ # ========================================================================
322
+ # Evaluation
323
+ # ========================================================================
324
+ def parse_action(text):
325
+ for a in ACTION_TYPES:
326
+ if a.lower() in text.lower():
327
+ return a
328
+ return "tool_call"
329
+
330
+
331
+ class EvalRunner:
332
+ def __init__(self, strong_name, cheap_name, verifier_name, device="cuda"):
333
+ self.device = device
334
+ print(f"Loading strong model: {strong_name}")
335
+ self.strong_tok = AutoTokenizer.from_pretrained(strong_name, trust_remote_code=True)
336
+ self.strong_model = AutoModelForCausalLM.from_pretrained(
337
+ strong_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
338
+ print(f"Loading cheap model: {cheap_name}")
339
+ self.cheap_tok = AutoTokenizer.from_pretrained(cheap_name, trust_remote_code=True)
340
+ self.cheap_model = AutoModelForCausalLM.from_pretrained(
341
+ cheap_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
342
+ self.verifier_name = verifier_name
343
+ if verifier_name:
344
+ print(f"Loading verifier: {verifier_name}")
345
+ self.v_tok = AutoTokenizer.from_pretrained(verifier_name, trust_remote_code=True)
346
+ self.v_model = AutoModelForCausalLM.from_pretrained(
347
+ verifier_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
348
+
349
+ def _gen(self, model, tokenizer, messages, max_new=64, temp=0.0):
350
+ inputs = tokenizer.apply_chat_template(
351
+ messages, tokenize=True, return_tensors="pt", add_generation_prompt=True
352
+ ).to(model.device)
353
+ with torch.no_grad():
354
+ out = model.generate(
355
+ inputs, max_new_tokens=max_new, do_sample=temp > 0,
356
+ temperature=temp if temp > 0 else None,
357
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
358
+ )
359
+ text = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
360
+ return text, inputs.shape[1], out.shape[1] - inputs.shape[1]
361
+
362
+ def run_a(self, messages):
363
+ s = {"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}
364
+ out, i, o = self._gen(self.strong_model, self.strong_tok, [s] + messages)
365
+ return parse_action(out), i, o, "strong"
366
+
367
+ def run_b(self, messages):
368
+ s = {"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}
369
+ out, i, o = self._gen(self.cheap_model, self.cheap_tok, [s] + messages)
370
+ return parse_action(out), i, o, "cheap"
371
+
372
+ def run_c(self, messages):
373
+ s = {"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}
374
+ prop, i1, o1 = self._gen(self.cheap_model, self.cheap_tok, [s] + messages)
375
+ vp = messages + [{"role": "assistant", "content": prop},
376
+ {"role": "user", "content": "Is this action correct? Answer ONLY yes or no."}]
377
+ ver, i2, o2 = self._gen(self.strong_model, self.strong_tok, vp, max_new=10)
378
+ if "yes" in ver.lower():
379
+ return parse_action(prop), i1 + i2, o1 + o2, "mixed"
380
+ out, i3, o3 = self._gen(self.strong_model, self.strong_tok, [s] + messages)
381
+ return parse_action(out), i1 + i2 + i3, o1 + o2 + o3, "mixed"
382
+
383
+ def run_d(self, messages):
384
+ if not self.verifier_name:
385
+ raise ValueError("Need verifier")
386
+ s = {"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}
387
+ prop, i1, o1 = self._gen(self.cheap_model, self.cheap_tok, [s] + messages)
388
+ vp = messages + [{"role": "assistant", "content": prop},
389
+ {"role": "user", "content": "Rate this action: good or bad."}]
390
+ ver, i2, o2 = self._gen(self.v_model, self.v_tok, vp, max_new=10)
391
+ if "good" in ver.lower():
392
+ return parse_action(prop), i1 + i2, o1 + o2, "cheap"
393
+ out, i3, o3 = self._gen(self.strong_model, self.strong_tok, [s] + messages)
394
+ return parse_action(out), i1 + i2 + i3, o1 + o2 + o3, "mixed"
395
+
396
+ def run_e(self, messages, n=3):
397
+ s = {"role": "system", "content": f"Predict next action from: {', '.join(ACTION_TYPES)}"}
398
+ props = []
399
+ ti, to = 0, 0
400
+ for _ in range(n):
401
+ p, i, o = self._gen(self.cheap_model, self.cheap_tok, [s] + messages, temp=0.7)
402
+ props.append(p); ti += i; to += o
403
+ best = props[0]; best_score = -1
404
+ for p in props:
405
+ rp = messages + [{"role": "assistant", "content": p},
406
+ {"role": "user", "content": "Score 1-10."}]
407
+ st, i, o = self._gen(self.strong_model, self.strong_tok, rp, max_new=5)
408
+ ti += i; to += o
409
+ m = re.search(r'(\d+)', st)
410
+ if m:
411
+ sc = int(m.group(1))
412
+ if sc > best_score:
413
+ best_score = sc; best = p
414
+ return parse_action(best), ti, to, "mixed"
415
+
416
+
417
+ def evaluate(limit=200):
418
+ print("\n=== Evaluation ===")
419
+ from datasets import load_dataset
420
+ ds = load_dataset(f"{HUB_ORG}/speculative-actions-eval", split="train")
421
+ ds = ds.shuffle(seed=42).select(range(min(limit, len(ds))))
422
+
423
+ runner = EvalRunner(
424
+ strong_name="Qwen/Qwen2.5-7B-Instruct",
425
+ cheap_name="Qwen/Qwen3-1.7B",
426
+ verifier_name=f"{HUB_ORG}/speculative-verifier-qwen3-4b",
427
+ )
428
+
429
+ results = defaultdict(lambda: {"correct": 0, "total": 0, "cost": 0.0, "unsafe": 0})
430
+ for idx, ex in enumerate(ds):
431
+ msgs = ex["messages"]; gold = ex["action_type"]
432
+ for cfg, func in [("A", runner.run_a), ("B", runner.run_b),
433
+ ("C", runner.run_c), ("D", runner.run_d),
434
+ ("E", lambda m: runner.run_e(m, n=3))]:
435
+ try:
436
+ pred, i_t, o_t, mtype = func(msgs)
437
+ except Exception as e:
438
+ print(f"Error {cfg} idx {idx}: {e}")
439
+ pred, i_t, o_t, mtype = "tool_call", 0, 0, "unknown"
440
+ results[cfg]["total"] += 1
441
+ if pred == gold:
442
+ results[cfg]["correct"] += 1
443
+ if pred == "BLOCKED" and gold != "BLOCKED":
444
+ results[cfg]["unsafe"] += 1
445
+ if pred != "BLOCKED" and gold == "BLOCKED":
446
+ results[cfg]["unsafe"] += 1
447
+ results[cfg]["cost"] += i_t * COST.get(f"{mtype}_in", 1.0) + o_t * COST.get(f"{mtype}_out", 1.0)
448
+ if (idx + 1) % 20 == 0:
449
+ print(f" Evaluated {idx + 1}/{min(limit, len(ds))}")
450
+
451
+ for cfg in results:
452
+ t = max(results[cfg]["total"], 1)
453
+ results[cfg]["accuracy"] = results[cfg]["correct"] / t
454
+ results[cfg]["avg_cost"] = results[cfg]["cost"] / t
455
+ results[cfg]["unsafe_rate"] = results[cfg]["unsafe"] / t
456
+
457
+ summary = {k: dict(v) for k, v in results.items()}
458
+ with open("/tmp/eval_results.json", "w") as f:
459
+ json.dump(summary, f, indent=2)
460
+ print(json.dumps(summary, indent=2))
461
+ return summary
462
+
463
+
464
+ # ========================================================================
465
+ # Report
466
+ # ========================================================================
467
+ def generate_report(eval_results):
468
+ print("\n=== Generating Report ===")
469
+ lines = ["# Speculative Tool Actions — Ablation Report\n\n"]
470
+ lines.append("## Configurations\n\n")
471
+ lines.append("- **A**: Always strong model (Qwen2.5-7B)\n")
472
+ lines.append("- **B**: Cheap model only (Qwen3-1.7B)\n")
473
+ lines.append("- **C**: Cheap proposer + strong verifier\n")
474
+ lines.append("- **D**: Cheap proposer + trained trace judge (Qwen3-4B reward model)\n")
475
+ lines.append("- **E**: Multi-proposal reranking (3 cheap proposals + strong scoring)\n\n")
476
+
477
+ lines.append("## Results\n\n")
478
+ lines.append("| Config | Accuracy | Avg Cost | Unsafe-Action Rate |\n")
479
+ lines.append("|--------|----------|----------|-------------------|\n")
480
+ for cfg in sorted(eval_results):
481
+ r = eval_results[cfg]
482
+ lines.append(f"| {cfg} | {r['accuracy']:.3f} | {r['avg_cost']:.2f} | {r['unsafe_rate']:.3f} |\n")
483
+
484
+ lines.append("\n## Cost-Quality Frontier\n\n")
485
+ points = [(r["avg_cost"], r["accuracy"], cfg) for cfg, r in eval_results.items()]
486
+ points.sort()
487
+ frontier = []
488
+ max_acc = -1
489
+ for cost, acc, cfg in points:
490
+ if acc > max_acc:
491
+ frontier.append((cost, acc, cfg)); max_acc = acc
492
+ lines.append("Pareto-optimal configs:\n")
493
+ for cost, acc, cfg in frontier:
494
+ lines.append(f"- **{cfg}**: cost={cost:.2f}, accuracy={acc:.3f}\n")
495
+
496
+ lines.append("\n## Recommendations\n\n")
497
+ best_ratio = None; best_cfg = None
498
+ for cfg, r in eval_results.items():
499
+ ratio = r["accuracy"] / max(r["avg_cost"], 0.01)
500
+ if best_ratio is None or ratio > best_ratio:
501
+ best_ratio = ratio; best_cfg = cfg
502
+ lines.append(f"- **Best accuracy/cost ratio**: Config {best_cfg} (ratio={best_ratio:.3f})\n")
503
+
504
+ best_acc_cfg = max(eval_results, key=lambda c: eval_results[c]["accuracy"])
505
+ lines.append(f"- **Highest accuracy**: Config {best_acc_cfg} ({eval_results[best_acc_cfg]['accuracy']:.3f})\n")
506
+
507
+ best_acc = eval_results[best_acc_cfg]["accuracy"]
508
+ threshold = best_acc * 0.9
509
+ cheap = {c: r for c, r in eval_results.items() if r["accuracy"] >= threshold}
510
+ if cheap:
511
+ cheapest = min(cheap, key=lambda c: cheap[c]["avg_cost"])
512
+ lines.append(f"- **Cheapest within 90% of best accuracy**: Config {cheapest} "
513
+ f"(cost={cheap[cheapest]['avg_cost']:.2f}, acc={cheap[cheapest]['accuracy']:.3f})\n")
514
+
515
+ report = "".join(lines)
516
+ with open("/tmp/ablation_report.md", "w") as f:
517
+ f.write(report)
518
+ print(report)
519
+ return report
520
+
521
+
522
+ # ========================================================================
523
+ # Main
524
+ # ========================================================================
525
+ def main():
526
+ build_synthetic_datasets(n_train=5000, n_test=500)
527
+ train_proposer()
528
+ train_verifier()
529
+ results = evaluate(limit=200)
530
+ generate_report(results)
531
+ print("\n=== Pipeline Complete ===")
532
+
533
+
534
+ if __name__ == "__main__":
535
+ main()