| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import datetime |
| import textwrap |
| from typing import Iterable, Optional |
|
|
| from google.generativeai import protos |
| from google.generativeai.types import caching_types |
| from google.generativeai.types import content_types |
| from google.generativeai.client import get_default_cache_client |
|
|
| from google.protobuf import field_mask_pb2 |
|
|
| _USER_ROLE = "user" |
| _MODEL_ROLE = "model" |
|
|
|
|
| class CachedContent: |
| """Cached content resource.""" |
|
|
| def __init__(self, name): |
| """Fetches a `CachedContent` resource. |
| |
| Identical to `CachedContent.get`. |
| |
| Args: |
| name: The resource name referring to the cached content. |
| """ |
| client = get_default_cache_client() |
|
|
| if "cachedContents/" not in name: |
| name = "cachedContents/" + name |
|
|
| request = protos.GetCachedContentRequest(name=name) |
| response = client.get_cached_content(request) |
| self._proto = response |
|
|
| @property |
| def name(self) -> str: |
| return self._proto.name |
|
|
| @property |
| def model(self) -> str: |
| return self._proto.model |
|
|
| @property |
| def display_name(self) -> str: |
| return self._proto.display_name |
|
|
| @property |
| def usage_metadata(self) -> protos.CachedContent.UsageMetadata: |
| return self._proto.usage_metadata |
|
|
| @property |
| def create_time(self) -> datetime.datetime: |
| return self._proto.create_time |
|
|
| @property |
| def update_time(self) -> datetime.datetime: |
| return self._proto.update_time |
|
|
| @property |
| def expire_time(self) -> datetime.datetime: |
| return self._proto.expire_time |
|
|
| def __str__(self): |
| return textwrap.dedent( |
| f"""\ |
| CachedContent( |
| name='{self.name}', |
| model='{self.model}', |
| display_name='{self.display_name}', |
| usage_metadata={'{'} |
| 'total_token_count': {self.usage_metadata.total_token_count}, |
| {'}'}, |
| create_time={self.create_time}, |
| update_time={self.update_time}, |
| expire_time={self.expire_time} |
| )""" |
| ) |
|
|
| __repr__ = __str__ |
|
|
| @classmethod |
| def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent: |
| """Creates an instance of CachedContent form an object, without calling `get`.""" |
| self = cls.__new__(cls) |
| self._proto = protos.CachedContent() |
| self._update(obj) |
| return self |
|
|
| def _update(self, updates): |
| """Updates this instance inplace, does not call the API's `update` method""" |
| if isinstance(updates, CachedContent): |
| updates = updates._proto |
|
|
| if not isinstance(updates, dict): |
| updates = type(updates).to_dict(updates, including_default_value_fields=False) |
|
|
| for key, value in updates.items(): |
| setattr(self._proto, key, value) |
|
|
| @staticmethod |
| def _prepare_create_request( |
| model: str, |
| *, |
| display_name: str | None = None, |
| system_instruction: Optional[content_types.ContentType] = None, |
| contents: Optional[content_types.ContentsType] = None, |
| tools: Optional[content_types.FunctionLibraryType] = None, |
| tool_config: Optional[content_types.ToolConfigType] = None, |
| ttl: Optional[caching_types.TTLTypes] = None, |
| expire_time: Optional[caching_types.ExpireTimeTypes] = None, |
| ) -> protos.CreateCachedContentRequest: |
| """Prepares a CreateCachedContentRequest.""" |
| if ttl and expire_time: |
| raise ValueError( |
| "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." |
| ) |
|
|
| if "/" not in model: |
| model = "models/" + model |
|
|
| if display_name and len(display_name) > 128: |
| raise ValueError("`display_name` must be no more than 128 unicode characters.") |
|
|
| if system_instruction: |
| system_instruction = content_types.to_content(system_instruction) |
|
|
| tools_lib = content_types.to_function_library(tools) |
| if tools_lib: |
| tools_lib = tools_lib.to_proto() |
|
|
| if tool_config: |
| tool_config = content_types.to_tool_config(tool_config) |
|
|
| if contents: |
| contents = content_types.to_contents(contents) |
| if not contents[-1].role: |
| contents[-1].role = _USER_ROLE |
|
|
| ttl = caching_types.to_optional_ttl(ttl) |
| expire_time = caching_types.to_optional_expire_time(expire_time) |
|
|
| cached_content = protos.CachedContent( |
| model=model, |
| display_name=display_name, |
| system_instruction=system_instruction, |
| contents=contents, |
| tools=tools_lib, |
| tool_config=tool_config, |
| ttl=ttl, |
| expire_time=expire_time, |
| ) |
|
|
| return protos.CreateCachedContentRequest(cached_content=cached_content) |
|
|
| @classmethod |
| def create( |
| cls, |
| model: str, |
| *, |
| display_name: str | None = None, |
| system_instruction: Optional[content_types.ContentType] = None, |
| contents: Optional[content_types.ContentsType] = None, |
| tools: Optional[content_types.FunctionLibraryType] = None, |
| tool_config: Optional[content_types.ToolConfigType] = None, |
| ttl: Optional[caching_types.TTLTypes] = None, |
| expire_time: Optional[caching_types.ExpireTimeTypes] = None, |
| ) -> CachedContent: |
| """Creates `CachedContent` resource. |
| |
| Args: |
| model: The name of the `model` to use for cached content creation. |
| Any `CachedContent` resource can be only used with the |
| `model` it was created for. |
| display_name: The user-generated meaningful display name |
| of the cached content. `display_name` must be no |
| more than 128 unicode characters. |
| system_instruction: Developer set system instruction. |
| contents: Contents to cache. |
| tools: A list of `Tools` the model may use to generate response. |
| tool_config: Config to apply to all tools. |
| ttl: TTL for cached resource (in seconds). Defaults to 1 hour. |
| `ttl` and `expire_time` are exclusive arguments. |
| expire_time: Expiration time for cached resource. |
| `ttl` and `expire_time` are exclusive arguments. |
| |
| Returns: |
| `CachedContent` resource with specified name. |
| """ |
| client = get_default_cache_client() |
|
|
| request = cls._prepare_create_request( |
| model=model, |
| display_name=display_name, |
| system_instruction=system_instruction, |
| contents=contents, |
| tools=tools, |
| tool_config=tool_config, |
| ttl=ttl, |
| expire_time=expire_time, |
| ) |
|
|
| response = client.create_cached_content(request) |
| result = CachedContent._from_obj(response) |
| return result |
|
|
| @classmethod |
| def get(cls, name: str) -> CachedContent: |
| """Fetches required `CachedContent` resource. |
| |
| Args: |
| name: The resource name referring to the cached content. |
| |
| Returns: |
| `CachedContent` resource with specified `name`. |
| """ |
| client = get_default_cache_client() |
|
|
| if "cachedContents/" not in name: |
| name = "cachedContents/" + name |
|
|
| request = protos.GetCachedContentRequest(name=name) |
| response = client.get_cached_content(request) |
| result = CachedContent._from_obj(response) |
| return result |
|
|
| @classmethod |
| def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]: |
| """Lists `CachedContent` objects associated with the project. |
| |
| Args: |
| page_size: The maximum number of permissions to return (per page). |
| The service may return fewer `CachedContent` objects. |
| |
| Returns: |
| A paginated list of `CachedContent` objects. |
| """ |
| client = get_default_cache_client() |
|
|
| request = protos.ListCachedContentsRequest(page_size=page_size) |
| for cached_content in client.list_cached_contents(request): |
| cached_content = CachedContent._from_obj(cached_content) |
| yield cached_content |
|
|
| def delete(self) -> None: |
| """Deletes `CachedContent` resource.""" |
| client = get_default_cache_client() |
|
|
| request = protos.DeleteCachedContentRequest(name=self.name) |
| client.delete_cached_content(request) |
| return |
|
|
| def update( |
| self, |
| *, |
| ttl: Optional[caching_types.TTLTypes] = None, |
| expire_time: Optional[caching_types.ExpireTimeTypes] = None, |
| ) -> None: |
| """Updates requested `CachedContent` resource. |
| |
| Args: |
| ttl: TTL for cached resource (in seconds). Defaults to 1 hour. |
| `ttl` and `expire_time` are exclusive arguments. |
| expire_time: Expiration time for cached resource. |
| `ttl` and `expire_time` are exclusive arguments. |
| """ |
| client = get_default_cache_client() |
|
|
| if ttl and expire_time: |
| raise ValueError( |
| "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." |
| ) |
|
|
| ttl = caching_types.to_optional_ttl(ttl) |
| expire_time = caching_types.to_optional_expire_time(expire_time) |
|
|
| updates = protos.CachedContent( |
| name=self.name, |
| ttl=ttl, |
| expire_time=expire_time, |
| ) |
|
|
| field_mask = field_mask_pb2.FieldMask() |
|
|
| if ttl: |
| field_mask.paths.append("ttl") |
| elif expire_time: |
| field_mask.paths.append("expire_time") |
| else: |
| raise ValueError( |
| f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`." |
| ) |
|
|
| request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask) |
| updated_cc = client.update_cached_content(request) |
| self._update(updated_cc) |
|
|
| return |
|
|