| from enum import Enum |
| from typing import Any, List, Literal, Optional, Union |
|
|
| from pydantic import BaseModel, Field |
|
|
|
|
| class Role(str, Enum): |
| """Message role options""" |
|
|
| SYSTEM = "system" |
| USER = "user" |
| ASSISTANT = "assistant" |
| TOOL = "tool" |
|
|
|
|
| ROLE_VALUES = tuple(role.value for role in Role) |
| ROLE_TYPE = Literal[ROLE_VALUES] |
|
|
|
|
| class ToolChoice(str, Enum): |
| """Tool choice options""" |
|
|
| NONE = "none" |
| AUTO = "auto" |
| REQUIRED = "required" |
|
|
|
|
| TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) |
| TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] |
|
|
|
|
| class AgentState(str, Enum): |
| """Agent execution states""" |
|
|
| IDLE = "IDLE" |
| RUNNING = "RUNNING" |
| FINISHED = "FINISHED" |
| ERROR = "ERROR" |
|
|
|
|
| class Function(BaseModel): |
| name: str |
| arguments: str |
|
|
|
|
| class ToolCall(BaseModel): |
| """Represents a tool/function call in a message""" |
|
|
| id: str |
| type: str = "function" |
| function: Function |
|
|
|
|
| class Message(BaseModel): |
| """Represents a chat message in the conversation""" |
|
|
| role: ROLE_TYPE = Field(...) |
| content: Optional[str] = Field(default=None) |
| tool_calls: Optional[List[ToolCall]] = Field(default=None) |
| name: Optional[str] = Field(default=None) |
| tool_call_id: Optional[str] = Field(default=None) |
| base64_image: Optional[str] = Field(default=None) |
|
|
| def __add__(self, other) -> List["Message"]: |
| """ζ―ζ Message + list ζ Message + Message ηζδ½""" |
| if isinstance(other, list): |
| return [self] + other |
| elif isinstance(other, Message): |
| return [self, other] |
| else: |
| raise TypeError( |
| f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" |
| ) |
|
|
| def __radd__(self, other) -> List["Message"]: |
| """ζ―ζ list + Message ηζδ½""" |
| if isinstance(other, list): |
| return other + [self] |
| else: |
| raise TypeError( |
| f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" |
| ) |
|
|
| def to_dict(self) -> dict: |
| """Convert message to dictionary format""" |
| message = {"role": self.role} |
| if self.content is not None: |
| message["content"] = self.content |
| if self.tool_calls is not None: |
| message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] |
| if self.name is not None: |
| message["name"] = self.name |
| if self.tool_call_id is not None: |
| message["tool_call_id"] = self.tool_call_id |
| if self.base64_image is not None: |
| message["base64_image"] = self.base64_image |
| return message |
|
|
| @classmethod |
| def user_message( |
| cls, content: str, base64_image: Optional[str] = None |
| ) -> "Message": |
| """Create a user message""" |
| return cls(role=Role.USER, content=content, base64_image=base64_image) |
|
|
| @classmethod |
| def system_message(cls, content: str) -> "Message": |
| """Create a system message""" |
| return cls(role=Role.SYSTEM, content=content) |
|
|
| @classmethod |
| def assistant_message( |
| cls, content: Optional[str] = None, base64_image: Optional[str] = None |
| ) -> "Message": |
| """Create an assistant message""" |
| return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) |
|
|
| @classmethod |
| def tool_message( |
| cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None |
| ) -> "Message": |
| """Create a tool message""" |
| return cls( |
| role=Role.TOOL, |
| content=content, |
| name=name, |
| tool_call_id=tool_call_id, |
| base64_image=base64_image, |
| ) |
|
|
| @classmethod |
| def from_tool_calls( |
| cls, |
| tool_calls: List[Any], |
| content: Union[str, List[str]] = "", |
| base64_image: Optional[str] = None, |
| **kwargs, |
| ): |
| """Create ToolCallsMessage from raw tool calls. |
| |
| Args: |
| tool_calls: Raw tool calls from LLM |
| content: Optional message content |
| base64_image: Optional base64 encoded image |
| """ |
| formatted_calls = [ |
| {"id": call.id, "function": call.function.model_dump(), "type": "function"} |
| for call in tool_calls |
| ] |
| return cls( |
| role=Role.ASSISTANT, |
| content=content, |
| tool_calls=formatted_calls, |
| base64_image=base64_image, |
| **kwargs, |
| ) |
|
|
|
|
| class Memory(BaseModel): |
| messages: List[Message] = Field(default_factory=list) |
| max_messages: int = Field(default=100) |
|
|
| def add_message(self, message: Message) -> None: |
| """Add a message to memory""" |
| self.messages.append(message) |
| |
| if len(self.messages) > self.max_messages: |
| self.messages = self.messages[-self.max_messages :] |
|
|
| def add_messages(self, messages: List[Message]) -> None: |
| """Add multiple messages to memory""" |
| self.messages.extend(messages) |
| |
| if len(self.messages) > self.max_messages: |
| self.messages = self.messages[-self.max_messages :] |
|
|
| def clear(self) -> None: |
| """Clear all messages""" |
| self.messages.clear() |
|
|
| def get_recent_messages(self, n: int) -> List[Message]: |
| """Get n most recent messages""" |
| return self.messages[-n:] |
|
|
| def to_dict_list(self) -> List[dict]: |
| """Convert messages to list of dicts""" |
| return [msg.to_dict() for msg in self.messages] |
|
|