File size: 3,923 Bytes
5569d72
7268941
5569d72
7268941
5569d72
7268941
 
 
 
 
5569d72
7268941
 
 
5569d72
7268941
5569d72
 
 
7268941
 
5569d72
 
 
 
7268941
5569d72
 
 
 
 
7268941
 
 
5569d72
 
7268941
5569d72
7268941
5569d72
 
 
 
 
 
7268941
5569d72
7268941
 
 
5569d72
 
 
 
 
7268941
5569d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7268941
 
5569d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7268941
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Normalized Agent Trace Schema for ACO."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
import json, uuid

@dataclass
class ModelCall:
    model_id: str
    provider: str
    tier: int
    input_tokens: int = 0
    output_tokens: int = 0
    reasoning_tokens: int = 0
    cache_hit: bool = False
    latency_ms: float = 0.0
    cost: float = 0.0
    success: bool = True
    error: Optional[str] = None

@dataclass
class ToolCall:
    tool_name: str
    args: Dict[str, Any] = field(default_factory=dict)
    result: Optional[str] = None
    latency_ms: float = 0.0
    cost: float = 0.0
    success: bool = True
    cached: bool = False
    unnecessary: bool = False
    error: Optional[str] = None

@dataclass
class TraceStep:
    step_num: int
    model_call: Optional[ModelCall] = None
    tool_calls: List[ToolCall] = field(default_factory=list)
    context_size: int = 0
    context_sources: List[str] = field(default_factory=list)
    context_budget_used: float = 0.0
    cache_prefix_tokens: int = 0
    cache_suffix_tokens: int = 0
    verifier_called: bool = False
    verifier_result: Optional[str] = None
    retry_num: int = 0
    recovery_action: Optional[str] = None
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())

@dataclass
class AgentTrace:
    trace_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
    request: str = ""
    task_type: str = "unknown_ambiguous"
    difficulty: int = 3
    predicted_tier: int = 4
    steps: List[TraceStep] = field(default_factory=list)
    final_outcome: str = "unknown"
    task_success: bool = False
    total_cost: float = 0.0
    total_tokens: int = 0
    total_tool_calls: int = 0
    total_retries: int = 0
    total_verifier_calls: int = 0
    cache_hit_rate: float = 0.0
    latency_total_ms: float = 0.0
    user_correction: bool = False
    failure_tags: List[str] = field(default_factory=list)
    artifacts_created: List[str] = field(default_factory=list)
    meta_tool_used: bool = False
    early_terminated: bool = False
    escalation_occurred: bool = False
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())

    def to_dict(self) -> dict:
        import dataclasses
        return dataclasses.asdict(self)

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), indent=2)

    def compute_summary(self) -> dict:
        total_cost = sum(
            (s.model_call.cost if s.model_call else 0) +
            sum(tc.cost for tc in s.tool_calls)
            for s in self.steps
        )
        total_tokens = sum(
            (s.model_call.input_tokens + s.model_call.output_tokens if s.model_call else 0)
            for s in self.steps
        )
        total_tool_calls = sum(len(s.tool_calls) for s in self.steps)
        total_retries = sum(s.retry_num for s in self.steps)
        total_verifiers = sum(1 for s in self.steps if s.verifier_called)
        cache_hits = sum(1 for s in self.steps if s.model_call and s.model_call.cache_hit)
        cache_total = sum(1 for s in self.steps if s.model_call)
        return {
            "trace_id": self.trace_id,
            "task_type": self.task_type,
            "difficulty": self.difficulty,
            "predicted_tier": self.predicted_tier,
            "success": self.task_success,
            "total_cost": total_cost,
            "total_tokens": total_tokens,
            "total_tool_calls": total_tool_calls,
            "total_retries": total_retries,
            "total_verifier_calls": total_verifiers,
            "cache_hit_rate": cache_hits / max(cache_total, 1),
            "num_steps": len(self.steps),
            "outcome": self.final_outcome,
            "failure_tags": self.failure_tags,
            "meta_tool_used": self.meta_tool_used,
            "early_terminated": self.early_terminated,
        }