hequ commited on
Commit
25842fe
·
verified ·
1 Parent(s): faf764f

Update config_loader.py

Browse files
Files changed (1) hide show
  1. config_loader.py +232 -229
config_loader.py CHANGED
@@ -1,230 +1,233 @@
1
- # SPDX-License-Identifier: GPL-3.0-or-later
2
- #
3
- # Toolify: Empower any LLM with function calling capabilities.
4
- # Copyright (C) 2025 FunnyCups (https://github.com/funnycups)
5
-
6
- import os
7
- import yaml
8
- from typing import List, Dict, Any, Set, Optional
9
- from pydantic import BaseModel, Field, field_validator
10
-
11
-
12
- class ServerConfig(BaseModel):
13
- """Server configuration"""
14
- port: int = Field(default=8000, ge=1, le=65535, description="Server port")
15
- host: str = Field(default="0.0.0.0", description="Server host address")
16
- timeout: int = Field(default=180, ge=1, description="Request timeout (seconds)")
17
-
18
-
19
- class UpstreamService(BaseModel):
20
- """Upstream service configuration"""
21
- name: str = Field(description="Service name")
22
- base_url: str = Field(description="Service base URL")
23
- api_key: str = Field(description="API key")
24
- models: List[str] = Field(description="List of supported models")
25
- description: str = Field(default="", description="Service description")
26
- is_default: bool = Field(default=False, description="Is default service")
27
-
28
- @field_validator('base_url')
29
- def validate_base_url(cls, v):
30
- if not v.startswith(('http://', 'https://')):
31
- raise ValueError('base_url must start with http:// or https://')
32
- return v.rstrip('/')
33
-
34
- @field_validator('api_key')
35
- def validate_api_key(cls, v):
36
- if not v or v.strip() == "":
37
- raise ValueError('api_key cannot be empty')
38
- return v
39
-
40
- @field_validator('models')
41
- def validate_models(cls, v):
42
- if not v or len(v) == 0:
43
- raise ValueError('models list cannot be empty')
44
- for model in v:
45
- if not model or model.strip() == "":
46
- raise ValueError('model name cannot be empty')
47
- return v
48
-
49
-
50
- class ClientAuthConfig(BaseModel):
51
- """Client authentication configuration"""
52
- allowed_keys: List[str] = Field(description="List of allowed client API keys")
53
-
54
- @field_validator('allowed_keys')
55
- def validate_allowed_keys(cls, v):
56
- if not v or len(v) == 0:
57
- raise ValueError('allowed_keys cannot be empty')
58
- for key in v:
59
- if not key or key.strip() == "":
60
- raise ValueError('API key cannot be empty')
61
- return v
62
-
63
-
64
- class FeaturesConfig(BaseModel):
65
- """Feature configuration"""
66
- enable_function_calling: bool = Field(default=True, description="Enable function calling")
67
- log_level: str = Field(default="INFO", description="Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL, or DISABLED")
68
- convert_developer_to_system: bool = Field(default=True, description="Convert developer role to system role")
69
- prompt_template: Optional[str] = Field(default=None, description="Custom prompt template for function calling")
70
- key_passthrough: bool = Field(default=False, description="If true, directly forward client-provided API key to upstream instead of using configured upstream key")
71
- model_passthrough: bool = Field(default=False, description="If true, forward all requests directly to the 'openai' upstream service, ignoring model-based routing")
72
-
73
- @field_validator('log_level')
74
- def validate_log_level(cls, v):
75
- valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "DISABLED"]
76
- if v.upper() not in valid_levels:
77
- raise ValueError(f"log_level must be one of {valid_levels}")
78
- return v.upper()
79
-
80
- @field_validator('prompt_template')
81
- def validate_prompt_template(cls, v):
82
- if v:
83
- if "{tools_list}" not in v or "{trigger_signal}" not in v:
84
- raise ValueError("prompt_template must contain {tools_list} and {trigger_signal} placeholders")
85
- return v
86
-
87
-
88
- class AppConfig(BaseModel):
89
- """Application full configuration"""
90
- server: ServerConfig = Field(default_factory=ServerConfig)
91
- upstream_services: List[UpstreamService] = Field(description="List of upstream services")
92
- client_authentication: ClientAuthConfig = Field(description="Client authentication configuration")
93
- features: FeaturesConfig = Field(default_factory=FeaturesConfig)
94
-
95
- @field_validator('upstream_services')
96
- def validate_upstream_services(cls, v):
97
- if not v or len(v) == 0:
98
- raise ValueError('upstream_services cannot be empty')
99
-
100
- default_services = [service for service in v if service.is_default]
101
- if len(default_services) == 0:
102
- raise ValueError('Must have at least one default upstream service (is_default: true)')
103
- if len(default_services) > 1:
104
- raise ValueError('Only one upstream service can be marked as default')
105
-
106
- all_models = set()
107
- all_aliases = set()
108
-
109
- for service in v:
110
- for model in service.models:
111
- if model in all_models:
112
- raise ValueError(f'Duplicate model entry found: {model}')
113
- all_models.add(model)
114
-
115
- if ':' in model:
116
- parts = model.split(':', 1)
117
- if len(parts) == 2:
118
- alias, real_model = parts
119
- if not alias.strip() or not real_model.strip():
120
- raise ValueError(f"Invalid alias format in '{model}'. Both parts must not be empty.")
121
- all_aliases.add(alias)
122
- else:
123
- raise ValueError(f"Invalid model format with colon: {model}")
124
-
125
- regular_models = {m for m in all_models if ':' not in m}
126
- conflicts = all_aliases.intersection(regular_models)
127
- if conflicts:
128
- raise ValueError(f"Alias names {conflicts} conflict with model names.")
129
-
130
- return v
131
-
132
-
133
- class ConfigLoader:
134
- """Configuration loader"""
135
-
136
- def __init__(self, config_path: str = "config.yaml"):
137
- self.config_path = config_path
138
- self._config: AppConfig = None
139
-
140
- def load_config(self) -> AppConfig:
141
- """Load configuration file"""
142
- if not os.path.exists(self.config_path):
143
- raise FileNotFoundError(
144
- f"Configuration file '{self.config_path}' not found. "
145
- f"Please copy 'config.example.yaml' to '{self.config_path}' and modify the configuration as needed."
146
- )
147
-
148
- try:
149
- with open(self.config_path, 'r', encoding='utf-8') as f:
150
- config_data = yaml.safe_load(f)
151
- except yaml.YAMLError as e:
152
- raise ValueError(f"Configuration file format error: {e}")
153
- except Exception as e:
154
- raise ValueError(f"Failed to read configuration file: {e}")
155
-
156
- if not config_data:
157
- raise ValueError("Configuration file is empty")
158
-
159
- try:
160
- self._config = AppConfig(**config_data)
161
- return self._config
162
- except Exception as e:
163
- raise ValueError(f"Configuration validation failed: {e}")
164
-
165
- @property
166
- def config(self) -> AppConfig:
167
- """Get configuration object"""
168
- if self._config is None:
169
- self.load_config()
170
- return self._config
171
-
172
- def get_model_to_service_mapping(self) -> tuple[Dict[str, Dict[str, Any]], Dict[str, List[str]]]:
173
- """Get model to service mapping and model aliases"""
174
- config = self.config
175
- model_mapping = {}
176
- alias_mapping = {}
177
-
178
- for service in config.upstream_services:
179
- service_info = {
180
- "name": service.name,
181
- "base_url": service.base_url,
182
- "api_key": service.api_key,
183
- "description": service.description,
184
- "is_default": service.is_default
185
- }
186
-
187
- for model_entry in service.models:
188
- model_mapping[model_entry] = service_info
189
- if ':' in model_entry:
190
- parts = model_entry.split(':', 1)
191
- if len(parts) == 2:
192
- alias, _ = parts
193
- if alias not in alias_mapping:
194
- alias_mapping[alias] = []
195
- alias_mapping[alias].append(model_entry)
196
-
197
- return model_mapping, alias_mapping
198
-
199
- def get_default_service(self) -> Dict[str, Any]:
200
- """Get default service configuration"""
201
- config = self.config
202
- for service in config.upstream_services:
203
- if service.is_default:
204
- return {
205
- "name": service.name,
206
- "base_url": service.base_url,
207
- "api_key": service.api_key,
208
- "description": service.description,
209
- "is_default": service.is_default
210
- }
211
- raise ValueError("No default service configured")
212
-
213
- def get_allowed_client_keys(self) -> Set[str]:
214
- """Get set of allowed client keys"""
215
- return set(self.config.client_authentication.allowed_keys)
216
-
217
- def get_log_level(self) -> str:
218
- """Get configured log level"""
219
- return self.config.features.log_level
220
-
221
- def get_features_config(self) -> Dict[str, Any]:
222
- """Get feature configuration"""
223
- return {
224
- "function_calling": self.config.features.enable_function_calling,
225
- "log_level": self.config.features.log_level,
226
- "convert_developer_to_system": self.config.features.convert_developer_to_system
227
- }
228
-
229
-
 
 
 
230
  config_loader = ConfigLoader()
 
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ #
3
+ # Toolify: Empower any LLM with function calling capabilities.
4
+ # Copyright (C) 2025 FunnyCups (https://github.com/funnycups)
5
+
6
+ import os
7
+ import yaml
8
+ from typing import List, Dict, Any, Set, Optional
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+
12
+ class ServerConfig(BaseModel):
13
+ """Server configuration"""
14
+ port: int = Field(default=8000, ge=1, le=65535, description="Server port")
15
+ host: str = Field(default="0.0.0.0", description="Server host address")
16
+ timeout: int = Field(default=180, ge=1, description="Request timeout (seconds)")
17
+
18
+
19
+ class UpstreamService(BaseModel):
20
+ """Upstream service configuration"""
21
+ name: str = Field(description="Service name")
22
+ base_url: str = Field(description="Service base URL")
23
+ api_key: str = Field(description="API key")
24
+ models: List[str] = Field(description="List of supported models")
25
+ description: str = Field(default="", description="Service description")
26
+ is_default: bool = Field(default=False, description="Is default service")
27
+
28
+ @field_validator('base_url')
29
+ def validate_base_url(cls, v):
30
+ if not v.startswith(('http://', 'https://')):
31
+ raise ValueError('base_url must start with http:// or https://')
32
+ return v.rstrip('/')
33
+
34
+ @field_validator('api_key')
35
+ def validate_api_key(cls, v):
36
+ if not v or v.strip() == "":
37
+ raise ValueError('api_key cannot be empty')
38
+ return v
39
+
40
+ @field_validator('models')
41
+ def validate_models(cls, v):
42
+ if not v or len(v) == 0:
43
+ raise ValueError('models list cannot be empty')
44
+ for model in v:
45
+ if not model or model.strip() == "":
46
+ raise ValueError('model name cannot be empty')
47
+ return v
48
+
49
+
50
+ class ClientAuthConfig(BaseModel):
51
+ """Client authentication configuration"""
52
+ allowed_keys: List[str] = Field(description="List of allowed client API keys")
53
+
54
+ @field_validator('allowed_keys')
55
+ def validate_allowed_keys(cls, v):
56
+ if not v or len(v) == 0:
57
+ raise ValueError('allowed_keys cannot be empty')
58
+ for key in v:
59
+ if not key or key.strip() == "":
60
+ raise ValueError('API key cannot be empty')
61
+ return v
62
+
63
+
64
+ class FeaturesConfig(BaseModel):
65
+ """Feature configuration"""
66
+ enable_function_calling: bool = Field(default=True, description="Enable function calling")
67
+ log_level: str = Field(default="INFO", description="Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL, or DISABLED")
68
+ convert_developer_to_system: bool = Field(default=True, description="Convert developer role to system role")
69
+ prompt_template: Optional[str] = Field(default=None, description="Custom prompt template for function calling")
70
+ key_passthrough: bool = Field(default=False, description="If true, directly forward client-provided API key to upstream instead of using configured upstream key")
71
+ model_passthrough: bool = Field(default=False, description="If true, forward all requests directly to the 'openai' upstream service, ignoring model-based routing")
72
+
73
+ @field_validator('log_level')
74
+ def validate_log_level(cls, v):
75
+ valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "DISABLED"]
76
+ if v.upper() not in valid_levels:
77
+ raise ValueError(f"log_level must be one of {valid_levels}")
78
+ return v.upper()
79
+
80
+ @field_validator('prompt_template')
81
+ def validate_prompt_template(cls, v):
82
+ if v:
83
+ if "{tools_list}" not in v or "{trigger_signal}" not in v:
84
+ raise ValueError("prompt_template must contain {tools_list} and {trigger_signal} placeholders")
85
+ return v
86
+
87
+
88
+ class AppConfig(BaseModel):
89
+ """Application full configuration"""
90
+ server: ServerConfig = Field(default_factory=ServerConfig)
91
+ upstream_services: List[UpstreamService] = Field(description="List of upstream services")
92
+ client_authentication: ClientAuthConfig = Field(description="Client authentication configuration")
93
+ features: FeaturesConfig = Field(default_factory=FeaturesConfig)
94
+
95
+ @field_validator('upstream_services')
96
+ def validate_upstream_services(cls, v):
97
+ if not v or len(v) == 0:
98
+ raise ValueError('upstream_services cannot be empty')
99
+
100
+ default_services = [service for service in v if service.is_default]
101
+ if len(default_services) == 0:
102
+ raise ValueError('Must have at least one default upstream service (is_default: true)')
103
+ if len(default_services) > 1:
104
+ raise ValueError('Only one upstream service can be marked as default')
105
+
106
+ all_models = set()
107
+ all_aliases = set()
108
+
109
+ for service in v:
110
+ for model in service.models:
111
+ if model in all_models:
112
+ raise ValueError(f'Duplicate model entry found: {model}')
113
+ all_models.add(model)
114
+
115
+ if ':' in model:
116
+ parts = model.split(':', 1)
117
+ if len(parts) == 2:
118
+ alias, real_model = parts
119
+ if not alias.strip() or not real_model.strip():
120
+ raise ValueError(f"Invalid alias format in '{model}'. Both parts must not be empty.")
121
+ all_aliases.add(alias)
122
+ else:
123
+ raise ValueError(f"Invalid model format with colon: {model}")
124
+
125
+ regular_models = {m for m in all_models if ':' not in m}
126
+ conflicts = all_aliases.intersection(regular_models)
127
+ if conflicts:
128
+ raise ValueError(f"Alias names {conflicts} conflict with model names.")
129
+
130
+ return v
131
+
132
+
133
+ class ConfigLoader:
134
+ """Configuration loader"""
135
+
136
+ def __init__(self, config_path: str = "config.yaml"):
137
+ env_config_path = os.getenv("CONFIG_PATH")
138
+ if env_config_path:
139
+ config_path = env_config_path
140
+ self.config_path = config_path
141
+ self._config: AppConfig = None
142
+
143
+ def load_config(self) -> AppConfig:
144
+ """Load configuration file"""
145
+ if not os.path.exists(self.config_path):
146
+ raise FileNotFoundError(
147
+ f"Configuration file '{self.config_path}' not found. "
148
+ f"Please copy 'config.example.yaml' to '{self.config_path}' and modify the configuration as needed."
149
+ )
150
+
151
+ try:
152
+ with open(self.config_path, 'r', encoding='utf-8') as f:
153
+ config_data = yaml.safe_load(f)
154
+ except yaml.YAMLError as e:
155
+ raise ValueError(f"Configuration file format error: {e}")
156
+ except Exception as e:
157
+ raise ValueError(f"Failed to read configuration file: {e}")
158
+
159
+ if not config_data:
160
+ raise ValueError("Configuration file is empty")
161
+
162
+ try:
163
+ self._config = AppConfig(**config_data)
164
+ return self._config
165
+ except Exception as e:
166
+ raise ValueError(f"Configuration validation failed: {e}")
167
+
168
+ @property
169
+ def config(self) -> AppConfig:
170
+ """Get configuration object"""
171
+ if self._config is None:
172
+ self.load_config()
173
+ return self._config
174
+
175
+ def get_model_to_service_mapping(self) -> tuple[Dict[str, Dict[str, Any]], Dict[str, List[str]]]:
176
+ """Get model to service mapping and model aliases"""
177
+ config = self.config
178
+ model_mapping = {}
179
+ alias_mapping = {}
180
+
181
+ for service in config.upstream_services:
182
+ service_info = {
183
+ "name": service.name,
184
+ "base_url": service.base_url,
185
+ "api_key": service.api_key,
186
+ "description": service.description,
187
+ "is_default": service.is_default
188
+ }
189
+
190
+ for model_entry in service.models:
191
+ model_mapping[model_entry] = service_info
192
+ if ':' in model_entry:
193
+ parts = model_entry.split(':', 1)
194
+ if len(parts) == 2:
195
+ alias, _ = parts
196
+ if alias not in alias_mapping:
197
+ alias_mapping[alias] = []
198
+ alias_mapping[alias].append(model_entry)
199
+
200
+ return model_mapping, alias_mapping
201
+
202
+ def get_default_service(self) -> Dict[str, Any]:
203
+ """Get default service configuration"""
204
+ config = self.config
205
+ for service in config.upstream_services:
206
+ if service.is_default:
207
+ return {
208
+ "name": service.name,
209
+ "base_url": service.base_url,
210
+ "api_key": service.api_key,
211
+ "description": service.description,
212
+ "is_default": service.is_default
213
+ }
214
+ raise ValueError("No default service configured")
215
+
216
+ def get_allowed_client_keys(self) -> Set[str]:
217
+ """Get set of allowed client keys"""
218
+ return set(self.config.client_authentication.allowed_keys)
219
+
220
+ def get_log_level(self) -> str:
221
+ """Get configured log level"""
222
+ return self.config.features.log_level
223
+
224
+ def get_features_config(self) -> Dict[str, Any]:
225
+ """Get feature configuration"""
226
+ return {
227
+ "function_calling": self.config.features.enable_function_calling,
228
+ "log_level": self.config.features.log_level,
229
+ "convert_developer_to_system": self.config.features.convert_developer_to_system
230
+ }
231
+
232
+
233
  config_loader = ConfigLoader()