Upload tests/test_optimizer.py
Browse files- tests/test_optimizer.py +100 -0
tests/test_optimizer.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Agent Cost Optimizer."""
|
| 2 |
+
|
| 3 |
+
import unittest
|
| 4 |
+
from aco import AgentCostOptimizer
|
| 5 |
+
from aco.config import ACOConfig, ModelConfig, ToolConfig, VerifierConfig, RoutingPolicy
|
| 6 |
+
from aco.trace_schema import TaskType, Outcome, FailureTag, ModelCall
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestAgentCostOptimizer(unittest.TestCase):
|
| 10 |
+
def setUp(self):
|
| 11 |
+
self.config = ACOConfig(
|
| 12 |
+
models={
|
| 13 |
+
"tiny": ModelConfig("tiny", "local", 0.0001, 0.0002, strength_tier=1),
|
| 14 |
+
"cheap": ModelConfig("cheap", "cloud", 0.0005, 0.001, strength_tier=2),
|
| 15 |
+
"medium": ModelConfig("medium", "cloud", 0.003, 0.006, strength_tier=3),
|
| 16 |
+
"frontier": ModelConfig("frontier", "cloud", 0.01, 0.03, strength_tier=4),
|
| 17 |
+
},
|
| 18 |
+
tools={
|
| 19 |
+
"search": ToolConfig("search", 0.002, 500),
|
| 20 |
+
},
|
| 21 |
+
verifiers={
|
| 22 |
+
"v1": VerifierConfig("v1", 0.005, 800, 0.8),
|
| 23 |
+
},
|
| 24 |
+
routing_policy=RoutingPolicy("test"),
|
| 25 |
+
)
|
| 26 |
+
self.optimizer = AgentCostOptimizer(self.config)
|
| 27 |
+
|
| 28 |
+
def test_optimize_quick_answer(self):
|
| 29 |
+
result = self.optimizer.optimize("What is the capital of France?")
|
| 30 |
+
self.assertEqual(result.routing_decision.tier, 1)
|
| 31 |
+
self.assertLess(result.estimated_cost, 0.01)
|
| 32 |
+
|
| 33 |
+
def test_optimize_coding_task(self):
|
| 34 |
+
result = self.optimizer.optimize("Write a Python function to reverse a linked list")
|
| 35 |
+
self.assertIn(result.routing_decision.tier, [2, 3])
|
| 36 |
+
self.assertGreater(result.confidence, 0.0)
|
| 37 |
+
|
| 38 |
+
def test_optimize_legal_task(self):
|
| 39 |
+
result = self.optimizer.optimize("Review this contract for liability clauses")
|
| 40 |
+
self.assertEqual(result.routing_decision.tier, 4) # frontier for legal
|
| 41 |
+
self.assertTrue(result.verifier_decision is not None)
|
| 42 |
+
|
| 43 |
+
def test_trace_lifecycle(self):
|
| 44 |
+
from aco.classifier import TaskPrediction
|
| 45 |
+
|
| 46 |
+
trace_id = "test-trace-001"
|
| 47 |
+
prediction = TaskPrediction(
|
| 48 |
+
task_type=TaskType.CODING,
|
| 49 |
+
expected_cost=0.05,
|
| 50 |
+
expected_model_tier=3,
|
| 51 |
+
expected_tools_needed=["code_execution"],
|
| 52 |
+
risk_of_failure=0.5,
|
| 53 |
+
retrieval_required=False,
|
| 54 |
+
verifier_required=False,
|
| 55 |
+
expected_latency_ms=1000,
|
| 56 |
+
confidence=0.8,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
trace = self.optimizer.start_trace(trace_id, "Test request", prediction)
|
| 60 |
+
self.assertEqual(trace.trace_id, trace_id)
|
| 61 |
+
|
| 62 |
+
model_call = ModelCall(
|
| 63 |
+
model_id="medium",
|
| 64 |
+
provider="cloud",
|
| 65 |
+
input_tokens=1024,
|
| 66 |
+
output_tokens=512,
|
| 67 |
+
cost_per_1k_input=0.003,
|
| 68 |
+
cost_per_1k_output=0.006,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.optimizer.record_step(
|
| 72 |
+
trace_id=trace_id,
|
| 73 |
+
model_call=model_call,
|
| 74 |
+
context_size_tokens=1024,
|
| 75 |
+
step_outcome=Outcome.SUCCESS,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
finalized = self.optimizer.finalize_trace(
|
| 79 |
+
trace_id=trace_id,
|
| 80 |
+
outcome=Outcome.SUCCESS,
|
| 81 |
+
user_satisfaction=0.95,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.assertIsNotNone(finalized)
|
| 85 |
+
self.assertEqual(finalized.final_outcome, Outcome.SUCCESS)
|
| 86 |
+
self.assertGreater(finalized.total_cost_computed, 0)
|
| 87 |
+
|
| 88 |
+
def test_cost_adjusted_score(self):
|
| 89 |
+
from aco.datasets.synthetic_traces import SyntheticTraceGenerator
|
| 90 |
+
|
| 91 |
+
gen = SyntheticTraceGenerator(seed=42)
|
| 92 |
+
traces = gen.generate(10)
|
| 93 |
+
|
| 94 |
+
for trace in traces:
|
| 95 |
+
score = self.optimizer.compute_cost_adjusted_score(trace)
|
| 96 |
+
self.assertIsInstance(score, float)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
unittest.main()
|