narcolepticchicken commited on
Commit
7268941
·
verified ·
1 Parent(s): 6d0aa94

Upload aco/trace_schema.py

Browse files
Files changed (1) hide show
  1. aco/trace_schema.py +234 -0
aco/trace_schema.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalized trace schema for Agent Cost Optimizer."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, List, Dict, Any
5
+ from enum import Enum
6
+ from datetime import datetime
7
+
8
+
9
+ class TaskType(Enum):
10
+ QUICK_ANSWER = "quick_answer"
11
+ RESEARCH = "research"
12
+ CODING = "coding"
13
+ DOCUMENT_DRAFTING = "document_drafting"
14
+ LEGAL_REGULATED = "legal_regulated"
15
+ TOOL_HEAVY = "tool_heavy"
16
+ RETRIEVAL_HEAVY = "retrieval_heavy"
17
+ LONG_HORIZON = "long_horizon"
18
+ UNKNOWN_AMBIGUOUS = "unknown_ambiguous"
19
+
20
+
21
+ class Outcome(Enum):
22
+ SUCCESS = "success"
23
+ PARTIAL_SUCCESS = "partial_success"
24
+ FAILURE = "failure"
25
+ FALSE_DONE = "false_done"
26
+ BLOCKED = "blocked"
27
+ ESCALATED_HUMAN = "escalated_human"
28
+ STOPPED_DOOM = "stopped_doom"
29
+
30
+
31
+ class FailureTag(Enum):
32
+ MODEL_TOO_WEAK = "model_too_weak"
33
+ CONTEXT_TOO_SMALL = "context_too_small"
34
+ TOOL_FAILED = "tool_failed"
35
+ TOOL_UNNECESSARY = "tool_unnecessary"
36
+ TOOL_MISSED = "tool_missed"
37
+ VERIFIER_FALSE_PASS = "verifier_false_pass"
38
+ VERIFIER_FALSE_REJECT = "verifier_false_reject"
39
+ RETRY_LOOP = "retry_loop"
40
+ CACHE_BREAK = "cache_break"
41
+ HALLUCINATION = "hallucination"
42
+ TIMEOUT = "timeout"
43
+ COST_EXCEEDED = "cost_exceeded"
44
+ UNSAFE_CHEAP_MODEL = "unsafe_cheap_model"
45
+ MISSED_ESCALATION = "missed_escalation"
46
+
47
+
48
+ @dataclass
49
+ class ToolCall:
50
+ tool_name: str
51
+ tool_input: Dict[str, Any]
52
+ tool_output: Optional[str] = None
53
+ tool_cost: float = 0.0
54
+ tool_latency_ms: float = 0.0
55
+ cache_hit: bool = False
56
+ repeated: bool = False
57
+ ignored_result: bool = False
58
+ failed: bool = False
59
+
60
+
61
+ @dataclass
62
+ class ModelCall:
63
+ model_id: str
64
+ provider: str
65
+ input_tokens: int = 0
66
+ output_tokens: int = 0
67
+ reasoning_tokens: int = 0
68
+ cost_per_1k_input: float = 0.0
69
+ cost_per_1k_output: float = 0.0
70
+ cache_hit_input_tokens: int = 0
71
+ latency_ms: float = 0.0
72
+
73
+ @property
74
+ def total_cost(self) -> float:
75
+ input_cost = (self.input_tokens / 1000) * self.cost_per_1k_input
76
+ output_cost = (self.output_tokens / 1000) * self.cost_per_1k_output
77
+ cache_discount = (self.cache_hit_input_tokens / 1000) * self.cost_per_1k_input * 0.5 # 50% discount
78
+ return input_cost + output_cost - cache_discount
79
+
80
+
81
+ @dataclass
82
+ class VerifierCall:
83
+ verifier_model_id: str
84
+ target_step_id: str
85
+ passed: bool = False
86
+ confidence: float = 0.0
87
+ cost: float = 0.0
88
+ latency_ms: float = 0.0
89
+
90
+
91
+ @dataclass
92
+ class TraceStep:
93
+ step_id: str
94
+ timestamp: datetime
95
+ task_type: TaskType
96
+ model_call: ModelCall
97
+ tool_calls: List[ToolCall] = field(default_factory=list)
98
+ verifier_calls: List[VerifierCall] = field(default_factory=list)
99
+ context_size_tokens: int = 0
100
+ context_sources: List[str] = field(default_factory=list)
101
+ cache_boundary_reached: bool = False
102
+ retry_count: int = 0
103
+ recovery_action: Optional[str] = None
104
+ planned_next: Optional[str] = None
105
+ user_correction: Optional[str] = None
106
+ artifacts_created: List[str] = field(default_factory=list)
107
+ step_outcome: Optional[Outcome] = None
108
+
109
+ @property
110
+ def step_cost(self) -> float:
111
+ mc = self.model_call.total_cost if self.model_call else 0.0
112
+ tc = sum(t.tool_cost for t in self.tool_calls)
113
+ vc = sum(v.cost for v in self.verifier_calls)
114
+ return mc + tc + vc
115
+
116
+ @property
117
+ def step_latency_ms(self) -> float:
118
+ ml = self.model_call.latency_ms if self.model_call else 0.0
119
+ tl = sum(t.tool_latency_ms for t in self.tool_calls)
120
+ vl = sum(v.latency_ms for v in self.verifier_calls)
121
+ return ml + tl + vl
122
+
123
+
124
+ @dataclass
125
+ class AgentTrace:
126
+ trace_id: str
127
+ user_request: str
128
+ task_type: TaskType
129
+ steps: List[TraceStep] = field(default_factory=list)
130
+ final_outcome: Optional[Outcome] = None
131
+ final_artifacts: List[str] = field(default_factory=list)
132
+ failure_tags: List[FailureTag] = field(default_factory=list)
133
+ user_satisfaction: Optional[float] = None # 0-1
134
+ total_cost_saved_vs_frontier: Optional[float] = None
135
+ total_cost: Optional[float] = None
136
+ optimal_cost: Optional[float] = None # oracle routing cost
137
+ metadata: Dict[str, Any] = field(default_factory=dict)
138
+
139
+ @property
140
+ def total_cost_computed(self) -> float:
141
+ return sum(s.step_cost for s in self.steps)
142
+
143
+ @property
144
+ def total_latency_ms(self) -> float:
145
+ return sum(s.step_latency_ms for s in self.steps)
146
+
147
+ @property
148
+ def total_retries(self) -> int:
149
+ return sum(s.retry_count for s in self.steps)
150
+
151
+ @property
152
+ def total_tool_calls(self) -> int:
153
+ return sum(len(s.tool_calls) for s in self.steps)
154
+
155
+ @property
156
+ def total_verifier_calls(self) -> int:
157
+ return sum(len(s.verifier_calls) for s in self.steps)
158
+
159
+ @property
160
+ def total_context_tokens(self) -> int:
161
+ return sum(s.context_size_tokens for s in self.steps)
162
+
163
+ @property
164
+ def cache_hit_rate(self) -> float:
165
+ model_calls = [s.model_call for s in self.steps if s.model_call]
166
+ if not model_calls:
167
+ return 0.0
168
+ total_input = sum(m.input_tokens for m in model_calls)
169
+ total_cache_hit = sum(m.cache_hit_input_tokens for m in model_calls)
170
+ return total_cache_hit / total_input if total_input > 0 else 0.0
171
+
172
+ def to_dict(self) -> Dict[str, Any]:
173
+ return {
174
+ "trace_id": self.trace_id,
175
+ "user_request": self.user_request,
176
+ "task_type": self.task_type.value,
177
+ "steps": [
178
+ {
179
+ "step_id": s.step_id,
180
+ "timestamp": s.timestamp.isoformat(),
181
+ "task_type": s.task_type.value,
182
+ "model_call": {
183
+ "model_id": s.model_call.model_id,
184
+ "provider": s.model_call.provider,
185
+ "input_tokens": s.model_call.input_tokens,
186
+ "output_tokens": s.model_call.output_tokens,
187
+ "reasoning_tokens": s.model_call.reasoning_tokens,
188
+ "cost": s.model_call.total_cost,
189
+ "latency_ms": s.model_call.latency_ms,
190
+ "cache_hit_input_tokens": s.model_call.cache_hit_input_tokens,
191
+ },
192
+ "tool_calls": [
193
+ {
194
+ "tool_name": t.tool_name,
195
+ "tool_cost": t.tool_cost,
196
+ "tool_latency_ms": t.tool_latency_ms,
197
+ "cache_hit": t.cache_hit,
198
+ "repeated": t.repeated,
199
+ "ignored_result": t.ignored_result,
200
+ "failed": t.failed,
201
+ }
202
+ for t in s.tool_calls
203
+ ],
204
+ "verifier_calls": [
205
+ {
206
+ "verifier_model_id": v.verifier_model_id,
207
+ "passed": v.passed,
208
+ "confidence": v.confidence,
209
+ "cost": v.cost,
210
+ }
211
+ for v in s.verifier_calls
212
+ ],
213
+ "context_size_tokens": s.context_size_tokens,
214
+ "retry_count": s.retry_count,
215
+ "recovery_action": s.recovery_action,
216
+ "step_outcome": s.step_outcome.value if s.step_outcome else None,
217
+ "step_cost": s.step_cost,
218
+ "step_latency_ms": s.step_latency_ms,
219
+ }
220
+ for s in self.steps
221
+ ],
222
+ "final_outcome": self.final_outcome.value if self.final_outcome else None,
223
+ "failure_tags": [f.value for f in self.failure_tags],
224
+ "total_cost": self.total_cost_computed,
225
+ "total_latency_ms": self.total_latency_ms,
226
+ "total_retries": self.total_retries,
227
+ "total_tool_calls": self.total_tool_calls,
228
+ "total_verifier_calls": self.total_verifier_calls,
229
+ "total_context_tokens": self.total_context_tokens,
230
+ "cache_hit_rate": self.cache_hit_rate,
231
+ "user_satisfaction": self.user_satisfaction,
232
+ "total_cost_saved_vs_frontier": self.total_cost_saved_vs_frontier,
233
+ "optimal_cost": self.optimal_cost,
234
+ }