| from typing import Dict, Type |
| from .base import BaseAgent |
| from .openai_agent import OpenAIAgent |
| from .amd_agent import AMDAgent |
| from .groq_agent import GroqAgent |
| from .gemini_agent import GeminiAgent |
| from .local_agent import LocalAgent |
| from .digitalocean_agent import DigitalOceanAgent |
| from services.config import settings |
|
|
| |
| PROVIDER_MAP: Dict[str, Type[BaseAgent]] = { |
| "openai": OpenAIAgent, |
| "amd": AMDAgent, |
| "groq": GroqAgent, |
| "gemini": GeminiAgent, |
| "local": LocalAgent, |
| "ollama": LocalAgent, |
| "digitalocean": DigitalOceanAgent |
| } |
|
|
| class AgentFactory: |
| @staticmethod |
| def get_agent(provider: str, name: str, role: str, model: str, system_prompt: str = None) -> BaseAgent: |
| """ |
| Instantiates the appropriate agent based on the provider string. |
| Includes a fallback to Groq if OpenAI is requested but no key is provided. |
| """ |
| provider = provider.lower() |
| |
| |
| if provider == "openai" and not settings.OPENAI_API_KEY: |
| |
| if settings.GROQ_API_KEY: |
| provider = "groq" |
| model = "llama-3.3-70b-versatile" |
| else: |
| |
| pass |
|
|
| agent_class = PROVIDER_MAP.get(provider) |
| |
| if not agent_class: |
| raise ValueError(f"Unsupported agent provider: {provider}") |
| |
| return agent_class(name=name, role=role, model=model, system_prompt=system_prompt) |
|
|