agent-cost-optimizer / aco /trace_schema.py
narcolepticchicken's picture
Upload aco/trace_schema.py with huggingface_hub
5569d72 verified
"""Normalized Agent Trace Schema for ACO."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
import json, uuid
@dataclass
class ModelCall:
model_id: str
provider: str
tier: int
input_tokens: int = 0
output_tokens: int = 0
reasoning_tokens: int = 0
cache_hit: bool = False
latency_ms: float = 0.0
cost: float = 0.0
success: bool = True
error: Optional[str] = None
@dataclass
class ToolCall:
tool_name: str
args: Dict[str, Any] = field(default_factory=dict)
result: Optional[str] = None
latency_ms: float = 0.0
cost: float = 0.0
success: bool = True
cached: bool = False
unnecessary: bool = False
error: Optional[str] = None
@dataclass
class TraceStep:
step_num: int
model_call: Optional[ModelCall] = None
tool_calls: List[ToolCall] = field(default_factory=list)
context_size: int = 0
context_sources: List[str] = field(default_factory=list)
context_budget_used: float = 0.0
cache_prefix_tokens: int = 0
cache_suffix_tokens: int = 0
verifier_called: bool = False
verifier_result: Optional[str] = None
retry_num: int = 0
recovery_action: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
@dataclass
class AgentTrace:
trace_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
request: str = ""
task_type: str = "unknown_ambiguous"
difficulty: int = 3
predicted_tier: int = 4
steps: List[TraceStep] = field(default_factory=list)
final_outcome: str = "unknown"
task_success: bool = False
total_cost: float = 0.0
total_tokens: int = 0
total_tool_calls: int = 0
total_retries: int = 0
total_verifier_calls: int = 0
cache_hit_rate: float = 0.0
latency_total_ms: float = 0.0
user_correction: bool = False
failure_tags: List[str] = field(default_factory=list)
artifacts_created: List[str] = field(default_factory=list)
meta_tool_used: bool = False
early_terminated: bool = False
escalation_occurred: bool = False
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
def to_dict(self) -> dict:
import dataclasses
return dataclasses.asdict(self)
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
def compute_summary(self) -> dict:
total_cost = sum(
(s.model_call.cost if s.model_call else 0) +
sum(tc.cost for tc in s.tool_calls)
for s in self.steps
)
total_tokens = sum(
(s.model_call.input_tokens + s.model_call.output_tokens if s.model_call else 0)
for s in self.steps
)
total_tool_calls = sum(len(s.tool_calls) for s in self.steps)
total_retries = sum(s.retry_num for s in self.steps)
total_verifiers = sum(1 for s in self.steps if s.verifier_called)
cache_hits = sum(1 for s in self.steps if s.model_call and s.model_call.cache_hit)
cache_total = sum(1 for s in self.steps if s.model_call)
return {
"trace_id": self.trace_id,
"task_type": self.task_type,
"difficulty": self.difficulty,
"predicted_tier": self.predicted_tier,
"success": self.task_success,
"total_cost": total_cost,
"total_tokens": total_tokens,
"total_tool_calls": total_tool_calls,
"total_retries": total_retries,
"total_verifier_calls": total_verifiers,
"cache_hit_rate": cache_hits / max(cache_total, 1),
"num_steps": len(self.steps),
"outcome": self.final_outcome,
"failure_tags": self.failure_tags,
"meta_tool_used": self.meta_tool_used,
"early_terminated": self.early_terminated,
}