| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import typing |
| from typing import Any, Literal |
|
|
| import google.ai.generativelanguage as glm |
|
|
| from google.generativeai import protos |
| from google.generativeai import operations |
| from google.generativeai.client import get_default_model_client |
| from google.generativeai.types import model_types |
| from google.generativeai.types import helper_types |
| from google.api_core import operation |
| from google.api_core import protobuf_helpers |
| from google.protobuf import field_mask_pb2 |
| from google.generativeai.utils import flatten_update_paths |
|
|
|
|
| def get_model( |
| name: model_types.AnyModelNameOptions, |
| *, |
| client=None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.Model | model_types.TunedModel: |
| """Calls the API to fetch a model by name. |
| |
| ``` |
| import pprint |
| model = genai.get_model('models/gemini-1.5-flash') |
| pprint.pprint(model) |
| ``` |
| |
| Args: |
| name: The name of the model to fetch. Should start with `models/` |
| client: The client to use. |
| request_options: Options for the request. |
| |
| Returns: |
| A `types.Model` |
| """ |
| name = model_types.make_model_name(name) |
| if name.startswith("models/"): |
| return get_base_model(name, client=client, request_options=request_options) |
| elif name.startswith("tunedModels/"): |
| return get_tuned_model(name, client=client, request_options=request_options) |
| else: |
| raise ValueError( |
| f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}" |
| ) |
|
|
|
|
| def get_base_model( |
| name: model_types.BaseModelNameOptions, |
| *, |
| client=None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.Model: |
| """Calls the API to fetch a base model by name. |
| |
| ``` |
| import pprint |
| model = genai.get_base_model('models/chat-bison-001') |
| pprint.pprint(model) |
| ``` |
| |
| Args: |
| name: The name of the model to fetch. Should start with `models/` |
| client: The client to use. |
| request_options: Options for the request. |
| |
| Returns: |
| A `types.Model`. |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| name = model_types.make_model_name(name) |
| if not name.startswith("models/"): |
| raise ValueError( |
| f"Invalid model name: Base model names must start with `models/`. Received: {name}" |
| ) |
|
|
| result = client.get_model(name=name, **request_options) |
| result = type(result).to_dict(result) |
| return model_types.Model(**result) |
|
|
|
|
| def get_tuned_model( |
| name: model_types.TunedModelNameOptions, |
| *, |
| client=None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.TunedModel: |
| """Calls the API to fetch a tuned model by name. |
| |
| ``` |
| import pprint |
| model = genai.get_tuned_model('tunedModels/gemini-1.5-flash') |
| pprint.pprint(model) |
| ``` |
| |
| Args: |
| name: The name of the model to fetch. Should start with `tunedModels/` |
| client: The client to use. |
| request_options: Options for the request. |
| |
| Returns: |
| A `types.TunedModel`. |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| name = model_types.make_model_name(name) |
|
|
| if not name.startswith("tunedModels/"): |
| raise ValueError( |
| f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}" |
| ) |
|
|
| result = client.get_tuned_model(name=name, **request_options) |
|
|
| return model_types.decode_tuned_model(result) |
|
|
|
|
| def get_base_model_name( |
| model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None |
| ): |
| """Calls the API to fetch the base model name of a model.""" |
|
|
| if isinstance(model, str): |
| if model.startswith("tunedModels/"): |
| model = get_model(model, client=client) |
| base_model = model.base_model |
| else: |
| base_model = model |
| elif isinstance(model, model_types.TunedModel): |
| base_model = model.base_model |
| elif isinstance(model, model_types.Model): |
| base_model = model.name |
| elif isinstance(model, protos.Model): |
| base_model = model.name |
| elif isinstance(model, protos.TunedModel): |
| base_model = getattr(model, "base_model", None) |
| if not base_model: |
| base_model = model.tuned_model_source.base_model |
| else: |
| raise TypeError( |
| f"Invalid model: The provided model '{model}' is not recognized or supported. " |
| "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." |
| ) |
|
|
| return base_model |
|
|
|
|
| def list_models( |
| *, |
| page_size: int | None = 50, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.ModelsIterable: |
| """Calls the API to list all available models. |
| |
| ``` |
| import pprint |
| for model in genai.list_models(): |
| pprint.pprint(model) |
| ``` |
| |
| Args: |
| page_size: How many `types.Models` to fetch per page (api call). |
| client: You may pass a `glm.ModelServiceClient` instead of using the default client. |
| request_options: Options for the request. |
| |
| Yields: |
| `types.Model` objects. |
| |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| for model in client.list_models(page_size=page_size, **request_options): |
| model = type(model).to_dict(model) |
| yield model_types.Model(**model) |
|
|
|
|
| def list_tuned_models( |
| *, |
| page_size: int | None = 50, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.TunedModelsIterable: |
| """Calls the API to list all tuned models. |
| |
| ``` |
| import pprint |
| for model in genai.list_tuned_models(): |
| pprint.pprint(model) |
| ``` |
| |
| Args: |
| page_size: How many `types.Models` to fetch per page (api call). |
| client: You may pass a `glm.ModelServiceClient` instead of using the default client. |
| request_options: Options for the request. |
| |
| Yields: |
| `types.TunedModel` objects. |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| for model in client.list_tuned_models( |
| page_size=page_size, |
| **request_options, |
| ): |
| model = type(model).to_dict(model) |
| yield model_types.decode_tuned_model(model) |
|
|
|
|
| def create_tuned_model( |
| source_model: model_types.AnyModelNameOptions, |
| training_data: model_types.TuningDataOptions, |
| *, |
| id: str | None = None, |
| display_name: str | None = None, |
| description: str | None = None, |
| temperature: float | None = None, |
| top_p: float | None = None, |
| top_k: int | None = None, |
| epoch_count: int | None = None, |
| batch_size: int | None = None, |
| learning_rate: float | None = None, |
| input_key: str = "text_input", |
| output_key: str = "output", |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> operations.CreateTunedModelOperation: |
| """Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress. |
| |
| Since tuning a model can take significant time, this API doesn't wait for the tuning to complete. |
| Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the |
| status of the tuning job, or wait for it to complete, and check the result. |
| |
| After the job completes you can either find the resulting `TunedModel` object in |
| `Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`. |
| |
| ``` |
| my_id = "my-tuned-model-id" |
| operation = palm.create_tuned_model( |
| id = my_id, |
| source_model="models/text-bison-001", |
| training_data=[{'text_input': 'example input', 'output': 'example output'},...] |
| ) |
| tuned_model=operation.result() # Wait for tuning to finish |
| |
| palm.generate_text(f"tunedModels/{my_id}", prompt="...") |
| ``` |
| |
| Args: |
| source_model: The name of the model to tune. |
| training_data: The dataset to tune the model on. This must be either: |
| * A `protos.Dataset`, or |
| * An `Iterable` of: |
| *`protos.TuningExample`, |
| * `{'text_input': text_input, 'output': output}` dicts |
| * `(text_input, output)` tuples. |
| * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which |
| columns to use as the input/output |
| * A csv file (will be read with `pd.read_csv` and handles as a `Mapping` |
| above). This can be: |
| * A local path as a `str` or `pathlib.Path`. |
| * A url for a csv file. |
| * The url of a Google Sheets file. |
| * A JSON file - Its contents will be handled either as an `Iterable` or `Mapping` |
| above. This can be: |
| * A local path as a `str` or `pathlib.Path`. |
| id: The model identifier, used to refer to the model in the API |
| `tunedModels/{id}`. Must be unique. |
| display_name: A human-readable name for display. |
| description: A description of the tuned model. |
| temperature: The default temperature for the tuned model, see `types.Model` for details. |
| top_p: The default `top_p` for the model, see `types.Model` for details. |
| top_k: The default `top_k` for the model, see `types.Model` for details. |
| epoch_count: The number of tuning epochs to run. An epoch is a pass over the whole dataset. |
| batch_size: The number of examples to use in each training batch. |
| learning_rate: The step size multiplier for the gradient updates. |
| client: Which client to use. |
| request_options: Options for the request. |
| |
| Returns: |
| A [`google.api_core.operation.Operation`](https://googleapis.dev/python/google-api-core/latest/operation.html) |
| """ |
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| source_model_name = model_types.make_model_name(source_model) |
| base_model_name = get_base_model_name(source_model) |
| if source_model_name.startswith("models/"): |
| source_model = {"base_model": source_model_name} |
| elif source_model_name.startswith("tunedModels/"): |
| source_model = { |
| "tuned_model_source": { |
| "tuned_model": source_model_name, |
| "base_model": base_model_name, |
| } |
| } |
| else: |
| raise ValueError( |
| f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'" |
| ) |
|
|
| training_data = model_types.encode_tuning_data( |
| training_data, input_key=input_key, output_key=output_key |
| ) |
|
|
| hyperparameters = protos.Hyperparameters( |
| epoch_count=epoch_count, |
| batch_size=batch_size, |
| learning_rate=learning_rate, |
| ) |
| tuning_task = protos.TuningTask( |
| training_data=training_data, |
| hyperparameters=hyperparameters, |
| ) |
|
|
| tuned_model = protos.TunedModel( |
| **source_model, |
| display_name=display_name, |
| description=description, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| tuning_task=tuning_task, |
| ) |
|
|
| operation = client.create_tuned_model( |
| dict(tuned_model_id=id, tuned_model=tuned_model), **request_options |
| ) |
|
|
| return operations.CreateTunedModelOperation.from_core_operation(operation) |
|
|
|
|
| @typing.overload |
| def update_tuned_model( |
| tuned_model: protos.TunedModel, |
| updates: None = None, |
| *, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.TunedModel: |
| pass |
|
|
|
|
| @typing.overload |
| def update_tuned_model( |
| tuned_model: str, |
| updates: dict[str, Any], |
| *, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.TunedModel: |
| pass |
|
|
|
|
| def update_tuned_model( |
| tuned_model: str | protos.TunedModel, |
| updates: dict[str, Any] | None = None, |
| *, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> model_types.TunedModel: |
| """Calls the API to push updates to a specified tuned model where only certain attributes are updatable.""" |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| if isinstance(tuned_model, str): |
| name = tuned_model |
| if not isinstance(updates, dict): |
| raise TypeError( |
| f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}." |
| ) |
|
|
| tuned_model = client.get_tuned_model(name=name, **request_options) |
|
|
| updates = flatten_update_paths(updates) |
| field_mask = field_mask_pb2.FieldMask() |
| for path in updates.keys(): |
| field_mask.paths.append(path) |
| for path, value in updates.items(): |
| _apply_update(tuned_model, path, value) |
| elif isinstance(tuned_model, protos.TunedModel): |
| if updates is not None: |
| raise ValueError( |
| "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, " |
| "the `updates` argument must not be set." |
| ) |
|
|
| name = tuned_model.name |
| was = client.get_tuned_model(name=name) |
| field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) |
| else: |
| raise TypeError( |
| "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " |
| f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." |
| ) |
|
|
| result = client.update_tuned_model( |
| protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), |
| **request_options, |
| ) |
| return model_types.decode_tuned_model(result) |
|
|
|
|
| def _apply_update(thing, path, value): |
| parts = path.split(".") |
| for part in parts[:-1]: |
| thing = getattr(thing, part) |
| setattr(thing, parts[-1], value) |
|
|
|
|
| def delete_tuned_model( |
| tuned_model: model_types.TunedModelNameOptions, |
| client: glm.ModelServiceClient | None = None, |
| request_options: helper_types.RequestOptionsType | None = None, |
| ) -> None: |
| """Calls the API to delete a specified tuned model""" |
|
|
| if request_options is None: |
| request_options = {} |
|
|
| if client is None: |
| client = get_default_model_client() |
|
|
| name = model_types.make_model_name(tuned_model) |
| client.delete_tuned_model(name=name, **request_options) |
|
|