narcolepticchicken commited on
Commit
f8a2d6d
·
verified ·
1 Parent(s): 7268941

Upload aco/config.py

Browse files
Files changed (1) hide show
  1. aco/config.py +144 -0
aco/config.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for Agent Cost Optimizer."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional, Any
5
+ from pathlib import Path
6
+ import yaml
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ model_id: str
12
+ provider: str
13
+ cost_per_1k_input: float
14
+ cost_per_1k_output: float
15
+ cost_per_1k_reasoning: float = 0.0
16
+ latency_ms_estimate: float = 1000.0
17
+ strength_tier: int = 3 # 1=tiny, 2=cheap, 3=medium, 4=frontier, 5=specialist, 6=verifier
18
+ max_context: int = 128000
19
+ supports_tools: bool = True
20
+ supports_reasoning: bool = False
21
+ cache_discount_rate: float = 0.5
22
+
23
+
24
+ @dataclass
25
+ class ToolConfig:
26
+ tool_name: str
27
+ cost_per_call: float = 0.0
28
+ latency_ms_estimate: float = 500.0
29
+ cacheable: bool = False
30
+ requires_verification: bool = False
31
+ max_retries: int = 3
32
+
33
+
34
+ @dataclass
35
+ class VerifierConfig:
36
+ verifier_model_id: str
37
+ cost_per_call: float = 0.0
38
+ latency_ms_estimate: float = 1000.0
39
+ confidence_threshold: float = 0.8
40
+
41
+
42
+ @dataclass
43
+ class RoutingPolicy:
44
+ name: str
45
+ type: str = "cascade" # cascade, static, learned, prompt_only
46
+ threshold_confidence: float = 0.7
47
+ max_cascade_depth: int = 3
48
+ enable_verifier_fallback: bool = True
49
+ enable_escalation: bool = True
50
+
51
+
52
+ @dataclass
53
+ class ACOConfig:
54
+ project_name: str = "agent-cost-optimizer"
55
+ trace_storage_path: str = "./traces"
56
+ models: Dict[str, ModelConfig] = field(default_factory=dict)
57
+ tools: Dict[str, ToolConfig] = field(default_factory=dict)
58
+ verifiers: Dict[str, VerifierConfig] = field(default_factory=dict)
59
+ routing_policy: RoutingPolicy = field(default_factory=lambda: RoutingPolicy("default"))
60
+
61
+ # Cost weights
62
+ model_cost_weight: float = 1.0
63
+ tool_cost_weight: float = 1.0
64
+ verifier_cost_weight: float = 1.0
65
+ latency_weight: float = 0.1
66
+ retry_penalty_weight: float = 2.0
67
+ false_done_penalty: float = 10.0
68
+ unsafe_cheap_model_penalty: float = 20.0
69
+ missed_escalation_penalty: float = 15.0
70
+
71
+ # Module toggles
72
+ enable_telemetry: bool = True
73
+ enable_classifier: bool = True
74
+ enable_router: bool = True
75
+ enable_context_budgeter: bool = True
76
+ enable_cache_layout: bool = True
77
+ enable_tool_gate: bool = True
78
+ enable_verifier_budgeter: bool = True
79
+ enable_retry_optimizer: bool = True
80
+ enable_meta_tool_miner: bool = True
81
+ enable_early_termination: bool = True
82
+
83
+ # Cache-aware layout
84
+ cache_prefix_stable: List[str] = field(default_factory=lambda: [
85
+ "system_rules", "tool_descriptions", "user_preferences"
86
+ ])
87
+ cache_suffix_dynamic: List[str] = field(default_factory=lambda: [
88
+ "user_message", "retrieved_docs", "recent_trace", "artifacts"
89
+ ])
90
+
91
+ # Early termination
92
+ doom_max_cost_ratio: float = 3.0 # stop if cost > 3x predicted
93
+ doom_max_retries: int = 3
94
+ doom_no_progress_steps: int = 5
95
+ doom_verifier_disagreement_threshold: int = 2
96
+
97
+ # Meta-tool mining
98
+ meta_tool_min_frequency: int = 5
99
+ meta_tool_min_success_rate: float = 0.8
100
+
101
+ @classmethod
102
+ def from_yaml(cls, path: str) -> "ACOConfig":
103
+ with open(path, "r") as f:
104
+ data = yaml.safe_load(f)
105
+
106
+ models = {k: ModelConfig(**v) for k, v in data.get("models", {}).items()}
107
+ tools = {k: ToolConfig(**v) for k, v in data.get("tools", {}).items()}
108
+ verifiers = {k: VerifierConfig(**v) for k, v in data.get("verifiers", {}).items()}
109
+ routing = RoutingPolicy(**data.get("routing_policy", {}))
110
+
111
+ config = cls(
112
+ project_name=data.get("project_name", "agent-cost-optimizer"),
113
+ trace_storage_path=data.get("trace_storage_path", "./traces"),
114
+ models=models,
115
+ tools=tools,
116
+ verifiers=verifiers,
117
+ routing_policy=routing,
118
+ )
119
+
120
+ # Override any direct fields
121
+ for key in ["model_cost_weight", "tool_cost_weight", "verifier_cost_weight",
122
+ "latency_weight", "retry_penalty_weight", "false_done_penalty",
123
+ "unsafe_cheap_model_penalty", "missed_escalation_penalty",
124
+ "enable_telemetry", "enable_classifier", "enable_router",
125
+ "enable_context_budgeter", "enable_cache_layout", "enable_tool_gate",
126
+ "enable_verifier_budgeter", "enable_retry_optimizer",
127
+ "enable_meta_tool_miner", "enable_early_termination",
128
+ "doom_max_cost_ratio", "doom_max_retries", "doom_no_progress_steps",
129
+ "doom_verifier_disagreement_threshold",
130
+ "meta_tool_min_frequency", "meta_tool_min_success_rate"]:
131
+ if key in data:
132
+ setattr(config, key, data[key])
133
+
134
+ return config
135
+
136
+ def to_dict(self) -> Dict[str, Any]:
137
+ return {
138
+ "project_name": self.project_name,
139
+ "trace_storage_path": self.trace_storage_path,
140
+ "models": {k: vars(v) for k, v in self.models.items()},
141
+ "tools": {k: vars(v) for k, v in self.tools.items()},
142
+ "verifiers": {k: vars(v) for k, v in self.verifiers.items()},
143
+ "routing_policy": vars(self.routing_policy),
144
+ }