narcolepticchicken commited on
Commit
5569d72
·
verified ·
1 Parent(s): d807f4a

Upload aco/trace_schema.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/trace_schema.py +86 -207
aco/trace_schema.py CHANGED
@@ -1,234 +1,113 @@
1
- """Normalized trace schema for Agent Cost Optimizer."""
2
-
3
  from dataclasses import dataclass, field
4
- from typing import Optional, List, Dict, Any
5
- from enum import Enum
6
  from datetime import datetime
7
-
8
-
9
- class TaskType(Enum):
10
- QUICK_ANSWER = "quick_answer"
11
- RESEARCH = "research"
12
- CODING = "coding"
13
- DOCUMENT_DRAFTING = "document_drafting"
14
- LEGAL_REGULATED = "legal_regulated"
15
- TOOL_HEAVY = "tool_heavy"
16
- RETRIEVAL_HEAVY = "retrieval_heavy"
17
- LONG_HORIZON = "long_horizon"
18
- UNKNOWN_AMBIGUOUS = "unknown_ambiguous"
19
-
20
-
21
- class Outcome(Enum):
22
- SUCCESS = "success"
23
- PARTIAL_SUCCESS = "partial_success"
24
- FAILURE = "failure"
25
- FALSE_DONE = "false_done"
26
- BLOCKED = "blocked"
27
- ESCALATED_HUMAN = "escalated_human"
28
- STOPPED_DOOM = "stopped_doom"
29
-
30
-
31
- class FailureTag(Enum):
32
- MODEL_TOO_WEAK = "model_too_weak"
33
- CONTEXT_TOO_SMALL = "context_too_small"
34
- TOOL_FAILED = "tool_failed"
35
- TOOL_UNNECESSARY = "tool_unnecessary"
36
- TOOL_MISSED = "tool_missed"
37
- VERIFIER_FALSE_PASS = "verifier_false_pass"
38
- VERIFIER_FALSE_REJECT = "verifier_false_reject"
39
- RETRY_LOOP = "retry_loop"
40
- CACHE_BREAK = "cache_break"
41
- HALLUCINATION = "hallucination"
42
- TIMEOUT = "timeout"
43
- COST_EXCEEDED = "cost_exceeded"
44
- UNSAFE_CHEAP_MODEL = "unsafe_cheap_model"
45
- MISSED_ESCALATION = "missed_escalation"
46
-
47
-
48
- @dataclass
49
- class ToolCall:
50
- tool_name: str
51
- tool_input: Dict[str, Any]
52
- tool_output: Optional[str] = None
53
- tool_cost: float = 0.0
54
- tool_latency_ms: float = 0.0
55
- cache_hit: bool = False
56
- repeated: bool = False
57
- ignored_result: bool = False
58
- failed: bool = False
59
-
60
 
61
  @dataclass
62
  class ModelCall:
63
  model_id: str
64
  provider: str
 
65
  input_tokens: int = 0
66
  output_tokens: int = 0
67
  reasoning_tokens: int = 0
68
- cost_per_1k_input: float = 0.0
69
- cost_per_1k_output: float = 0.0
70
- cache_hit_input_tokens: int = 0
71
  latency_ms: float = 0.0
72
-
73
- @property
74
- def total_cost(self) -> float:
75
- input_cost = (self.input_tokens / 1000) * self.cost_per_1k_input
76
- output_cost = (self.output_tokens / 1000) * self.cost_per_1k_output
77
- cache_discount = (self.cache_hit_input_tokens / 1000) * self.cost_per_1k_input * 0.5 # 50% discount
78
- return input_cost + output_cost - cache_discount
79
-
80
 
81
  @dataclass
82
- class VerifierCall:
83
- verifier_model_id: str
84
- target_step_id: str
85
- passed: bool = False
86
- confidence: float = 0.0
87
- cost: float = 0.0
88
  latency_ms: float = 0.0
89
-
 
 
 
 
90
 
91
  @dataclass
92
  class TraceStep:
93
- step_id: str
94
- timestamp: datetime
95
- task_type: TaskType
96
- model_call: ModelCall
97
  tool_calls: List[ToolCall] = field(default_factory=list)
98
- verifier_calls: List[VerifierCall] = field(default_factory=list)
99
- context_size_tokens: int = 0
100
  context_sources: List[str] = field(default_factory=list)
101
- cache_boundary_reached: bool = False
102
- retry_count: int = 0
 
 
 
 
103
  recovery_action: Optional[str] = None
104
- planned_next: Optional[str] = None
105
- user_correction: Optional[str] = None
106
- artifacts_created: List[str] = field(default_factory=list)
107
- step_outcome: Optional[Outcome] = None
108
-
109
- @property
110
- def step_cost(self) -> float:
111
- mc = self.model_call.total_cost if self.model_call else 0.0
112
- tc = sum(t.tool_cost for t in self.tool_calls)
113
- vc = sum(v.cost for v in self.verifier_calls)
114
- return mc + tc + vc
115
-
116
- @property
117
- def step_latency_ms(self) -> float:
118
- ml = self.model_call.latency_ms if self.model_call else 0.0
119
- tl = sum(t.tool_latency_ms for t in self.tool_calls)
120
- vl = sum(v.latency_ms for v in self.verifier_calls)
121
- return ml + tl + vl
122
-
123
 
124
  @dataclass
125
  class AgentTrace:
126
- trace_id: str
127
- user_request: str
128
- task_type: TaskType
 
 
129
  steps: List[TraceStep] = field(default_factory=list)
130
- final_outcome: Optional[Outcome] = None
131
- final_artifacts: List[str] = field(default_factory=list)
132
- failure_tags: List[FailureTag] = field(default_factory=list)
133
- user_satisfaction: Optional[float] = None # 0-1
134
- total_cost_saved_vs_frontier: Optional[float] = None
135
- total_cost: Optional[float] = None
136
- optimal_cost: Optional[float] = None # oracle routing cost
137
- metadata: Dict[str, Any] = field(default_factory=dict)
138
-
139
- @property
140
- def total_cost_computed(self) -> float:
141
- return sum(s.step_cost for s in self.steps)
142
-
143
- @property
144
- def total_latency_ms(self) -> float:
145
- return sum(s.step_latency_ms for s in self.steps)
146
-
147
- @property
148
- def total_retries(self) -> int:
149
- return sum(s.retry_count for s in self.steps)
150
-
151
- @property
152
- def total_tool_calls(self) -> int:
153
- return sum(len(s.tool_calls) for s in self.steps)
154
-
155
- @property
156
- def total_verifier_calls(self) -> int:
157
- return sum(len(s.verifier_calls) for s in self.steps)
158
-
159
- @property
160
- def total_context_tokens(self) -> int:
161
- return sum(s.context_size_tokens for s in self.steps)
162
-
163
- @property
164
- def cache_hit_rate(self) -> float:
165
- model_calls = [s.model_call for s in self.steps if s.model_call]
166
- if not model_calls:
167
- return 0.0
168
- total_input = sum(m.input_tokens for m in model_calls)
169
- total_cache_hit = sum(m.cache_hit_input_tokens for m in model_calls)
170
- return total_cache_hit / total_input if total_input > 0 else 0.0
171
-
172
- def to_dict(self) -> Dict[str, Any]:
173
  return {
174
  "trace_id": self.trace_id,
175
- "user_request": self.user_request,
176
- "task_type": self.task_type.value,
177
- "steps": [
178
- {
179
- "step_id": s.step_id,
180
- "timestamp": s.timestamp.isoformat(),
181
- "task_type": s.task_type.value,
182
- "model_call": {
183
- "model_id": s.model_call.model_id,
184
- "provider": s.model_call.provider,
185
- "input_tokens": s.model_call.input_tokens,
186
- "output_tokens": s.model_call.output_tokens,
187
- "reasoning_tokens": s.model_call.reasoning_tokens,
188
- "cost": s.model_call.total_cost,
189
- "latency_ms": s.model_call.latency_ms,
190
- "cache_hit_input_tokens": s.model_call.cache_hit_input_tokens,
191
- },
192
- "tool_calls": [
193
- {
194
- "tool_name": t.tool_name,
195
- "tool_cost": t.tool_cost,
196
- "tool_latency_ms": t.tool_latency_ms,
197
- "cache_hit": t.cache_hit,
198
- "repeated": t.repeated,
199
- "ignored_result": t.ignored_result,
200
- "failed": t.failed,
201
- }
202
- for t in s.tool_calls
203
- ],
204
- "verifier_calls": [
205
- {
206
- "verifier_model_id": v.verifier_model_id,
207
- "passed": v.passed,
208
- "confidence": v.confidence,
209
- "cost": v.cost,
210
- }
211
- for v in s.verifier_calls
212
- ],
213
- "context_size_tokens": s.context_size_tokens,
214
- "retry_count": s.retry_count,
215
- "recovery_action": s.recovery_action,
216
- "step_outcome": s.step_outcome.value if s.step_outcome else None,
217
- "step_cost": s.step_cost,
218
- "step_latency_ms": s.step_latency_ms,
219
- }
220
- for s in self.steps
221
- ],
222
- "final_outcome": self.final_outcome.value if self.final_outcome else None,
223
- "failure_tags": [f.value for f in self.failure_tags],
224
- "total_cost": self.total_cost_computed,
225
- "total_latency_ms": self.total_latency_ms,
226
- "total_retries": self.total_retries,
227
- "total_tool_calls": self.total_tool_calls,
228
- "total_verifier_calls": self.total_verifier_calls,
229
- "total_context_tokens": self.total_context_tokens,
230
- "cache_hit_rate": self.cache_hit_rate,
231
- "user_satisfaction": self.user_satisfaction,
232
- "total_cost_saved_vs_frontier": self.total_cost_saved_vs_frontier,
233
- "optimal_cost": self.optimal_cost,
234
  }
 
1
+ """Normalized Agent Trace Schema for ACO."""
 
2
  from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Any
 
4
  from datetime import datetime
5
+ import json, uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  @dataclass
8
  class ModelCall:
9
  model_id: str
10
  provider: str
11
+ tier: int
12
  input_tokens: int = 0
13
  output_tokens: int = 0
14
  reasoning_tokens: int = 0
15
+ cache_hit: bool = False
 
 
16
  latency_ms: float = 0.0
17
+ cost: float = 0.0
18
+ success: bool = True
19
+ error: Optional[str] = None
 
 
 
 
 
20
 
21
  @dataclass
22
+ class ToolCall:
23
+ tool_name: str
24
+ args: Dict[str, Any] = field(default_factory=dict)
25
+ result: Optional[str] = None
 
 
26
  latency_ms: float = 0.0
27
+ cost: float = 0.0
28
+ success: bool = True
29
+ cached: bool = False
30
+ unnecessary: bool = False
31
+ error: Optional[str] = None
32
 
33
  @dataclass
34
  class TraceStep:
35
+ step_num: int
36
+ model_call: Optional[ModelCall] = None
 
 
37
  tool_calls: List[ToolCall] = field(default_factory=list)
38
+ context_size: int = 0
 
39
  context_sources: List[str] = field(default_factory=list)
40
+ context_budget_used: float = 0.0
41
+ cache_prefix_tokens: int = 0
42
+ cache_suffix_tokens: int = 0
43
+ verifier_called: bool = False
44
+ verifier_result: Optional[str] = None
45
+ retry_num: int = 0
46
  recovery_action: Optional[str] = None
47
+ timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @dataclass
50
  class AgentTrace:
51
+ trace_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
52
+ request: str = ""
53
+ task_type: str = "unknown_ambiguous"
54
+ difficulty: int = 3
55
+ predicted_tier: int = 4
56
  steps: List[TraceStep] = field(default_factory=list)
57
+ final_outcome: str = "unknown"
58
+ task_success: bool = False
59
+ total_cost: float = 0.0
60
+ total_tokens: int = 0
61
+ total_tool_calls: int = 0
62
+ total_retries: int = 0
63
+ total_verifier_calls: int = 0
64
+ cache_hit_rate: float = 0.0
65
+ latency_total_ms: float = 0.0
66
+ user_correction: bool = False
67
+ failure_tags: List[str] = field(default_factory=list)
68
+ artifacts_created: List[str] = field(default_factory=list)
69
+ meta_tool_used: bool = False
70
+ early_terminated: bool = False
71
+ escalation_occurred: bool = False
72
+ timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
73
+
74
+ def to_dict(self) -> dict:
75
+ import dataclasses
76
+ return dataclasses.asdict(self)
77
+
78
+ def to_json(self) -> str:
79
+ return json.dumps(self.to_dict(), indent=2)
80
+
81
+ def compute_summary(self) -> dict:
82
+ total_cost = sum(
83
+ (s.model_call.cost if s.model_call else 0) +
84
+ sum(tc.cost for tc in s.tool_calls)
85
+ for s in self.steps
86
+ )
87
+ total_tokens = sum(
88
+ (s.model_call.input_tokens + s.model_call.output_tokens if s.model_call else 0)
89
+ for s in self.steps
90
+ )
91
+ total_tool_calls = sum(len(s.tool_calls) for s in self.steps)
92
+ total_retries = sum(s.retry_num for s in self.steps)
93
+ total_verifiers = sum(1 for s in self.steps if s.verifier_called)
94
+ cache_hits = sum(1 for s in self.steps if s.model_call and s.model_call.cache_hit)
95
+ cache_total = sum(1 for s in self.steps if s.model_call)
 
 
 
 
96
  return {
97
  "trace_id": self.trace_id,
98
+ "task_type": self.task_type,
99
+ "difficulty": self.difficulty,
100
+ "predicted_tier": self.predicted_tier,
101
+ "success": self.task_success,
102
+ "total_cost": total_cost,
103
+ "total_tokens": total_tokens,
104
+ "total_tool_calls": total_tool_calls,
105
+ "total_retries": total_retries,
106
+ "total_verifier_calls": total_verifiers,
107
+ "cache_hit_rate": cache_hits / max(cache_total, 1),
108
+ "num_steps": len(self.steps),
109
+ "outcome": self.final_outcome,
110
+ "failure_tags": self.failure_tags,
111
+ "meta_tool_used": self.meta_tool_used,
112
+ "early_terminated": self.early_terminated,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }