narcolepticchicken commited on
Commit
d807f4a
·
verified ·
1 Parent(s): e1cfa35

Upload aco/config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/config.py +37 -127
aco/config.py CHANGED
@@ -1,144 +1,54 @@
1
- """Configuration for Agent Cost Optimizer."""
2
-
3
  from dataclasses import dataclass, field
4
- from typing import Dict, List, Optional, Any
5
- from pathlib import Path
6
- import yaml
7
 
 
 
 
 
 
 
8
 
9
  @dataclass
10
  class ModelConfig:
11
- model_id: str
12
  provider: str
 
13
  cost_per_1k_input: float
14
  cost_per_1k_output: float
15
- cost_per_1k_reasoning: float = 0.0
16
- latency_ms_estimate: float = 1000.0
17
- strength_tier: int = 3 # 1=tiny, 2=cheap, 3=medium, 4=frontier, 5=specialist, 6=verifier
18
- max_context: int = 128000
19
  supports_tools: bool = True
20
- supports_reasoning: bool = False
21
- cache_discount_rate: float = 0.5
22
-
23
-
24
- @dataclass
25
- class ToolConfig:
26
- tool_name: str
27
- cost_per_call: float = 0.0
28
- latency_ms_estimate: float = 500.0
29
- cacheable: bool = False
30
- requires_verification: bool = False
31
- max_retries: int = 3
32
-
33
-
34
- @dataclass
35
- class VerifierConfig:
36
- verifier_model_id: str
37
- cost_per_call: float = 0.0
38
- latency_ms_estimate: float = 1000.0
39
- confidence_threshold: float = 0.8
40
-
41
 
42
  @dataclass
43
  class RoutingPolicy:
44
- name: str
45
- type: str = "cascade" # cascade, static, learned, prompt_only
46
- threshold_confidence: float = 0.7
47
- max_cascade_depth: int = 3
48
- enable_verifier_fallback: bool = True
49
- enable_escalation: bool = True
50
-
51
 
52
  @dataclass
53
  class ACOConfig:
54
- project_name: str = "agent-cost-optimizer"
55
- trace_storage_path: str = "./traces"
56
- models: Dict[str, ModelConfig] = field(default_factory=dict)
57
- tools: Dict[str, ToolConfig] = field(default_factory=dict)
58
- verifiers: Dict[str, VerifierConfig] = field(default_factory=dict)
59
- routing_policy: RoutingPolicy = field(default_factory=lambda: RoutingPolicy("default"))
60
-
61
- # Cost weights
62
- model_cost_weight: float = 1.0
63
- tool_cost_weight: float = 1.0
64
- verifier_cost_weight: float = 1.0
65
- latency_weight: float = 0.1
66
- retry_penalty_weight: float = 2.0
67
- false_done_penalty: float = 10.0
68
- unsafe_cheap_model_penalty: float = 20.0
69
- missed_escalation_penalty: float = 15.0
70
-
71
- # Module toggles
72
- enable_telemetry: bool = True
73
- enable_classifier: bool = True
74
- enable_router: bool = True
75
- enable_context_budgeter: bool = True
76
- enable_cache_layout: bool = True
77
  enable_tool_gate: bool = True
78
  enable_verifier_budgeter: bool = True
79
- enable_retry_optimizer: bool = True
80
- enable_meta_tool_miner: bool = True
81
- enable_early_termination: bool = True
82
-
83
- # Cache-aware layout
84
- cache_prefix_stable: List[str] = field(default_factory=lambda: [
85
- "system_rules", "tool_descriptions", "user_preferences"
86
- ])
87
- cache_suffix_dynamic: List[str] = field(default_factory=lambda: [
88
- "user_message", "retrieved_docs", "recent_trace", "artifacts"
89
- ])
90
-
91
- # Early termination
92
- doom_max_cost_ratio: float = 3.0 # stop if cost > 3x predicted
93
- doom_max_retries: int = 3
94
- doom_no_progress_steps: int = 5
95
- doom_verifier_disagreement_threshold: int = 2
96
-
97
- # Meta-tool mining
98
- meta_tool_min_frequency: int = 5
99
- meta_tool_min_success_rate: float = 0.8
100
-
101
- @classmethod
102
- def from_yaml(cls, path: str) -> "ACOConfig":
103
- with open(path, "r") as f:
104
- data = yaml.safe_load(f)
105
-
106
- models = {k: ModelConfig(**v) for k, v in data.get("models", {}).items()}
107
- tools = {k: ToolConfig(**v) for k, v in data.get("tools", {}).items()}
108
- verifiers = {k: VerifierConfig(**v) for k, v in data.get("verifiers", {}).items()}
109
- routing = RoutingPolicy(**data.get("routing_policy", {}))
110
-
111
- config = cls(
112
- project_name=data.get("project_name", "agent-cost-optimizer"),
113
- trace_storage_path=data.get("trace_storage_path", "./traces"),
114
- models=models,
115
- tools=tools,
116
- verifiers=verifiers,
117
- routing_policy=routing,
118
- )
119
-
120
- # Override any direct fields
121
- for key in ["model_cost_weight", "tool_cost_weight", "verifier_cost_weight",
122
- "latency_weight", "retry_penalty_weight", "false_done_penalty",
123
- "unsafe_cheap_model_penalty", "missed_escalation_penalty",
124
- "enable_telemetry", "enable_classifier", "enable_router",
125
- "enable_context_budgeter", "enable_cache_layout", "enable_tool_gate",
126
- "enable_verifier_budgeter", "enable_retry_optimizer",
127
- "enable_meta_tool_miner", "enable_early_termination",
128
- "doom_max_cost_ratio", "doom_max_retries", "doom_no_progress_steps",
129
- "doom_verifier_disagreement_threshold",
130
- "meta_tool_min_frequency", "meta_tool_min_success_rate"]:
131
- if key in data:
132
- setattr(config, key, data[key])
133
-
134
- return config
135
-
136
- def to_dict(self) -> Dict[str, Any]:
137
- return {
138
- "project_name": self.project_name,
139
- "trace_storage_path": self.trace_storage_path,
140
- "models": {k: vars(v) for k, v in self.models.items()},
141
- "tools": {k: vars(v) for k, v in self.tools.items()},
142
- "verifiers": {k: vars(v) for k, v in self.verifiers.items()},
143
- "routing_policy": vars(self.routing_policy),
144
- }
 
1
+ """ACO Configuration."""
 
2
  from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional
4
+ from enum import IntEnum
 
5
 
6
+ class ModelTier(IntEnum):
7
+ TINY_LOCAL = 1
8
+ CHEAP_CLOUD = 2
9
+ MEDIUM = 3
10
+ FRONTIER = 4
11
+ SPECIALIST = 5
12
 
13
  @dataclass
14
  class ModelConfig:
15
+ tier: int
16
  provider: str
17
+ model_id: str
18
  cost_per_1k_input: float
19
  cost_per_1k_output: float
20
+ max_context: int
 
 
 
21
  supports_tools: bool = True
22
+ supports_vision: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @dataclass
25
  class RoutingPolicy:
26
+ safety_threshold: float = 0.30
27
+ downgrade_threshold: float = 0.90
28
+ max_retries: int = 3
29
+ max_cost_per_task: float = 5.0
30
+ use_dynamic_difficulty: bool = True
31
+ use_ml_confirmation: bool = True
 
32
 
33
  @dataclass
34
  class ACOConfig:
35
+ routing_policy: RoutingPolicy = field(default_factory=RoutingPolicy)
36
+ models: Dict[int, ModelConfig] = field(default_factory=dict)
37
+ task_floors: Dict[str, int] = field(default_factory=lambda: {
38
+ "legal_regulated": 4, "long_horizon": 3, "research": 3,
39
+ "coding": 3, "unknown_ambiguous": 3, "quick_answer": 1,
40
+ "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2,
41
+ })
42
+ tier_costs: Dict[int, float] = field(default_factory=lambda: {
43
+ 1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5
44
+ })
45
+ tier_strengths: Dict[int, float] = field(default_factory=lambda: {
46
+ 1: 0.35, 2: 0.55, 3: 0.80, 4: 0.93, 5: 0.97
47
+ })
48
+ router_model_path: str = "router_models/router_bundle_v8.pkl"
49
+ enable_cache_aware: bool = True
 
 
 
 
 
 
 
 
50
  enable_tool_gate: bool = True
51
  enable_verifier_budgeter: bool = True
52
+ enable_doom_detector: bool = True
53
+ enable_meta_tools: bool = True
54
+ telemetry_file: str = "aco_traces.jsonl"