| |
| |
|
|
| from typing import Any, ClassVar, Literal, cast |
|
|
| from pydantic import ( |
| BaseModel, |
| GetCoreSchemaHandler, |
| PrivateAttr, |
| model_serializer, |
| model_validator, |
| ) |
| from pydantic_core import core_schema |
|
|
|
|
| class ContentPart(BaseModel): |
| """A part of the content in a message.""" |
|
|
| __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {} |
|
|
| type: Literal["text", "think", "image_url", "audio_url"] |
|
|
| def __init_subclass__(cls, **kwargs: Any) -> None: |
| super().__init_subclass__(**kwargs) |
|
|
| invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" |
|
|
| type_value = getattr(cls, "type", None) |
| if type_value is None or not isinstance(type_value, str): |
| raise ValueError(invalid_subclass_error_msg) |
|
|
| cls.__content_part_registry[type_value] = cls |
|
|
| @classmethod |
| def __get_pydantic_core_schema__( |
| cls, source_type: Any, handler: GetCoreSchemaHandler |
| ) -> core_schema.CoreSchema: |
| |
| if cls.__name__ == "ContentPart": |
|
|
| def validate_content_part(value: Any) -> Any: |
| |
| if hasattr(value, "__class__") and issubclass(value.__class__, cls): |
| return value |
|
|
| |
| if isinstance(value, dict) and "type" in value: |
| type_value: Any | None = cast(dict[str, Any], value).get("type") |
| if not isinstance(type_value, str): |
| raise ValueError(f"Cannot validate {value} as ContentPart") |
| target_class = cls.__content_part_registry[type_value] |
| return target_class.model_validate(value) |
|
|
| raise ValueError(f"Cannot validate {value} as ContentPart") |
|
|
| return core_schema.no_info_plain_validator_function(validate_content_part) |
|
|
| |
| return handler(source_type) |
|
|
|
|
| class TextPart(ContentPart): |
| """ |
| >>> TextPart(text="Hello, world!").model_dump() |
| {'type': 'text', 'text': 'Hello, world!'} |
| """ |
|
|
| type: str = "text" |
| text: str |
|
|
|
|
| class ThinkPart(ContentPart): |
| """ |
| >>> ThinkPart(think="I think I need to think about this.").model_dump() |
| {'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None} |
| """ |
|
|
| type: str = "think" |
| think: str |
| encrypted: str | None = None |
| """Encrypted thinking content, or signature.""" |
|
|
| def merge_in_place(self, other: Any) -> bool: |
| if not isinstance(other, ThinkPart): |
| return False |
| if self.encrypted: |
| return False |
| self.think += other.think |
| if other.encrypted: |
| self.encrypted = other.encrypted |
| return True |
|
|
|
|
| class ImageURLPart(ContentPart): |
| """ |
| >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump() |
| {'type': 'image_url', 'image_url': 'http://example.com/image.jpg'} |
| """ |
|
|
| class ImageURL(BaseModel): |
| url: str |
| """The URL of the image, can be data URI scheme like `data:image/png;base64,...`.""" |
| id: str | None = None |
| """The ID of the image, to allow LLMs to distinguish different images.""" |
|
|
| type: str = "image_url" |
| image_url: ImageURL |
|
|
|
|
| class AudioURLPart(ContentPart): |
| """ |
| >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump() |
| {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}} |
| """ |
|
|
| class AudioURL(BaseModel): |
| url: str |
| """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`.""" |
| id: str | None = None |
| """The ID of the audio, to allow LLMs to distinguish different audios.""" |
|
|
| type: str = "audio_url" |
| audio_url: AudioURL |
|
|
|
|
| class ToolCall(BaseModel): |
| """ |
| A tool call requested by the assistant. |
| |
| >>> ToolCall( |
| ... id="123", |
| ... function=ToolCall.FunctionBody( |
| ... name="function", |
| ... arguments="{}" |
| ... ), |
| ... ).model_dump() |
| {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}} |
| """ |
|
|
| class FunctionBody(BaseModel): |
| name: str |
| arguments: str | None |
|
|
| type: Literal["function"] = "function" |
|
|
| id: str |
| """The ID of the tool call.""" |
| function: FunctionBody |
| """The function body of the tool call.""" |
| extra_content: dict[str, Any] | None = None |
| """Extra metadata for the tool call.""" |
|
|
| @model_serializer(mode="wrap") |
| def serialize(self, handler): |
| data = handler(self) |
| if self.extra_content is None: |
| data.pop("extra_content", None) |
| return data |
|
|
|
|
| class ToolCallPart(BaseModel): |
| """A part of the tool call.""" |
|
|
| arguments_part: str | None = None |
| """A part of the arguments of the tool call.""" |
|
|
|
|
| class Message(BaseModel): |
| """A message in a conversation.""" |
|
|
| role: Literal[ |
| "system", |
| "user", |
| "assistant", |
| "tool", |
| ] |
|
|
| content: str | list[ContentPart] | None = None |
| """The content of the message.""" |
|
|
| tool_calls: list[ToolCall] | list[dict] | None = None |
| """The tool calls of the message.""" |
|
|
| tool_call_id: str | None = None |
| """The ID of the tool call.""" |
|
|
| _no_save: bool = PrivateAttr(default=False) |
|
|
| @model_validator(mode="after") |
| def check_content_required(self): |
| |
| if self.role == "assistant" and self.tool_calls is not None: |
| return self |
|
|
| |
| if self.content is None: |
| raise ValueError( |
| "content is required unless role='assistant' and tool_calls is not None" |
| ) |
| return self |
|
|
| @model_serializer(mode="wrap") |
| def serialize(self, handler): |
| data = handler(self) |
| if self.tool_calls is None: |
| data.pop("tool_calls", None) |
| if self.tool_call_id is None: |
| data.pop("tool_call_id", None) |
| return data |
|
|
|
|
| class AssistantMessageSegment(Message): |
| """A message segment from the assistant.""" |
|
|
| role: Literal["assistant"] = "assistant" |
|
|
|
|
| class ToolCallMessageSegment(Message): |
| """A message segment representing a tool call.""" |
|
|
| role: Literal["tool"] = "tool" |
|
|
|
|
| class UserMessageSegment(Message): |
| """A message segment from the user.""" |
|
|
| role: Literal["user"] = "user" |
|
|
|
|
| class SystemMessageSegment(Message): |
| """A message segment from the system.""" |
|
|
| role: Literal["system"] = "system" |
|
|