| """Classes for working with the Gemini models.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Iterable |
| import textwrap |
| from typing import Any, Union, overload |
| import reprlib |
|
|
| |
|
|
|
|
| import google.api_core.exceptions |
| from google.generativeai import protos |
| from google.generativeai import client |
|
|
| from google.generativeai import caching |
| from google.generativeai.types import content_types |
| from google.generativeai.types import generation_types |
| from google.generativeai.types import helper_types |
| from google.generativeai.types import safety_types |
|
|
| _USER_ROLE = "user" |
| _MODEL_ROLE = "model" |
|
|
|
|
| class GenerativeModel: |
| """ |
| The `genai.GenerativeModel` class wraps default parameters for calls to |
| `GenerativeModel.generate_content`, `GenerativeModel.count_tokens`, and |
| `GenerativeModel.start_chat`. |
| |
| This family of functionality is designed to support multi-turn conversations, and multimodal |
| requests. What media-types are supported for input and output is model-dependant. |
| |
| >>> import google.generativeai as genai |
| >>> import PIL.Image |
| >>> genai.configure(api_key='YOUR_API_KEY') |
| >>> model = genai.GenerativeModel('models/gemini-1.5-flash') |
| >>> result = model.generate_content('Tell me a story about a magic backpack') |
| >>> result.text |
| "In the quaint little town of Lakeside, there lived a young girl named Lily..." |
| |
| Multimodal input: |
| |
| >>> model = genai.GenerativeModel('models/gemini-1.5-flash') |
| >>> result = model.generate_content([ |
| ... "Give me a recipe for these:", PIL.Image.open('scones.jpeg')]) |
| >>> result.text |
| "**Blueberry Scones** ..." |
| |
| Multi-turn conversation: |
| |
| >>> chat = model.start_chat() |
| >>> response = chat.send_message("Hi, I have some questions for you.") |
| >>> response.text |
| "Sure, I'll do my best to answer your questions..." |
| |
| To list the compatible model names use: |
| |
| >>> for m in genai.list_models(): |
| ... if 'generateContent' in m.supported_generation_methods: |
| ... print(m.name) |
| |
| Arguments: |
| model_name: The name of the model to query. To list compatible models use |
| safety_settings: Sets the default safety filters. This controls which content is blocked |
| by the api before being returned. |
| generation_config: A `genai.GenerationConfig` setting the default generation parameters to |
| use. |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "gemini-1.5-flash-002", |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| system_instruction: content_types.ContentType | None = None, |
| ): |
| if "/" not in model_name: |
| model_name = "models/" + model_name |
| self._model_name = model_name |
| self._safety_settings = safety_types.to_easy_safety_dict(safety_settings) |
| self._generation_config = generation_types.to_generation_config_dict(generation_config) |
| self._tools = content_types.to_function_library(tools) |
|
|
| if tool_config is None: |
| self._tool_config = None |
| else: |
| self._tool_config = content_types.to_tool_config(tool_config) |
|
|
| if system_instruction is None: |
| self._system_instruction = None |
| else: |
| self._system_instruction = content_types.to_content(system_instruction) |
|
|
| self._client = None |
| self._async_client = None |
|
|
| @property |
| def cached_content(self) -> str: |
| return getattr(self, "_cached_content", None) |
|
|
| @property |
| def model_name(self): |
| return self._model_name |
|
|
| def __str__(self): |
| def maybe_text(content): |
| if content and len(content.parts) and (t := content.parts[0].text): |
| return repr(t) |
| return content |
|
|
| return textwrap.dedent( |
| f"""\ |
| genai.GenerativeModel( |
| model_name='{self.model_name}', |
| generation_config={self._generation_config}, |
| safety_settings={self._safety_settings}, |
| tools={self._tools}, |
| system_instruction={maybe_text(self._system_instruction)}, |
| cached_content={self.cached_content} |
| )""" |
| ) |
|
|
| __repr__ = __str__ |
|
|
| def _prepare_request( |
| self, |
| *, |
| contents: content_types.ContentsType, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| tools: content_types.FunctionLibraryType | None, |
| tool_config: content_types.ToolConfigType | None, |
| ) -> protos.GenerateContentRequest: |
| """Creates a `protos.GenerateContentRequest` from raw inputs.""" |
| if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]): |
| raise ValueError( |
| "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantiated with `cached_content` as its context." |
| ) |
|
|
| tools_lib = self._get_tools_lib(tools) |
| if tools_lib is not None: |
| tools_lib = tools_lib.to_proto() |
|
|
| if tool_config is None: |
| tool_config = self._tool_config |
| else: |
| tool_config = content_types.to_tool_config(tool_config) |
|
|
| contents = content_types.to_contents(contents) |
|
|
| generation_config = generation_types.to_generation_config_dict(generation_config) |
| merged_gc = self._generation_config.copy() |
| merged_gc.update(generation_config) |
|
|
| safety_settings = safety_types.to_easy_safety_dict(safety_settings) |
| merged_ss = self._safety_settings.copy() |
| merged_ss.update(safety_settings) |
| merged_ss = safety_types.normalize_safety_settings(merged_ss) |
|
|
| return protos.GenerateContentRequest( |
| model=self._model_name, |
| contents=contents, |
| generation_config=merged_gc, |
| safety_settings=merged_ss, |
| tools=tools_lib, |
| tool_config=tool_config, |
| system_instruction=self._system_instruction, |
| cached_content=self.cached_content, |
| ) |
|
|
| def _get_tools_lib( |
| self, tools: content_types.FunctionLibraryType |
| ) -> content_types.FunctionLibrary | None: |
| if tools is None: |
| return self._tools |
| else: |
| return content_types.to_function_library(tools) |
|
|
| @overload |
| @classmethod |
| def from_cached_content( |
| cls, |
| cached_content: str, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| ) -> GenerativeModel: ... |
|
|
| @overload |
| @classmethod |
| def from_cached_content( |
| cls, |
| cached_content: caching.CachedContent, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| ) -> GenerativeModel: ... |
|
|
| @classmethod |
| def from_cached_content( |
| cls, |
| cached_content: str | caching.CachedContent, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| ) -> GenerativeModel: |
| """Creates a model with `cached_content` as model's context. |
| |
| Args: |
| cached_content: context for the model. |
| generation_config: Overrides for the model's generation config. |
| safety_settings: Overrides for the model's safety settings. |
| |
| Returns: |
| `GenerativeModel` object with `cached_content` as its context. |
| """ |
| if isinstance(cached_content, str): |
| cached_content = caching.CachedContent.get(name=cached_content) |
|
|
| |
| |
| self = cls( |
| model_name=cached_content.model, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| ) |
|
|
| |
| setattr(self, "_cached_content", cached_content.name) |
| return self |
|
|
| def generate_content( |
| self, |
| contents: content_types.ContentsType, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| stream: bool = False, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> generation_types.GenerateContentResponse: |
| """A multipurpose function to generate responses from the model. |
| |
| This `GenerativeModel.generate_content` method can handle multimodal input, and multi-turn |
| conversations. |
| |
| >>> model = genai.GenerativeModel('models/gemini-1.5-flash') |
| >>> response = model.generate_content('Tell me a story about a magic backpack') |
| >>> response.text |
| |
| ### Streaming |
| |
| This method supports streaming with the `stream=True`. The result has the same type as the non streaming case, |
| but you can iterate over the response chunks as they become available: |
| |
| >>> response = model.generate_content('Tell me a story about a magic backpack', stream=True) |
| >>> for chunk in response: |
| ... print(chunk.text) |
| |
| ### Multi-turn |
| |
| This method supports multi-turn chats but is **stateless**: the entire conversation history needs to be sent with each |
| request. This takes some manual management but gives you complete control: |
| |
| >>> messages = [{'role':'user', 'parts': ['hello']}] |
| >>> response = model.generate_content(messages) # "Hello, how can I help" |
| >>> messages.append(response.candidates[0].content) |
| >>> messages.append({'role':'user', 'parts': ['How does quantum physics work?']}) |
| >>> response = model.generate_content(messages) |
| |
| For a simpler multi-turn interface see `GenerativeModel.start_chat`. |
| |
| ### Input type flexibility |
| |
| While the underlying API strictly expects a `list[protos.Content]` objects, this method |
| will convert the user input into the correct type. The hierarchy of types that can be |
| converted is below. Any of these objects can be passed as an equivalent `dict`. |
| |
| * `Iterable[protos.Content]` |
| * `protos.Content` |
| * `Iterable[protos.Part]` |
| * `protos.Part` |
| * `str`, `Image`, or `protos.Blob` |
| |
| In an `Iterable[protos.Content]` each `content` is a separate message. |
| But note that an `Iterable[protos.Part]` is taken as the parts of a single message. |
| |
| Arguments: |
| contents: The contents serving as the model's prompt. |
| generation_config: Overrides for the model's generation config. |
| safety_settings: Overrides for the model's safety settings. |
| stream: If True, yield response chunks as they are generated. |
| tools: `protos.Tools` more info coming soon. |
| request_options: Options for the request. |
| """ |
| if not contents: |
| raise TypeError("contents must not be empty") |
|
|
| request = self._prepare_request( |
| contents=contents, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| tools=tools, |
| tool_config=tool_config, |
| ) |
|
|
| if request.contents and not request.contents[-1].role: |
| request.contents[-1].role = _USER_ROLE |
|
|
| if self._client is None: |
| self._client = client.get_default_generative_client() |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| try: |
| if stream: |
| with generation_types.rewrite_stream_error(): |
| iterator = self._client.stream_generate_content( |
| request, |
| **request_options, |
| ) |
| return generation_types.GenerateContentResponse.from_iterator(iterator) |
| else: |
| response = self._client.generate_content( |
| request, |
| **request_options, |
| ) |
| return generation_types.GenerateContentResponse.from_response(response) |
| except google.api_core.exceptions.InvalidArgument as e: |
| if e.message.startswith("Request payload size exceeds the limit:"): |
| e.message += ( |
| " The file size is too large. Please use the File API to upload your files instead. " |
| "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" |
| ) |
| raise |
|
|
| async def generate_content_async( |
| self, |
| contents: content_types.ContentsType, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| stream: bool = False, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> generation_types.AsyncGenerateContentResponse: |
| """The async version of `GenerativeModel.generate_content`.""" |
| if not contents: |
| raise TypeError("contents must not be empty") |
|
|
| request = self._prepare_request( |
| contents=contents, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| tools=tools, |
| tool_config=tool_config, |
| ) |
|
|
| if request.contents and not request.contents[-1].role: |
| request.contents[-1].role = _USER_ROLE |
|
|
| if self._async_client is None: |
| self._async_client = client.get_default_generative_async_client() |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| try: |
| if stream: |
| with generation_types.rewrite_stream_error(): |
| iterator = await self._async_client.stream_generate_content( |
| request, |
| **request_options, |
| ) |
| return await generation_types.AsyncGenerateContentResponse.from_aiterator(iterator) |
| else: |
| response = await self._async_client.generate_content( |
| request, |
| **request_options, |
| ) |
| return generation_types.AsyncGenerateContentResponse.from_response(response) |
| except google.api_core.exceptions.InvalidArgument as e: |
| if e.message.startswith("Request payload size exceeds the limit:"): |
| e.message += ( |
| " The file size is too large. Please use the File API to upload your files instead. " |
| "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" |
| ) |
| raise |
|
|
| |
| def count_tokens( |
| self, |
| contents: content_types.ContentsType = None, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> protos.CountTokensResponse: |
| if request_options is None: |
| request_options = {} |
|
|
| if self._client is None: |
| self._client = client.get_default_generative_client() |
|
|
| request = protos.CountTokensRequest( |
| model=self.model_name, |
| generate_content_request=self._prepare_request( |
| contents=contents, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| tools=tools, |
| tool_config=tool_config, |
| )) |
| return self._client.count_tokens(request, **request_options) |
|
|
| async def count_tokens_async( |
| self, |
| contents: content_types.ContentsType = None, |
| *, |
| generation_config: generation_types.GenerationConfigType | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> protos.CountTokensResponse: |
| if request_options is None: |
| request_options = {} |
|
|
| if self._async_client is None: |
| self._async_client = client.get_default_generative_async_client() |
|
|
| request = protos.CountTokensRequest( |
| model=self.model_name, |
| generate_content_request=self._prepare_request( |
| contents=contents, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| tools=tools, |
| tool_config=tool_config, |
| )) |
| return await self._async_client.count_tokens(request, **request_options) |
|
|
| |
|
|
| def start_chat( |
| self, |
| *, |
| history: Iterable[content_types.StrictContentType] | None = None, |
| enable_automatic_function_calling: bool = False, |
| ) -> ChatSession: |
| """Returns a `genai.ChatSession` attached to this model. |
| |
| >>> model = genai.GenerativeModel() |
| >>> chat = model.start_chat(history=[...]) |
| >>> response = chat.send_message("Hello?") |
| |
| Arguments: |
| history: An iterable of `protos.Content` objects, or equivalents to initialize the session. |
| """ |
| if self._generation_config.get("candidate_count", 1) > 1: |
| raise ValueError( |
| "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." |
| ) |
| return ChatSession( |
| model=self, |
| history=history, |
| enable_automatic_function_calling=enable_automatic_function_calling, |
| ) |
|
|
|
|
| class ChatSession: |
| """Contains an ongoing conversation with the model. |
| |
| >>> model = genai.GenerativeModel('models/gemini-1.5-flash') |
| >>> chat = model.start_chat() |
| >>> response = chat.send_message("Hello") |
| >>> print(response.text) |
| >>> response = chat.send_message("Hello again") |
| >>> print(response.text) |
| >>> response = chat.send_message(... |
| |
| This `ChatSession` object collects the messages sent and received, in its |
| `ChatSession.history` attribute. |
| |
| Arguments: |
| model: The model to use in the chat. |
| history: A chat history to initialize the object with. |
| """ |
|
|
| def __init__( |
| self, |
| model: GenerativeModel, |
| history: Iterable[content_types.StrictContentType] | None = None, |
| enable_automatic_function_calling: bool = False, |
| ): |
| self.model: GenerativeModel = model |
| self._history: list[protos.Content] = content_types.to_contents(history) |
| self._last_sent: protos.Content | None = None |
| self._last_received: generation_types.BaseGenerateContentResponse | None = None |
| self.enable_automatic_function_calling = enable_automatic_function_calling |
|
|
| def send_message( |
| self, |
| content: content_types.ContentType, |
| *, |
| generation_config: generation_types.GenerationConfigType = None, |
| safety_settings: safety_types.SafetySettingOptions = None, |
| stream: bool = False, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> generation_types.GenerateContentResponse: |
| """Sends the conversation history with the added message and returns the model's response. |
| |
| Appends the request and response to the conversation history. |
| |
| >>> model = genai.GenerativeModel('models/gemini-1.5-flash') |
| >>> chat = model.start_chat() |
| >>> response = chat.send_message("Hello") |
| >>> print(response.text) |
| "Hello! How can I assist you today?" |
| >>> len(chat.history) |
| 2 |
| |
| Call it with `stream=True` to receive response chunks as they are generated: |
| |
| >>> chat = model.start_chat() |
| >>> response = chat.send_message("Explain quantum physics", stream=True) |
| >>> for chunk in response: |
| ... print(chunk.text, end='') |
| |
| Once iteration over chunks is complete, the `response` and `ChatSession` are in states identical to the |
| `stream=False` case. Some properties are not available until iteration is complete. |
| |
| Like `GenerativeModel.generate_content` this method lets you override the model's `generation_config` and |
| `safety_settings`. |
| |
| Arguments: |
| content: The message contents. |
| generation_config: Overrides for the model's generation config. |
| safety_settings: Overrides for the model's safety settings. |
| stream: If True, yield response chunks as they are generated. |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if self.enable_automatic_function_calling and stream: |
| raise NotImplementedError( |
| "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." |
| ) |
|
|
| tools_lib = self.model._get_tools_lib(tools) |
|
|
| content = content_types.to_content(content) |
|
|
| if not content.role: |
| content.role = _USER_ROLE |
|
|
| history = self.history[:] |
| history.append(content) |
|
|
| generation_config = generation_types.to_generation_config_dict(generation_config) |
| if generation_config.get("candidate_count", 1) > 1: |
| raise ValueError( |
| "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." |
| ) |
|
|
| response = self.model.generate_content( |
| contents=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools=tools_lib, |
| tool_config=tool_config, |
| request_options=request_options, |
| ) |
|
|
| self._check_response(response=response, stream=stream) |
|
|
| if self.enable_automatic_function_calling and tools_lib is not None: |
| self.history, content, response = self._handle_afc( |
| response=response, |
| history=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools_lib=tools_lib, |
| request_options=request_options, |
| ) |
|
|
| self._last_sent = content |
| self._last_received = response |
|
|
| return response |
|
|
| def _check_response(self, *, response, stream): |
| if response.prompt_feedback.block_reason: |
| raise generation_types.BlockedPromptException(response.prompt_feedback) |
|
|
| if not stream: |
| if response.candidates[0].finish_reason not in ( |
| protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, |
| protos.Candidate.FinishReason.STOP, |
| protos.Candidate.FinishReason.MAX_TOKENS, |
| ): |
| raise generation_types.StopCandidateException(response.candidates[0]) |
|
|
| def _get_function_calls(self, response) -> list[protos.FunctionCall]: |
| candidates = response.candidates |
| if len(candidates) != 1: |
| raise ValueError( |
| f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided." |
| ) |
| parts = candidates[0].content.parts |
| function_calls = [part.function_call for part in parts if part and "function_call" in part] |
| return function_calls |
|
|
| def _handle_afc( |
| self, |
| *, |
| response, |
| history, |
| generation_config, |
| safety_settings, |
| stream, |
| tools_lib, |
| request_options, |
| ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: |
|
|
| while function_calls := self._get_function_calls(response): |
| if not all(callable(tools_lib[fc]) for fc in function_calls): |
| break |
| history.append(response.candidates[0].content) |
|
|
| function_response_parts: list[protos.Part] = [] |
| for fc in function_calls: |
| fr = tools_lib(fc) |
| assert fr is not None, ( |
| "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " |
| "is not callable, which is checked earlier in the code." |
| ) |
| function_response_parts.append(fr) |
|
|
| send = protos.Content(role=_USER_ROLE, parts=function_response_parts) |
| history.append(send) |
|
|
| response = self.model.generate_content( |
| contents=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools=tools_lib, |
| request_options=request_options, |
| ) |
|
|
| self._check_response(response=response, stream=stream) |
|
|
| *history, content = history |
| return history, content, response |
|
|
| async def send_message_async( |
| self, |
| content: content_types.ContentType, |
| *, |
| generation_config: generation_types.GenerationConfigType = None, |
| safety_settings: safety_types.SafetySettingOptions = None, |
| stream: bool = False, |
| tools: content_types.FunctionLibraryType | None = None, |
| tool_config: content_types.ToolConfigType | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> generation_types.AsyncGenerateContentResponse: |
| """The async version of `ChatSession.send_message`.""" |
| if request_options is None: |
| request_options = {} |
|
|
| if self.enable_automatic_function_calling and stream: |
| raise NotImplementedError( |
| "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." |
| ) |
|
|
| tools_lib = self.model._get_tools_lib(tools) |
|
|
| content = content_types.to_content(content) |
|
|
| if not content.role: |
| content.role = _USER_ROLE |
|
|
| history = self.history[:] |
| history.append(content) |
|
|
| generation_config = generation_types.to_generation_config_dict(generation_config) |
| if generation_config.get("candidate_count", 1) > 1: |
| raise ValueError( |
| "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." |
| ) |
|
|
| response = await self.model.generate_content_async( |
| contents=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools=tools_lib, |
| tool_config=tool_config, |
| request_options=request_options, |
| ) |
|
|
| self._check_response(response=response, stream=stream) |
|
|
| if self.enable_automatic_function_calling and tools_lib is not None: |
| self.history, content, response = await self._handle_afc_async( |
| response=response, |
| history=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools_lib=tools_lib, |
| request_options=request_options, |
| ) |
|
|
| self._last_sent = content |
| self._last_received = response |
|
|
| return response |
|
|
| async def _handle_afc_async( |
| self, |
| *, |
| response, |
| history, |
| generation_config, |
| safety_settings, |
| stream, |
| tools_lib, |
| request_options, |
| ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: |
|
|
| while function_calls := self._get_function_calls(response): |
| if not all(callable(tools_lib[fc]) for fc in function_calls): |
| break |
| history.append(response.candidates[0].content) |
|
|
| function_response_parts: list[protos.Part] = [] |
| for fc in function_calls: |
| fr = tools_lib(fc) |
| assert fr is not None, ( |
| "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " |
| "is not callable, which is checked earlier in the code." |
| ) |
| function_response_parts.append(fr) |
|
|
| send = protos.Content(role=_USER_ROLE, parts=function_response_parts) |
| history.append(send) |
|
|
| response = await self.model.generate_content_async( |
| contents=history, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| stream=stream, |
| tools=tools_lib, |
| request_options=request_options, |
| ) |
|
|
| self._check_response(response=response, stream=stream) |
|
|
| *history, content = history |
| return history, content, response |
|
|
| def __copy__(self): |
| return ChatSession( |
| model=self.model, |
| |
| history=list(self.history), |
| ) |
|
|
| def rewind(self) -> tuple[protos.Content, protos.Content]: |
| """Removes the last request/response pair from the chat history.""" |
| if self._last_received is None: |
| result = self._history.pop(-2), self._history.pop() |
| return result |
| else: |
| result = self._last_sent, self._last_received.candidates[0].content |
| self._last_sent = None |
| self._last_received = None |
| return result |
|
|
| @property |
| def last(self) -> generation_types.BaseGenerateContentResponse | None: |
| """returns the last received `genai.GenerateContentResponse`""" |
| return self._last_received |
|
|
| @property |
| def history(self) -> list[protos.Content]: |
| """The chat history.""" |
| last = self._last_received |
| if last is None: |
| return self._history |
|
|
| if last.candidates[0].finish_reason not in ( |
| protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, |
| protos.Candidate.FinishReason.STOP, |
| protos.Candidate.FinishReason.MAX_TOKENS, |
| ): |
| error = generation_types.StopCandidateException(last.candidates[0]) |
| last._error = error |
|
|
| if last._error is not None: |
| raise generation_types.BrokenResponseError( |
| "Unable to build a coherent chat history due to a broken streaming response. " |
| "Refer to the previous exception for details. " |
| "To inspect the last response object, use `chat.last`. " |
| "To remove the last request/response `Content` objects from the chat, " |
| "call `last_send, last_received = chat.rewind()` and continue without it." |
| ) from last._error |
|
|
| sent = self._last_sent |
| received = last.candidates[0].content |
| if not received.role: |
| received.role = _MODEL_ROLE |
| self._history.extend([sent, received]) |
|
|
| self._last_sent = None |
| self._last_received = None |
|
|
| return self._history |
|
|
| @history.setter |
| def history(self, history): |
| self._history = content_types.to_contents(history) |
| self._last_sent = None |
| self._last_received = None |
|
|
| def __repr__(self) -> str: |
| _dict_repr = reprlib.Repr() |
| _model = str(self.model).replace("\n", "\n" + " " * 4) |
|
|
| def content_repr(x): |
| return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" |
|
|
| try: |
| history = list(self.history) |
| except (generation_types.BrokenResponseError, generation_types.IncompleteIterationError): |
| history = list(self._history) |
|
|
| if self._last_sent is not None: |
| history.append(self._last_sent) |
| history = [content_repr(x) for x in history] |
|
|
| last_received = self._last_received |
| if last_received is not None: |
| if last_received._error is not None: |
| history.append("<STREAMING ERROR>") |
| else: |
| history.append("<STREAMING IN PROGRESS>") |
|
|
| _history = ",\n " + f"history=[{', '.join(history)}]\n)" |
|
|
| return ( |
| textwrap.dedent( |
| f"""\ |
| ChatSession( |
| model=""" |
| ) |
| + _model |
| + _history |
| ) |
|
|