File size: 4,354 Bytes
0b2427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ac8a9d
 
0b2427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ac8a9d
0b2427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ac8a9d
0b2427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Base agent class for all agents in the system."""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI

from src.utils.config import get_settings
from src.utils.cost_tracker import CostTracker
from src.utils.logging import setup_logger

logger = setup_logger(__name__)


class BaseAgent(ABC):
    """
    Base class for all agents in the multi-agent system.

    Provides common functionality for LLM interaction, cost tracking,
    and error handling.
    """

    def __init__(
        self,
        name: str,
        model: Optional[str] = None,
        temperature: float = 0.7,
        cost_tracker: Optional[CostTracker] = None,
    ):
        """
        Initialize base agent.

        Args:
            name: Agent name for logging
            model: LLM model to use (defaults to config default)
            temperature: LLM sampling temperature (0-1)
            cost_tracker: Optional cost tracker instance
        """
        self.name = name
        self.cost_tracker = cost_tracker or CostTracker()

        settings = get_settings()
        self.model_name = model or settings.default_model

        # Initialize LLM via OpenRouter
        self.llm = ChatOpenAI(
            model=self.model_name,
            temperature=temperature,
            openai_api_key=settings.openrouter_api_key,  # type: ignore[call-arg]
            openai_api_base=settings.openrouter_base_url,  # type: ignore[call-arg]
        )

        logger.info(f"Initialized {name} with model {self.model_name}")

    @abstractmethod
    def get_system_prompt(self) -> str:
        """
        Get the system prompt for this agent.

        Returns:
            System prompt string
        """
        pass

    @abstractmethod
    async def run(self, **kwargs) -> Dict[str, Any]:
        """
        Execute the agent's main task.

        Args:
            **kwargs: Agent-specific parameters

        Returns:
            Dictionary with results
        """
        pass

    async def _invoke_llm(
        self,
        messages: list[BaseMessage],
        **llm_kwargs,
    ) -> str:
        """
        Invoke LLM and track costs.

        Args:
            messages: List of messages to send
            **llm_kwargs: Additional LLM parameters

        Returns:
            LLM response text
        """
        try:
            response = await self.llm.ainvoke(messages, **llm_kwargs)

            # Track usage if available
            if hasattr(response, "response_metadata"):
                usage = response.response_metadata.get("usage", {})
                if usage:
                    self.cost_tracker.track_usage(
                        model=self.model_name,
                        input_tokens=usage.get("prompt_tokens", 0),
                        output_tokens=usage.get("completion_tokens", 0),
                    )

            logger.info(
                f"{self.name} LLM call complete",
                extra={
                    "extra_fields": {
                        "model": self.model_name,
                        "total_cost": self.cost_tracker.total_cost,
                    }
                },
            )

            return str(response.content)

        except Exception as e:
            logger.error(f"{self.name} LLM call failed: {e}")
            raise

    def _create_messages(
        self,
        user_message: str,
        system_prompt: Optional[str] = None,
    ) -> list[BaseMessage]:
        """
        Create message list for LLM.

        Args:
            user_message: User message content
            system_prompt: Optional system prompt (uses default if None)

        Returns:
            List of messages
        """
        messages: list[BaseMessage] = []

        # Add system message
        prompt = system_prompt or self.get_system_prompt()
        messages.append(SystemMessage(content=prompt))

        # Add user message
        messages.append(HumanMessage(content=user_message))

        return messages

    def get_cost_summary(self) -> Dict[str, Any]:
        """
        Get cost summary for this agent's operations.

        Returns:
            Cost summary dictionary
        """
        return self.cost_tracker.get_summary()