pkgprateek's picture
fix: configure app host and dockerfile for HF deployment
8ac8a9d
"""Base agent class for all agents in the system."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from src.utils.config import get_settings
from src.utils.cost_tracker import CostTracker
from src.utils.logging import setup_logger
logger = setup_logger(__name__)
class BaseAgent(ABC):
"""
Base class for all agents in the multi-agent system.
Provides common functionality for LLM interaction, cost tracking,
and error handling.
"""
def __init__(
self,
name: str,
model: Optional[str] = None,
temperature: float = 0.7,
cost_tracker: Optional[CostTracker] = None,
):
"""
Initialize base agent.
Args:
name: Agent name for logging
model: LLM model to use (defaults to config default)
temperature: LLM sampling temperature (0-1)
cost_tracker: Optional cost tracker instance
"""
self.name = name
self.cost_tracker = cost_tracker or CostTracker()
settings = get_settings()
self.model_name = model or settings.default_model
# Initialize LLM via OpenRouter
self.llm = ChatOpenAI(
model=self.model_name,
temperature=temperature,
openai_api_key=settings.openrouter_api_key, # type: ignore[call-arg]
openai_api_base=settings.openrouter_base_url, # type: ignore[call-arg]
)
logger.info(f"Initialized {name} with model {self.model_name}")
@abstractmethod
def get_system_prompt(self) -> str:
"""
Get the system prompt for this agent.
Returns:
System prompt string
"""
pass
@abstractmethod
async def run(self, **kwargs) -> Dict[str, Any]:
"""
Execute the agent's main task.
Args:
**kwargs: Agent-specific parameters
Returns:
Dictionary with results
"""
pass
async def _invoke_llm(
self,
messages: list[BaseMessage],
**llm_kwargs,
) -> str:
"""
Invoke LLM and track costs.
Args:
messages: List of messages to send
**llm_kwargs: Additional LLM parameters
Returns:
LLM response text
"""
try:
response = await self.llm.ainvoke(messages, **llm_kwargs)
# Track usage if available
if hasattr(response, "response_metadata"):
usage = response.response_metadata.get("usage", {})
if usage:
self.cost_tracker.track_usage(
model=self.model_name,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
)
logger.info(
f"{self.name} LLM call complete",
extra={
"extra_fields": {
"model": self.model_name,
"total_cost": self.cost_tracker.total_cost,
}
},
)
return str(response.content)
except Exception as e:
logger.error(f"{self.name} LLM call failed: {e}")
raise
def _create_messages(
self,
user_message: str,
system_prompt: Optional[str] = None,
) -> list[BaseMessage]:
"""
Create message list for LLM.
Args:
user_message: User message content
system_prompt: Optional system prompt (uses default if None)
Returns:
List of messages
"""
messages: list[BaseMessage] = []
# Add system message
prompt = system_prompt or self.get_system_prompt()
messages.append(SystemMessage(content=prompt))
# Add user message
messages.append(HumanMessage(content=user_message))
return messages
def get_cost_summary(self) -> Dict[str, Any]:
"""
Get cost summary for this agent's operations.
Returns:
Cost summary dictionary
"""
return self.cost_tracker.get_summary()