| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import functools |
| from typing import Iterator |
|
|
| from google.generativeai import protos |
|
|
| from google.generativeai import client as client_lib |
| from google.generativeai.types import model_types |
| from google.api_core import operation as operation_lib |
|
|
| import tqdm.auto as tqdm |
|
|
|
|
| def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: |
| """Calls the API to list all operations""" |
|
|
| if client is None: |
| client = client_lib.get_default_operations_client() |
|
|
| |
| |
| operations = ( |
| CreateTunedModelOperation.from_proto(op, client) |
| for op in client.list_operations(name="", filter_="") |
| ) |
|
|
| return operations |
|
|
|
|
| def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: |
| """Calls the API to get a specific operation""" |
| if client is None: |
| client = client_lib.get_default_operations_client() |
|
|
| op = client.get_operation(name=name) |
| return CreateTunedModelOperation.from_proto(op, client) |
|
|
|
|
| def delete_operation(name: str, *, client=None): |
| """Calls the API to delete a specific operation""" |
|
|
| |
| if client is None: |
| client = client_lib.get_default_operations_client() |
|
|
| return client.delete_operation(name=name) |
|
|
|
|
| class CreateTunedModelOperation(operation_lib.Operation): |
| @classmethod |
| def from_proto(cls, proto, client): |
| """ |
| result = getattr(proto, 'result', None) |
| if result is not None: |
| if result.value == b'': |
| del proto.result |
| """ |
|
|
| return from_gapic( |
| cls=CreateTunedModelOperation, |
| operation=proto, |
| operations_client=client, |
| result_type=protos.TunedModel, |
| metadata_type=protos.CreateTunedModelMetadata, |
| ) |
|
|
| @classmethod |
| def from_core_operation( |
| cls, |
| operation: operation_lib.Operation, |
| ): |
| polling = getattr(operation, "_polling", None) |
| retry = getattr(operation, "_retry", None) |
| if polling is not None: |
| |
| kwargs = {"polling": polling} |
| elif retry is not None: |
| |
| kwargs = {"retry": retry} |
| else: |
| kwargs = {} |
| return cls( |
| operation=operation._operation, |
| refresh=operation._refresh, |
| cancel=operation._cancel, |
| result_type=operation._result_type, |
| metadata_type=operation._metadata_type, |
| **kwargs, |
| ) |
|
|
| @property |
| def name(self) -> str: |
| return self._operation.name |
|
|
| def update(self): |
| """Refresh the current statuses in metadata/result/error""" |
| self._refresh_and_update() |
|
|
| def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: |
| """A tqdm wait bar, yields `Operation` statuses until complete. |
| |
| Args: |
| **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` |
| |
| Yields: |
| Operation statuses as `protos.CreateTunedModelMetadata` objects. |
| """ |
| bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) |
|
|
| |
| while not self.done(): |
| metadata = self.metadata |
| bar.update(self.metadata.completed_steps - bar.n) |
| yield metadata |
| metadata = self.metadata |
| bar.update(self.metadata.completed_steps - bar.n) |
| return self.result() |
|
|
| def set_result(self, result: protos.TunedModel): |
| result = model_types.decode_tuned_model(result) |
| super().set_result(result) |
|
|
|
|
| def from_gapic( |
| cls, |
| *, |
| operation, |
| operations_client, |
| result_type, |
| metadata_type, |
| grpc_metadata=None, |
| **kwargs, |
| ): |
| """`google.api_core.operation.from_gapic`, patched to allow subclasses.""" |
| refresh = functools.partial( |
| operations_client.get_operation, operation.name, metadata=grpc_metadata |
| ) |
| cancel = functools.partial( |
| operations_client.cancel_operation, |
| operation.name, |
| metadata=grpc_metadata, |
| ) |
| return cls(operation, refresh, cancel, result_type, metadata_type, **kwargs) |
|
|