narcolepticchicken commited on
Commit
d18f367
·
verified ·
1 Parent(s): c55ff1a

Upload aco/datasets/synthetic_traces.py

Browse files
Files changed (1) hide show
  1. aco/datasets/synthetic_traces.py +416 -0
aco/datasets/synthetic_traces.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Synthetic Trace Generator.
2
+
3
+ Generates 10,000+ agent traces with:
4
+ - task type
5
+ - model used
6
+ - tool calls
7
+ - context size
8
+ - cost
9
+ - latency
10
+ - failure mode
11
+ - final outcome
12
+ - optimal cheaper alternative
13
+ - recovery action
14
+ - verifier need
15
+ - escalation decision
16
+
17
+ Includes traces with:
18
+ - cheap model success
19
+ - cheap model failure
20
+ - frontier model unnecessary
21
+ - tool overuse
22
+ - tool underuse
23
+ - retrieval overuse
24
+ - verifier overuse
25
+ - retry loops
26
+ - cache breaks
27
+ - false-DONE
28
+ - successful meta-tool reuse
29
+ - bad meta-tool reuse
30
+ """
31
+
32
+ import uuid
33
+ import random
34
+ import json
35
+ from datetime import datetime, timedelta
36
+ from typing import List, Dict, Any, Optional
37
+ from dataclasses import asdict
38
+
39
+ from aco.trace_schema import (
40
+ AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall,
41
+ TaskType, Outcome, FailureTag,
42
+ )
43
+
44
+
45
+ class SyntheticTraceGenerator:
46
+ """Generates diverse synthetic agent traces for training and benchmarking."""
47
+
48
+ MODEL_CONFIGS = {
49
+ "tiny_local": {"tier": 1, "cost_input": 0.0001, "cost_output": 0.0002, "latency": 200, "strength": 0.3},
50
+ "cheap_cloud": {"tier": 2, "cost_input": 0.0005, "cost_output": 0.001, "latency": 500, "strength": 0.5},
51
+ "medium": {"tier": 3, "cost_input": 0.003, "cost_output": 0.006, "latency": 800, "strength": 0.75},
52
+ "frontier": {"tier": 4, "cost_input": 0.01, "cost_output": 0.03, "latency": 1500, "strength": 0.95},
53
+ "specialist": {"tier": 5, "cost_input": 0.015, "cost_output": 0.045, "latency": 2000, "strength": 0.98},
54
+ }
55
+
56
+ TOOL_COSTS = {
57
+ "search": 0.002,
58
+ "retrieve": 0.001,
59
+ "fetch": 0.003,
60
+ "code_execution": 0.005,
61
+ "linter": 0.001,
62
+ "test_runner": 0.003,
63
+ "file_read": 0.0005,
64
+ "file_write": 0.0005,
65
+ "calculator": 0.0001,
66
+ "database_query": 0.004,
67
+ "compliance_check": 0.01,
68
+ "summarize": 0.002,
69
+ "task_planner": 0.001,
70
+ "progress_tracker": 0.0005,
71
+ }
72
+
73
+ TASK_TYPE_DISTRIBUTION = {
74
+ TaskType.QUICK_ANSWER: 0.20,
75
+ TaskType.CODING: 0.20,
76
+ TaskType.RESEARCH: 0.15,
77
+ TaskType.DOCUMENT_DRAFTING: 0.10,
78
+ TaskType.LEGAL_REGULATED: 0.05,
79
+ TaskType.TOOL_HEAVY: 0.10,
80
+ TaskType.RETRIEVAL_HEAVY: 0.10,
81
+ TaskType.LONG_HORIZON: 0.08,
82
+ TaskType.UNKNOWN_AMBIGUOUS: 0.02,
83
+ }
84
+
85
+ # Scenario templates for generating realistic traces
86
+ SCENARIOS = [
87
+ # cheap model success
88
+ {"name": "cheap_success", "prob": 0.15, "tier": [1, 2], "outcome": Outcome.SUCCESS, "failure_tags": []},
89
+ # cheap model failure -> should have escalated
90
+ {"name": "cheap_failure", "prob": 0.10, "tier": [1, 2], "outcome": Outcome.FAILURE, "failure_tags": [FailureTag.MODEL_TOO_WEAK]},
91
+ # frontier model unnecessary -> overpaid
92
+ {"name": "frontier_unnecessary", "prob": 0.08, "tier": [4], "outcome": Outcome.SUCCESS, "failure_tags": [], "optimal_tier": [1, 2]},
93
+ # tool overuse
94
+ {"name": "tool_overuse", "prob": 0.07, "tier": [3, 4], "outcome": Outcome.PARTIAL_SUCCESS, "failure_tags": [FailureTag.TOOL_UNNECESSARY], "extra_tools": 3},
95
+ # tool underuse
96
+ {"name": "tool_underuse", "prob": 0.05, "tier": [3, 4], "outcome": Outcome.FAILURE, "failure_tags": [FailureTag.TOOL_MISSED], "missing_tools": 2},
97
+ # retrieval overuse
98
+ {"name": "retrieval_overuse", "prob": 0.04, "tier": [3, 4], "outcome": Outcome.SUCCESS, "failure_tags": [], "extra_retrievals": 5},
99
+ # verifier overuse
100
+ {"name": "verifier_overuse", "prob": 0.03, "tier": [3, 4], "outcome": Outcome.SUCCESS, "failure_tags": [], "extra_verifiers": 2},
101
+ # retry loops
102
+ {"name": "retry_loop", "prob": 0.05, "tier": [3, 4], "outcome": Outcome.FAILURE, "failure_tags": [FailureTag.RETRY_LOOP], "retries": 5},
103
+ # cache breaks
104
+ {"name": "cache_break", "prob": 0.04, "tier": [3, 4], "outcome": Outcome.PARTIAL_SUCCESS, "failure_tags": [FailureTag.CACHE_BREAK]},
105
+ # false DONE
106
+ {"name": "false_done", "prob": 0.05, "tier": [3, 4], "outcome": Outcome.FALSE_DONE, "failure_tags": [FailureTag.VERIFIER_FALSE_PASS]},
107
+ # meta-tool success
108
+ {"name": "meta_tool_success", "prob": 0.06, "tier": [2, 3], "outcome": Outcome.SUCCESS, "failure_tags": [], "uses_meta_tool": True},
109
+ # meta-tool bad reuse
110
+ {"name": "meta_tool_bad", "prob": 0.02, "tier": [2, 3], "outcome": Outcome.FAILURE, "failure_tags": [FailureTag.MODEL_TOO_WEAK], "uses_meta_tool": True},
111
+ # normal success
112
+ {"name": "normal_success", "prob": 0.20, "tier": [3, 4], "outcome": Outcome.SUCCESS, "failure_tags": []},
113
+ # blocked
114
+ {"name": "blocked", "prob": 0.03, "tier": [4], "outcome": Outcome.BLOCKED, "failure_tags": [FailureTag.MISSED_ESCALATION]},
115
+ # human escalation
116
+ {"name": "human_escalation", "prob": 0.02, "tier": [4, 5], "outcome": Outcome.ESCALATED_HUMAN, "failure_tags": [FailureTag.MISSED_ESCALATION]},
117
+ # stopped doom
118
+ {"name": "stopped_doom", "prob": 0.03, "tier": [3, 4], "outcome": Outcome.STOPPED_DOOM, "failure_tags": [FailureTag.COST_EXCEEDED]},
119
+ ]
120
+
121
+ def __init__(self, seed: int = 42):
122
+ self.rng = random.Random(seed)
123
+
124
+ def generate(self, n: int = 10000) -> List[AgentTrace]:
125
+ """Generate n synthetic traces."""
126
+ traces = []
127
+ for i in range(n):
128
+ trace = self._generate_trace(i)
129
+ traces.append(trace)
130
+ return traces
131
+
132
+ def _generate_trace(self, idx: int) -> AgentTrace:
133
+ trace_id = f"synth_{idx}_{uuid.uuid4().hex[:8]}"
134
+
135
+ # Pick task type
136
+ task_type = self.rng.choices(
137
+ list(self.TASK_TYPE_DISTRIBUTION.keys()),
138
+ weights=list(self.TASK_TYPE_DISTRIBUTION.values()),
139
+ )[0]
140
+
141
+ # Pick scenario
142
+ scenario = self._pick_scenario()
143
+
144
+ # Generate user request based on task type
145
+ user_request = self._generate_request(task_type, scenario["name"])
146
+
147
+ # Number of steps
148
+ base_steps = self.rng.randint(1, 8)
149
+ if scenario["name"] in ("retry_loop", "false_done"):
150
+ base_steps = self.rng.randint(5, 12)
151
+ if scenario.get("uses_meta_tool"):
152
+ base_steps = max(2, base_steps // 2) # meta-tools compress steps
153
+
154
+ steps = []
155
+ tier = self.rng.choice(scenario["tier"])
156
+ model_key = self._tier_to_model(tier)
157
+ model_cfg = self.MODEL_CONFIGS[model_key]
158
+
159
+ total_cost = 0.0
160
+ total_latency = 0.0
161
+
162
+ for step_idx in range(base_steps):
163
+ step = self._generate_step(
164
+ trace_id, step_idx, task_type, model_key, model_cfg,
165
+ scenario, step_idx == base_steps - 1,
166
+ )
167
+ steps.append(step)
168
+ total_cost += step.step_cost
169
+ total_latency += step.step_latency_ms
170
+
171
+ # Generate final outcome
172
+ outcome = scenario["outcome"]
173
+ failure_tags = list(scenario["failure_tags"])
174
+
175
+ # Optimal cheaper alternative
176
+ optimal_tier = scenario.get("optimal_tier")
177
+ if optimal_tier:
178
+ optimal_model = self._tier_to_model(self.rng.choice(optimal_tier))
179
+ optimal_cost = self.MODEL_CONFIGS[optimal_model]["cost_input"] * 2000 # rough estimate
180
+ else:
181
+ optimal_cost = total_cost * 0.6
182
+
183
+ # Compute cost saved vs always frontier
184
+ frontier_cost = self.MODEL_CONFIGS["frontier"]["cost_input"] * 2000 * base_steps
185
+ cost_saved = frontier_cost - total_cost
186
+
187
+ return AgentTrace(
188
+ trace_id=trace_id,
189
+ user_request=user_request,
190
+ task_type=task_type,
191
+ steps=steps,
192
+ final_outcome=outcome,
193
+ failure_tags=failure_tags,
194
+ total_cost=total_cost,
195
+ total_cost_saved_vs_frontier=cost_saved,
196
+ optimal_cost=optimal_cost,
197
+ metadata={
198
+ "scenario": scenario["name"],
199
+ "synthetic": True,
200
+ "generation_timestamp": datetime.utcnow().isoformat(),
201
+ "optimal_tier": optimal_tier[0] if optimal_tier else tier,
202
+ },
203
+ )
204
+
205
+ def _pick_scenario(self) -> Dict:
206
+ names = [s["name"] for s in self.SCENARIOS]
207
+ probs = [s["prob"] for s in self.SCENARIOS]
208
+ return self.rng.choices(self.SCENARIOS, weights=probs)[0]
209
+
210
+ def _generate_request(self, task_type: TaskType, scenario: str) -> str:
211
+ templates = {
212
+ TaskType.QUICK_ANSWER: [
213
+ "What is the capital of France?",
214
+ "Briefly explain quantum computing.",
215
+ "Summarize the key points of article X.",
216
+ "What is 237 * 452?",
217
+ ],
218
+ TaskType.CODING: [
219
+ "Write a Python function to reverse a linked list.",
220
+ "Fix the bug in this React component.",
221
+ "Refactor the authentication module to use JWT.",
222
+ "Implement a LRU cache in Go.",
223
+ ],
224
+ TaskType.RESEARCH: [
225
+ "Research the latest advancements in transformer architectures.",
226
+ "Find sources comparing LoRA and full fine-tuning.",
227
+ "Investigate the climate impact of data centers.",
228
+ "What does the literature say about speculative decoding?",
229
+ ],
230
+ TaskType.DOCUMENT_DRAFTING: [
231
+ "Draft a project proposal for the ML pipeline.",
232
+ "Write an email to the team about the deployment schedule.",
233
+ "Create a technical report on system performance.",
234
+ ],
235
+ TaskType.LEGAL_REGULATED: [
236
+ "Review this contract for liability clauses.",
237
+ "Check compliance with GDPR for this data processing pipeline.",
238
+ "Draft a privacy policy section for user data.",
239
+ ],
240
+ TaskType.TOOL_HEAVY: [
241
+ "Search for open issues in the repo and create a summary.",
242
+ "Fetch the latest API documentation and generate client code.",
243
+ "Query the database for Q3 sales and produce a chart.",
244
+ ],
245
+ TaskType.RETRIEVAL_HEAVY: [
246
+ "Answer based on the attached 50-page document.",
247
+ "Find all mentions of 'payment processing' in my files.",
248
+ "Retrieve relevant cases for this legal query.",
249
+ ],
250
+ TaskType.LONG_HORIZON: [
251
+ "Plan a 3-month roadmap for the agent framework.",
252
+ "Orchestrate the deployment of the multi-region system.",
253
+ "Project: redesign the data architecture end-to-end.",
254
+ ],
255
+ TaskType.UNKNOWN_AMBIGUOUS: [
256
+ "Help me with this thing.",
257
+ "I need something done about the server.",
258
+ "Can you look into that issue we discussed?",
259
+ ],
260
+ }
261
+ options = templates.get(task_type, ["Generic request"])
262
+ return self.rng.choice(options)
263
+
264
+ def _tier_to_model(self, tier: int) -> str:
265
+ mapping = {1: "tiny_local", 2: "cheap_cloud", 3: "medium", 4: "frontier", 5: "specialist"}
266
+ return mapping.get(tier, "medium")
267
+
268
+ def _generate_step(
269
+ self,
270
+ trace_id: str,
271
+ step_idx: int,
272
+ task_type: TaskType,
273
+ model_key: str,
274
+ model_cfg: Dict,
275
+ scenario: Dict,
276
+ is_last: bool,
277
+ ) -> TraceStep:
278
+ step_id = f"{trace_id}_step_{step_idx}"
279
+
280
+ # Model call
281
+ input_tokens = self.rng.randint(500, 8000)
282
+ output_tokens = self.rng.randint(100, 4000)
283
+ cache_hit = self.rng.random() < 0.3
284
+ cache_hit_tokens = int(input_tokens * self.rng.random() * 0.5) if cache_hit else 0
285
+
286
+ model_call = ModelCall(
287
+ model_id=model_key,
288
+ provider="synthetic",
289
+ input_tokens=input_tokens,
290
+ output_tokens=output_tokens,
291
+ reasoning_tokens=output_tokens // 5 if model_key == "frontier" else 0,
292
+ cost_per_1k_input=model_cfg["cost_input"],
293
+ cost_per_1k_output=model_cfg["cost_output"],
294
+ cache_hit_input_tokens=cache_hit_tokens,
295
+ latency_ms=model_cfg["latency"] * self.rng.uniform(0.8, 1.5),
296
+ )
297
+
298
+ # Tool calls
299
+ tool_calls = []
300
+ base_tools = self._get_tools_for_task(task_type)
301
+ num_tools = self.rng.randint(0, len(base_tools))
302
+
303
+ if scenario.get("extra_tools"):
304
+ num_tools += scenario["extra_tools"]
305
+ if scenario.get("missing_tools"):
306
+ num_tools = max(0, num_tools - scenario["missing_tools"])
307
+
308
+ for t in range(min(num_tools, len(base_tools))):
309
+ tool_name = base_tools[t]
310
+ tool_cost = self.TOOL_COSTS.get(tool_name, 0.001)
311
+ repeated = self.rng.random() < 0.1
312
+ ignored = self.rng.random() < 0.05
313
+ failed = self.rng.random() < (0.2 if scenario["name"] in ("retry_loop", "tool_underuse") else 0.05)
314
+
315
+ tool_calls.append(ToolCall(
316
+ tool_name=tool_name,
317
+ tool_input={"query": f"auto_{tool_name}"},
318
+ tool_cost=tool_cost,
319
+ tool_latency_ms=self.rng.uniform(100, 1000),
320
+ cache_hit=self.rng.random() < 0.2,
321
+ repeated=repeated,
322
+ ignored_result=ignored,
323
+ failed=failed,
324
+ ))
325
+
326
+ # Verifier calls
327
+ verifier_calls = []
328
+ num_verifiers = 0
329
+ if task_type in (TaskType.LEGAL_REGULATED, TaskType.CODING, TaskType.RESEARCH):
330
+ num_verifiers = 1 if self.rng.random() < 0.5 else 0
331
+ if scenario.get("extra_verifiers"):
332
+ num_verifiers += scenario["extra_verifiers"]
333
+
334
+ for v in range(num_verifiers):
335
+ passed = self.rng.random() < 0.8
336
+ verifier_calls.append(VerifierCall(
337
+ verifier_model_id="verifier_medium",
338
+ target_step_id=step_id,
339
+ passed=passed,
340
+ confidence=self.rng.uniform(0.6, 0.99),
341
+ cost=0.005,
342
+ latency_ms=500,
343
+ ))
344
+
345
+ # Context size
346
+ context_size = self.rng.randint(1000, 15000)
347
+ if scenario["name"] == "cache_break":
348
+ context_size += self.rng.randint(5000, 20000) # excessive context
349
+
350
+ # Retry count
351
+ retries = 0
352
+ if scenario.get("retries"):
353
+ retries = self.rng.randint(scenario["retries"] - 1, scenario["retries"] + 1)
354
+ elif self.rng.random() < 0.15:
355
+ retries = self.rng.randint(1, 2)
356
+
357
+ # Recovery action
358
+ recovery = None
359
+ if retries > 0:
360
+ recovery = self.rng.choice([
361
+ "retry_same", "retry_changed_prompt", "repair_tool",
362
+ "retrieve_more_context", "switch_model", "ask_clarification",
363
+ ])
364
+
365
+ # Outcome per step
366
+ step_outcome = Outcome.SUCCESS
367
+ if is_last:
368
+ step_outcome = scenario["outcome"]
369
+ elif scenario["name"] == "retry_loop" and step_idx >= 2:
370
+ step_outcome = Outcome.FAILURE
371
+ elif scenario["name"] == "false_done" and is_last:
372
+ step_outcome = Outcome.FALSE_DONE
373
+
374
+ return TraceStep(
375
+ step_id=step_id,
376
+ timestamp=datetime.utcnow() + timedelta(seconds=step_idx * 30),
377
+ task_type=task_type,
378
+ model_call=model_call,
379
+ tool_calls=tool_calls,
380
+ verifier_calls=verifier_calls,
381
+ context_size_tokens=context_size,
382
+ context_sources=["system_rules", "tool_descriptions", "user_preferences", "recent_messages"],
383
+ retry_count=retries,
384
+ recovery_action=recovery,
385
+ artifacts_created=[f"artifact_{step_idx}"] if self.rng.random() < 0.3 else [],
386
+ step_outcome=step_outcome,
387
+ )
388
+
389
+ def _get_tools_for_task(self, task_type: TaskType) -> List[str]:
390
+ mapping = {
391
+ TaskType.QUICK_ANSWER: ["calculator", "search"],
392
+ TaskType.CODING: ["file_read", "file_write", "code_execution", "linter", "test_runner"],
393
+ TaskType.RESEARCH: ["search", "retrieve", "fetch", "summarize"],
394
+ TaskType.DOCUMENT_DRAFTING: ["file_read", "summarize"],
395
+ TaskType.LEGAL_REGULATED: ["document_retrieval", "compliance_check", "search"],
396
+ TaskType.TOOL_HEAVY: ["search", "fetch", "api_call", "database_query"],
397
+ TaskType.RETRIEVAL_HEAVY: ["retrieve", "search", "fetch"],
398
+ TaskType.LONG_HORIZON: ["task_planner", "progress_tracker", "file_read"],
399
+ TaskType.UNKNOWN_AMBIGUOUS: ["search"],
400
+ }
401
+ return mapping.get(task_type, ["search"])
402
+
403
+ def to_dicts(self, traces: List[AgentTrace]) -> List[Dict[str, Any]]:
404
+ return [t.to_dict() for t in traces]
405
+
406
+ def save(self, traces: List[AgentTrace], path: str) -> None:
407
+ with open(path, "w") as f:
408
+ for trace in traces:
409
+ f.write(json.dumps(trace.to_dict()) + "\n")
410
+
411
+ def load(self, path: str) -> List[Dict[str, Any]]:
412
+ traces = []
413
+ with open(path, "r") as f:
414
+ for line in f:
415
+ traces.append(json.loads(line))
416
+ return traces