File size: 1,666 Bytes
81ff144
 
 
 
 
 
 
e047946
81ff144
 
 
 
 
 
 
 
 
e047946
 
81ff144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Dict, Type
from .base import BaseAgent
from .openai_agent import OpenAIAgent
from .amd_agent import AMDAgent
from .groq_agent import GroqAgent
from .gemini_agent import GeminiAgent
from .local_agent import LocalAgent
from .digitalocean_agent import DigitalOceanAgent
from services.config import settings

# Map of providers to their respective classes
PROVIDER_MAP: Dict[str, Type[BaseAgent]] = {
    "openai": OpenAIAgent,
    "amd": AMDAgent,
    "groq": GroqAgent,
    "gemini": GeminiAgent,
    "local": LocalAgent,
    "ollama": LocalAgent,
    "digitalocean": DigitalOceanAgent
}

class AgentFactory:
    @staticmethod
    def get_agent(provider: str, name: str, role: str, model: str, system_prompt: str = None) -> BaseAgent:
        """
        Instantiates the appropriate agent based on the provider string.
        Includes a fallback to Groq if OpenAI is requested but no key is provided.
        """
        provider = provider.lower()
        
        # Groq Redirection Logic
        if provider == "openai" and not settings.OPENAI_API_KEY:
            # Check if we have a Groq key before redirecting
            if settings.GROQ_API_KEY:
                provider = "groq"
                model = "llama-3.3-70b-versatile" # Robust fallback model
            else:
                # If neither is available, let it fail with the original provider
                pass

        agent_class = PROVIDER_MAP.get(provider)
        
        if not agent_class:
            raise ValueError(f"Unsupported agent provider: {provider}")
            
        return agent_class(name=name, role=role, model=model, system_prompt=system_prompt)