narcolepticchicken commited on
Commit
d4271af
·
verified ·
1 Parent(s): 00c175d

Upload tests/test_optimizer.py

Browse files
Files changed (1) hide show
  1. tests/test_optimizer.py +100 -0
tests/test_optimizer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Agent Cost Optimizer."""
2
+
3
+ import unittest
4
+ from aco import AgentCostOptimizer
5
+ from aco.config import ACOConfig, ModelConfig, ToolConfig, VerifierConfig, RoutingPolicy
6
+ from aco.trace_schema import TaskType, Outcome, FailureTag, ModelCall
7
+
8
+
9
+ class TestAgentCostOptimizer(unittest.TestCase):
10
+ def setUp(self):
11
+ self.config = ACOConfig(
12
+ models={
13
+ "tiny": ModelConfig("tiny", "local", 0.0001, 0.0002, strength_tier=1),
14
+ "cheap": ModelConfig("cheap", "cloud", 0.0005, 0.001, strength_tier=2),
15
+ "medium": ModelConfig("medium", "cloud", 0.003, 0.006, strength_tier=3),
16
+ "frontier": ModelConfig("frontier", "cloud", 0.01, 0.03, strength_tier=4),
17
+ },
18
+ tools={
19
+ "search": ToolConfig("search", 0.002, 500),
20
+ },
21
+ verifiers={
22
+ "v1": VerifierConfig("v1", 0.005, 800, 0.8),
23
+ },
24
+ routing_policy=RoutingPolicy("test"),
25
+ )
26
+ self.optimizer = AgentCostOptimizer(self.config)
27
+
28
+ def test_optimize_quick_answer(self):
29
+ result = self.optimizer.optimize("What is the capital of France?")
30
+ self.assertEqual(result.routing_decision.tier, 1)
31
+ self.assertLess(result.estimated_cost, 0.01)
32
+
33
+ def test_optimize_coding_task(self):
34
+ result = self.optimizer.optimize("Write a Python function to reverse a linked list")
35
+ self.assertIn(result.routing_decision.tier, [2, 3])
36
+ self.assertGreater(result.confidence, 0.0)
37
+
38
+ def test_optimize_legal_task(self):
39
+ result = self.optimizer.optimize("Review this contract for liability clauses")
40
+ self.assertEqual(result.routing_decision.tier, 4) # frontier for legal
41
+ self.assertTrue(result.verifier_decision is not None)
42
+
43
+ def test_trace_lifecycle(self):
44
+ from aco.classifier import TaskPrediction
45
+
46
+ trace_id = "test-trace-001"
47
+ prediction = TaskPrediction(
48
+ task_type=TaskType.CODING,
49
+ expected_cost=0.05,
50
+ expected_model_tier=3,
51
+ expected_tools_needed=["code_execution"],
52
+ risk_of_failure=0.5,
53
+ retrieval_required=False,
54
+ verifier_required=False,
55
+ expected_latency_ms=1000,
56
+ confidence=0.8,
57
+ )
58
+
59
+ trace = self.optimizer.start_trace(trace_id, "Test request", prediction)
60
+ self.assertEqual(trace.trace_id, trace_id)
61
+
62
+ model_call = ModelCall(
63
+ model_id="medium",
64
+ provider="cloud",
65
+ input_tokens=1024,
66
+ output_tokens=512,
67
+ cost_per_1k_input=0.003,
68
+ cost_per_1k_output=0.006,
69
+ )
70
+
71
+ self.optimizer.record_step(
72
+ trace_id=trace_id,
73
+ model_call=model_call,
74
+ context_size_tokens=1024,
75
+ step_outcome=Outcome.SUCCESS,
76
+ )
77
+
78
+ finalized = self.optimizer.finalize_trace(
79
+ trace_id=trace_id,
80
+ outcome=Outcome.SUCCESS,
81
+ user_satisfaction=0.95,
82
+ )
83
+
84
+ self.assertIsNotNone(finalized)
85
+ self.assertEqual(finalized.final_outcome, Outcome.SUCCESS)
86
+ self.assertGreater(finalized.total_cost_computed, 0)
87
+
88
+ def test_cost_adjusted_score(self):
89
+ from aco.datasets.synthetic_traces import SyntheticTraceGenerator
90
+
91
+ gen = SyntheticTraceGenerator(seed=42)
92
+ traces = gen.generate(10)
93
+
94
+ for trace in traces:
95
+ score = self.optimizer.compute_cost_adjusted_score(trace)
96
+ self.assertIsInstance(score, float)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ unittest.main()