File size: 4,451 Bytes
de021eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Cost Telemetry Collector - Module 1."""

import json
import os
from typing import List, Dict, Any, Optional
from datetime import datetime
from pathlib import Path

from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall, TaskType, Outcome, FailureTag


class CostTelemetryCollector:
    """Collects structured telemetry from agent runs and persists normalized traces."""

    def __init__(self, storage_path: str = "./traces"):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(parents=True, exist_ok=True)
        self._pending: Dict[str, AgentTrace] = {}

    def start_trace(self, trace_id: str, user_request: str, task_type: TaskType) -> AgentTrace:
        trace = AgentTrace(
            trace_id=trace_id,
            user_request=user_request,
            task_type=task_type,
        )
        self._pending[trace_id] = trace
        return trace

    def add_step(
        self,
        trace_id: str,
        step_id: str,
        model_call: ModelCall,
        tool_calls: Optional[List[ToolCall]] = None,
        verifier_calls: Optional[List[VerifierCall]] = None,
        context_size_tokens: int = 0,
        context_sources: Optional[List[str]] = None,
        retry_count: int = 0,
        recovery_action: Optional[str] = None,
        artifacts_created: Optional[List[str]] = None,
        step_outcome: Optional[Outcome] = None,
    ) -> None:
        trace = self._pending.get(trace_id)
        if not trace:
            raise ValueError(f"Trace {trace_id} not found")
        
        step = TraceStep(
            step_id=step_id,
            timestamp=datetime.utcnow(),
            task_type=trace.task_type,
            model_call=model_call,
            tool_calls=tool_calls or [],
            verifier_calls=verifier_calls or [],
            context_size_tokens=context_size_tokens,
            context_sources=context_sources or [],
            retry_count=retry_count,
            recovery_action=recovery_action,
            artifacts_created=artifacts_created or [],
            step_outcome=step_outcome,
        )
        trace.steps.append(step)

    def finalize_trace(
        self,
        trace_id: str,
        final_outcome: Outcome,
        failure_tags: Optional[List[FailureTag]] = None,
        user_satisfaction: Optional[float] = None,
        total_cost_saved_vs_frontier: Optional[float] = None,
        optimal_cost: Optional[float] = None,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> AgentTrace:
        trace = self._pending.pop(trace_id)
        trace.final_outcome = final_outcome
        trace.failure_tags = failure_tags or []
        trace.user_satisfaction = user_satisfaction
        trace.total_cost_saved_vs_frontier = total_cost_saved_vs_frontier
        trace.optimal_cost = optimal_cost
        trace.metadata = metadata or {}
        trace.total_cost = trace.total_cost_computed
        
        self._persist(trace)
        return trace

    def _persist(self, trace: AgentTrace) -> None:
        filepath = self.storage_path / f"{trace.trace_id}.json"
        with open(filepath, "w") as f:
            json.dump(trace.to_dict(), f, indent=2, default=str)

    def load_trace(self, trace_id: str) -> Optional[AgentTrace]:
        filepath = self.storage_path / f"{trace_id}.json"
        if not filepath.exists():
            return None
        with open(filepath, "r") as f:
            data = json.load(f)
        # Simplified deserialization - full version would reconstruct dataclasses
        return data

    def list_traces(self) -> List[str]:
        return [p.stem for p in self.storage_path.glob("*.json")]

    def get_stats(self) -> Dict[str, Any]:
        traces = []
        for tid in self.list_traces():
            t = self.load_trace(tid)
            if t:
                traces.append(t)
        
        if not traces:
            return {"count": 0}
        
        total_cost = sum(t.get("total_cost", 0) for t in traces if isinstance(t, dict))
        total_steps = sum(len(t.get("steps", [])) for t in traces if isinstance(t, dict))
        
        return {
            "count": len(traces),
            "avg_cost": total_cost / len(traces),
            "avg_steps": total_steps / len(traces),
            "success_rate": sum(
                1 for t in traces
                if isinstance(t, dict) and t.get("final_outcome") == "success"
            ) / len(traces),
        }