File size: 3,667 Bytes
d4271af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Tests for Agent Cost Optimizer."""

import unittest
from aco import AgentCostOptimizer
from aco.config import ACOConfig, ModelConfig, ToolConfig, VerifierConfig, RoutingPolicy
from aco.trace_schema import TaskType, Outcome, FailureTag, ModelCall


class TestAgentCostOptimizer(unittest.TestCase):
    def setUp(self):
        self.config = ACOConfig(
            models={
                "tiny": ModelConfig("tiny", "local", 0.0001, 0.0002, strength_tier=1),
                "cheap": ModelConfig("cheap", "cloud", 0.0005, 0.001, strength_tier=2),
                "medium": ModelConfig("medium", "cloud", 0.003, 0.006, strength_tier=3),
                "frontier": ModelConfig("frontier", "cloud", 0.01, 0.03, strength_tier=4),
            },
            tools={
                "search": ToolConfig("search", 0.002, 500),
            },
            verifiers={
                "v1": VerifierConfig("v1", 0.005, 800, 0.8),
            },
            routing_policy=RoutingPolicy("test"),
        )
        self.optimizer = AgentCostOptimizer(self.config)

    def test_optimize_quick_answer(self):
        result = self.optimizer.optimize("What is the capital of France?")
        self.assertEqual(result.routing_decision.tier, 1)
        self.assertLess(result.estimated_cost, 0.01)

    def test_optimize_coding_task(self):
        result = self.optimizer.optimize("Write a Python function to reverse a linked list")
        self.assertIn(result.routing_decision.tier, [2, 3])
        self.assertGreater(result.confidence, 0.0)

    def test_optimize_legal_task(self):
        result = self.optimizer.optimize("Review this contract for liability clauses")
        self.assertEqual(result.routing_decision.tier, 4)  # frontier for legal
        self.assertTrue(result.verifier_decision is not None)

    def test_trace_lifecycle(self):
        from aco.classifier import TaskPrediction
        
        trace_id = "test-trace-001"
        prediction = TaskPrediction(
            task_type=TaskType.CODING,
            expected_cost=0.05,
            expected_model_tier=3,
            expected_tools_needed=["code_execution"],
            risk_of_failure=0.5,
            retrieval_required=False,
            verifier_required=False,
            expected_latency_ms=1000,
            confidence=0.8,
        )
        
        trace = self.optimizer.start_trace(trace_id, "Test request", prediction)
        self.assertEqual(trace.trace_id, trace_id)
        
        model_call = ModelCall(
            model_id="medium",
            provider="cloud",
            input_tokens=1024,
            output_tokens=512,
            cost_per_1k_input=0.003,
            cost_per_1k_output=0.006,
        )
        
        self.optimizer.record_step(
            trace_id=trace_id,
            model_call=model_call,
            context_size_tokens=1024,
            step_outcome=Outcome.SUCCESS,
        )
        
        finalized = self.optimizer.finalize_trace(
            trace_id=trace_id,
            outcome=Outcome.SUCCESS,
            user_satisfaction=0.95,
        )
        
        self.assertIsNotNone(finalized)
        self.assertEqual(finalized.final_outcome, Outcome.SUCCESS)
        self.assertGreater(finalized.total_cost_computed, 0)

    def test_cost_adjusted_score(self):
        from aco.datasets.synthetic_traces import SyntheticTraceGenerator
        
        gen = SyntheticTraceGenerator(seed=42)
        traces = gen.generate(10)
        
        for trace in traces:
            score = self.optimizer.compute_cost_adjusted_score(trace)
            self.assertIsInstance(score, float)


if __name__ == "__main__":
    unittest.main()