| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import dataclasses |
| from collections.abc import Iterable, Sequence |
| import itertools |
| from typing import Any, Iterable, overload, TypeVar |
|
|
| import google.ai.generativelanguage as glm |
|
|
| from google.generativeai.client import get_default_text_client |
| from google.generativeai import string_utils |
| from google.generativeai.types import text_types |
| from google.generativeai.types import model_types |
| from google.generativeai import models |
| from google.generativeai.types import safety_types |
|
|
| DEFAULT_TEXT_MODEL = "models/text-bison-001" |
| EMBEDDING_MAX_BATCH_SIZE = 100 |
|
|
| try: |
| |
| _batched = itertools.batched |
| except AttributeError: |
| T = TypeVar("T") |
|
|
| def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: |
| if n < 1: |
| raise ValueError(f"Batch size `n` must be >1, got: {n}") |
| batch = [] |
| for item in iterable: |
| batch.append(item) |
| if len(batch) == n: |
| yield batch |
| batch = [] |
|
|
| if batch: |
| yield batch |
|
|
|
|
| def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: |
| """ |
| Creates a `glm.TextPrompt` object based on the provided prompt input. |
| |
| Args: |
| prompt: The prompt input, either a string or a dictionary. |
| |
| Returns: |
| glm.TextPrompt: A TextPrompt object containing the prompt text. |
| |
| Raises: |
| TypeError: If the provided prompt is neither a string nor a dictionary. |
| """ |
| if isinstance(prompt, str): |
| return glm.TextPrompt(text=prompt) |
| elif isinstance(prompt, dict): |
| return glm.TextPrompt(prompt) |
| else: |
| TypeError("Expected string or dictionary for text prompt.") |
|
|
|
|
| def _make_generate_text_request( |
| *, |
| model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, |
| prompt: str | None = None, |
| temperature: float | None = None, |
| candidate_count: int | None = None, |
| max_output_tokens: int | None = None, |
| top_p: int | None = None, |
| top_k: int | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| stop_sequences: str | Iterable[str] | None = None, |
| ) -> glm.GenerateTextRequest: |
| """ |
| Creates a `glm.GenerateTextRequest` object based on the provided parameters. |
| |
| This function generates a `glm.GenerateTextRequest` object with the specified |
| parameters. It prepares the input parameters and creates a request that can be |
| used for generating text using the chosen model. |
| |
| Args: |
| model: The model to use for text generation. |
| prompt: The prompt for text generation. Defaults to None. |
| temperature: The temperature for randomness in generation. Defaults to None. |
| candidate_count: The number of candidates to consider. Defaults to None. |
| max_output_tokens: The maximum number of output tokens. Defaults to None. |
| top_p: The nucleus sampling probability threshold. Defaults to None. |
| top_k: The top-k sampling parameter. Defaults to None. |
| safety_settings: Safety settings for generated text. Defaults to None. |
| stop_sequences: Stop sequences to halt text generation. Can be a string |
| or iterable of strings. Defaults to None. |
| |
| Returns: |
| `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. |
| """ |
| model = model_types.make_model_name(model) |
| prompt = _make_text_prompt(prompt=prompt) |
| safety_settings = safety_types.normalize_safety_settings( |
| safety_settings, harm_category_set="old" |
| ) |
| if isinstance(stop_sequences, str): |
| stop_sequences = [stop_sequences] |
| if stop_sequences: |
| stop_sequences = list(stop_sequences) |
|
|
| return glm.GenerateTextRequest( |
| model=model, |
| prompt=prompt, |
| temperature=temperature, |
| candidate_count=candidate_count, |
| max_output_tokens=max_output_tokens, |
| top_p=top_p, |
| top_k=top_k, |
| safety_settings=safety_settings, |
| stop_sequences=stop_sequences, |
| ) |
|
|
|
|
| def generate_text( |
| *, |
| model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, |
| prompt: str, |
| temperature: float | None = None, |
| candidate_count: int | None = None, |
| max_output_tokens: int | None = None, |
| top_p: float | None = None, |
| top_k: float | None = None, |
| safety_settings: safety_types.SafetySettingOptions | None = None, |
| stop_sequences: str | Iterable[str] | None = None, |
| client: glm.TextServiceClient | None = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> text_types.Completion: |
| """Calls the API and returns a `types.Completion` containing the response. |
| |
| Args: |
| model: Which model to call, as a string or a `types.Model`. |
| prompt: Free-form input text given to the model. Given a prompt, the model will |
| generate text that completes the input text. |
| temperature: Controls the randomness of the output. Must be positive. |
| Typical values are in the range: `[0.0,1.0]`. Higher values produce a |
| more random and varied response. A temperature of zero will be deterministic. |
| candidate_count: The **maximum** number of generated response messages to return. |
| This value must be between `[1, 8]`, inclusive. If unset, this |
| will default to `1`. |
| |
| Note: Only unique candidates are returned. Higher temperatures are more |
| likely to produce unique candidates. Setting `temperature=0.0` will always |
| return 1 candidate regardless of the `candidate_count`. |
| max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater |
| than zero. If unset, will default to 64. |
| top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. |
| `top_k` sets the maximum number of tokens to sample from on each step. |
| top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. |
| `top_p` configures the nucleus sampling. It sets the maximum cumulative |
| probability of tokens to sample from. |
| For example, if the sorted probabilities are |
| `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample |
| as `[0.625, 0.25, 0.125, 0, 0, 0]`. |
| safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. |
| These will be enforced on the `prompt` and |
| `candidates`. There should not be more than one |
| setting for each `types.SafetyCategory` type. The API will block any prompts and |
| responses that fail to meet the thresholds set by these settings. This list |
| overrides the default settings for each `SafetyCategory` specified in the |
| safety_settings. If there is no `types.SafetySetting` for a given |
| `SafetyCategory` provided in the list, the API will use the default safety |
| setting for that category. |
| stop_sequences: A set of up to 5 character sequences that will stop output generation. |
| If specified, the API will stop at the first appearance of a stop |
| sequence. The stop sequence will not be included as part of the response. |
| client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. |
| request_options: Options for the request. |
| |
| Returns: |
| A `types.Completion` containing the model's text completion response. |
| """ |
| request = _make_generate_text_request( |
| model=model, |
| prompt=prompt, |
| temperature=temperature, |
| candidate_count=candidate_count, |
| max_output_tokens=max_output_tokens, |
| top_p=top_p, |
| top_k=top_k, |
| safety_settings=safety_settings, |
| stop_sequences=stop_sequences, |
| ) |
|
|
| return _generate_response(client=client, request=request, request_options=request_options) |
|
|
|
|
| @string_utils.prettyprint |
| @dataclasses.dataclass(init=False) |
| class Completion(text_types.Completion): |
| def __init__(self, **kwargs): |
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
|
|
| self.result = None |
| if self.candidates: |
| self.result = self.candidates[0]["output"] |
|
|
|
|
| def _generate_response( |
| request: glm.GenerateTextRequest, |
| client: glm.TextServiceClient = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> Completion: |
| """ |
| Generates a response using the provided `glm.GenerateTextRequest` and client. |
| |
| Args: |
| request: The text generation request. |
| client: The client to use for text generation. Defaults to None, in which |
| case the default text client is used. |
| request_options: Options for the request. |
| |
| Returns: |
| `Completion`: A `Completion` object with the generated text and response information. |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_text_client() |
|
|
| response = client.generate_text(request, **request_options) |
| response = type(response).to_dict(response) |
|
|
| response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) |
| response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( |
| response["safety_feedback"] |
| ) |
| response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) |
|
|
| return Completion(_client=client, **response) |
|
|
|
|
| def count_text_tokens( |
| model: model_types.AnyModelNameOptions, |
| prompt: str, |
| client: glm.TextServiceClient | None = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> text_types.TokenCount: |
| base_model = models.get_base_model_name(model) |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_text_client() |
|
|
| result = client.count_text_tokens( |
| glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), |
| **request_options, |
| ) |
|
|
| return type(result).to_dict(result) |
|
|
|
|
| @overload |
| def generate_embeddings( |
| model: model_types.BaseModelNameOptions, |
| text: str, |
| client: glm.TextServiceClient = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> text_types.EmbeddingDict: ... |
|
|
|
|
| @overload |
| def generate_embeddings( |
| model: model_types.BaseModelNameOptions, |
| text: Sequence[str], |
| client: glm.TextServiceClient = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> text_types.BatchEmbeddingDict: ... |
|
|
|
|
| def generate_embeddings( |
| model: model_types.BaseModelNameOptions, |
| text: str | Sequence[str], |
| client: glm.TextServiceClient = None, |
| request_options: dict[str, Any] | None = None, |
| ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: |
| """Calls the API to create an embedding for the text passed in. |
| |
| Args: |
| model: Which model to call, as a string or a `types.Model`. |
| |
| text: Free-form input text given to the model. Given a string, the model will |
| generate an embedding based on the input text. |
| |
| client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. |
| |
| request_options: Options for the request. |
| |
| Returns: |
| Dictionary containing the embedding (list of float values) for the input text. |
| """ |
| model = model_types.make_model_name(model) |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_text_client() |
|
|
| if isinstance(text, str): |
| embedding_request = glm.EmbedTextRequest(model=model, text=text) |
| embedding_response = client.embed_text( |
| embedding_request, |
| **request_options, |
| ) |
| embedding_dict = type(embedding_response).to_dict(embedding_response) |
| embedding_dict["embedding"] = embedding_dict["embedding"]["value"] |
| else: |
| result = {"embedding": []} |
| for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): |
| |
| embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) |
| embedding_response = client.batch_embed_text( |
| embedding_request, |
| **request_options, |
| ) |
| embedding_dict = type(embedding_response).to_dict(embedding_response) |
| result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"]) |
| return result |
|
|
| return embedding_dict |
|
|