| import yaml |
| from pydantic import BaseModel, model_validator |
| from typing import Optional, Type, Union, List, Any |
| |
| from .models.model_configs import LLMConfig |
| from .core.registry import MODEL_REGISTRY |
|
|
|
|
| class Config(BaseModel): |
|
|
| llm_config: dict |
| agents: Optional[Union[str, List[dict]]] = [] |
| model_config = {"arbitrary_types_allowed": True, "extra": "allow", "protected_namespaces": ()} |
|
|
| @classmethod |
| def from_file(cls, path: str): |
| with open(path, mode="r", encoding="utf-8") as file: |
| data = yaml.safe_load(file.read()) |
| config = cls.model_validate(data) |
| return config |
| |
| @property |
| def kwargs(self): |
| return self.model_extra |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def validate_config_data(cls, data: Any) -> Any: |
|
|
| |
| llm_config_data = data.get("llm_config", None) |
| if not llm_config_data: |
| raise ValueError("config file must contain 'llm_config'") |
| data["llm_config"] = cls.process_llm_config(data=data["llm_config"]) |
| |
| |
| agents_data = data.get("agents", None) |
| if agents_data: |
| data["agents"] = cls.process_agents_data(agents=agents_data, llm_config=data["llm_config"]) |
|
|
| return data |
|
|
| @classmethod |
| def process_llm_config(cls, data: dict) -> dict: |
|
|
| llm_type = data.get("llm_type", None) |
| if not llm_type: |
| raise ValueError("must specify `llm_type` in in `llm_config`!") |
| llm_config_cls: Type[LLMConfig] = MODEL_REGISTRY.get_model_config(llm_type) |
| if "class_name" in data: |
| assert data["class_name"] == llm_config_cls.__name__, \ |
| "the 'class_name' specified in 'llm_config' ({}) doesn't match the LLMConfig class ({}) registered for {} model. You should either remove 'class_name' or set it to {}.".format( |
| data["class_name"], llm_config_cls.__name__, llm_type, llm_config_cls.__name__ |
| ) |
| else: |
| data["class_name"] = llm_config_cls.__name__ |
| |
| return data |
| |
| @classmethod |
| def process_agents_data(cls, agents: List[dict], llm_config=dict) -> List[dict]: |
|
|
| for agent in agents: |
| if "llm_config" not in agent: |
| agent["llm_config"] = llm_config |
| return agents |
|
|
|
|