| """Normalized Agent Trace Schema for ACO.""" |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Any |
| from datetime import datetime |
| import json, uuid |
|
|
| @dataclass |
| class ModelCall: |
| model_id: str |
| provider: str |
| tier: int |
| input_tokens: int = 0 |
| output_tokens: int = 0 |
| reasoning_tokens: int = 0 |
| cache_hit: bool = False |
| latency_ms: float = 0.0 |
| cost: float = 0.0 |
| success: bool = True |
| error: Optional[str] = None |
|
|
| @dataclass |
| class ToolCall: |
| tool_name: str |
| args: Dict[str, Any] = field(default_factory=dict) |
| result: Optional[str] = None |
| latency_ms: float = 0.0 |
| cost: float = 0.0 |
| success: bool = True |
| cached: bool = False |
| unnecessary: bool = False |
| error: Optional[str] = None |
|
|
| @dataclass |
| class TraceStep: |
| step_num: int |
| model_call: Optional[ModelCall] = None |
| tool_calls: List[ToolCall] = field(default_factory=list) |
| context_size: int = 0 |
| context_sources: List[str] = field(default_factory=list) |
| context_budget_used: float = 0.0 |
| cache_prefix_tokens: int = 0 |
| cache_suffix_tokens: int = 0 |
| verifier_called: bool = False |
| verifier_result: Optional[str] = None |
| retry_num: int = 0 |
| recovery_action: Optional[str] = None |
| timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) |
|
|
| @dataclass |
| class AgentTrace: |
| trace_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12]) |
| request: str = "" |
| task_type: str = "unknown_ambiguous" |
| difficulty: int = 3 |
| predicted_tier: int = 4 |
| steps: List[TraceStep] = field(default_factory=list) |
| final_outcome: str = "unknown" |
| task_success: bool = False |
| total_cost: float = 0.0 |
| total_tokens: int = 0 |
| total_tool_calls: int = 0 |
| total_retries: int = 0 |
| total_verifier_calls: int = 0 |
| cache_hit_rate: float = 0.0 |
| latency_total_ms: float = 0.0 |
| user_correction: bool = False |
| failure_tags: List[str] = field(default_factory=list) |
| artifacts_created: List[str] = field(default_factory=list) |
| meta_tool_used: bool = False |
| early_terminated: bool = False |
| escalation_occurred: bool = False |
| timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) |
|
|
| def to_dict(self) -> dict: |
| import dataclasses |
| return dataclasses.asdict(self) |
|
|
| def to_json(self) -> str: |
| return json.dumps(self.to_dict(), indent=2) |
|
|
| def compute_summary(self) -> dict: |
| total_cost = sum( |
| (s.model_call.cost if s.model_call else 0) + |
| sum(tc.cost for tc in s.tool_calls) |
| for s in self.steps |
| ) |
| total_tokens = sum( |
| (s.model_call.input_tokens + s.model_call.output_tokens if s.model_call else 0) |
| for s in self.steps |
| ) |
| total_tool_calls = sum(len(s.tool_calls) for s in self.steps) |
| total_retries = sum(s.retry_num for s in self.steps) |
| total_verifiers = sum(1 for s in self.steps if s.verifier_called) |
| cache_hits = sum(1 for s in self.steps if s.model_call and s.model_call.cache_hit) |
| cache_total = sum(1 for s in self.steps if s.model_call) |
| return { |
| "trace_id": self.trace_id, |
| "task_type": self.task_type, |
| "difficulty": self.difficulty, |
| "predicted_tier": self.predicted_tier, |
| "success": self.task_success, |
| "total_cost": total_cost, |
| "total_tokens": total_tokens, |
| "total_tool_calls": total_tool_calls, |
| "total_retries": total_retries, |
| "total_verifier_calls": total_verifiers, |
| "cache_hit_rate": cache_hits / max(cache_total, 1), |
| "num_steps": len(self.steps), |
| "outcome": self.final_outcome, |
| "failure_tags": self.failure_tags, |
| "meta_tool_used": self.meta_tool_used, |
| "early_terminated": self.early_terminated, |
| } |
|
|