File size: 3,923 Bytes
5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 5569d72 7268941 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | """Normalized Agent Trace Schema for ACO."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
import json, uuid
@dataclass
class ModelCall:
model_id: str
provider: str
tier: int
input_tokens: int = 0
output_tokens: int = 0
reasoning_tokens: int = 0
cache_hit: bool = False
latency_ms: float = 0.0
cost: float = 0.0
success: bool = True
error: Optional[str] = None
@dataclass
class ToolCall:
tool_name: str
args: Dict[str, Any] = field(default_factory=dict)
result: Optional[str] = None
latency_ms: float = 0.0
cost: float = 0.0
success: bool = True
cached: bool = False
unnecessary: bool = False
error: Optional[str] = None
@dataclass
class TraceStep:
step_num: int
model_call: Optional[ModelCall] = None
tool_calls: List[ToolCall] = field(default_factory=list)
context_size: int = 0
context_sources: List[str] = field(default_factory=list)
context_budget_used: float = 0.0
cache_prefix_tokens: int = 0
cache_suffix_tokens: int = 0
verifier_called: bool = False
verifier_result: Optional[str] = None
retry_num: int = 0
recovery_action: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
@dataclass
class AgentTrace:
trace_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
request: str = ""
task_type: str = "unknown_ambiguous"
difficulty: int = 3
predicted_tier: int = 4
steps: List[TraceStep] = field(default_factory=list)
final_outcome: str = "unknown"
task_success: bool = False
total_cost: float = 0.0
total_tokens: int = 0
total_tool_calls: int = 0
total_retries: int = 0
total_verifier_calls: int = 0
cache_hit_rate: float = 0.0
latency_total_ms: float = 0.0
user_correction: bool = False
failure_tags: List[str] = field(default_factory=list)
artifacts_created: List[str] = field(default_factory=list)
meta_tool_used: bool = False
early_terminated: bool = False
escalation_occurred: bool = False
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
def to_dict(self) -> dict:
import dataclasses
return dataclasses.asdict(self)
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
def compute_summary(self) -> dict:
total_cost = sum(
(s.model_call.cost if s.model_call else 0) +
sum(tc.cost for tc in s.tool_calls)
for s in self.steps
)
total_tokens = sum(
(s.model_call.input_tokens + s.model_call.output_tokens if s.model_call else 0)
for s in self.steps
)
total_tool_calls = sum(len(s.tool_calls) for s in self.steps)
total_retries = sum(s.retry_num for s in self.steps)
total_verifiers = sum(1 for s in self.steps if s.verifier_called)
cache_hits = sum(1 for s in self.steps if s.model_call and s.model_call.cache_hit)
cache_total = sum(1 for s in self.steps if s.model_call)
return {
"trace_id": self.trace_id,
"task_type": self.task_type,
"difficulty": self.difficulty,
"predicted_tier": self.predicted_tier,
"success": self.task_success,
"total_cost": total_cost,
"total_tokens": total_tokens,
"total_tool_calls": total_tool_calls,
"total_retries": total_retries,
"total_verifier_calls": total_verifiers,
"cache_hit_rate": cache_hits / max(cache_total, 1),
"num_steps": len(self.steps),
"outcome": self.final_outcome,
"failure_tags": self.failure_tags,
"meta_tool_used": self.meta_tool_used,
"early_terminated": self.early_terminated,
}
|