| import threading |
| import time |
| from collections import deque |
|
|
| import huggingface_hub |
| from gradio_client import Client |
|
|
| from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name |
|
|
|
|
| class Run: |
| def __init__( |
| self, |
| url: str, |
| project: str, |
| client: Client, |
| name: str | None = None, |
| config: dict | None = None, |
| ): |
| self.url = url |
| self.project = project |
| self._client_lock = threading.Lock() |
| self._client_thread = None |
| self._client = client |
| self.name = name or generate_readable_name() |
| self.config = config or {} |
| self._queued_logs = deque() |
|
|
| if client is None: |
| self._client_thread = threading.Thread(target=self._init_client_background) |
| self._client_thread.start() |
|
|
| def _init_client_background(self): |
| fib = fibo() |
| for sleep_coefficient in fib: |
| try: |
| client = Client(self.url, verbose=False) |
| with self._client_lock: |
| self._client = client |
| if len(self._queued_logs) > 0: |
| for queued_log in self._queued_logs: |
| self._client.predict(**queued_log) |
| self._queued_logs.clear() |
| break |
| except Exception: |
| pass |
| if sleep_coefficient is not None: |
| time.sleep(0.1 * sleep_coefficient) |
|
|
| def log(self, metrics: dict): |
| for k in metrics.keys(): |
| if k in RESERVED_KEYS or k.startswith("__"): |
| raise ValueError( |
| f"Please do not use this reserved key as a metric: {k}" |
| ) |
| with self._client_lock: |
| if self._client is None: |
| |
| |
| self._queued_logs.append( |
| dict( |
| api_name="/log", |
| project=self.project, |
| run=self.name, |
| metrics=metrics, |
| hf_token=huggingface_hub.utils.get_token(), |
| ) |
| ) |
| else: |
| assert ( |
| len(self._queued_logs) == 0 |
| ) |
| |
| self._client.predict( |
| api_name="/log", |
| project=self.project, |
| run=self.name, |
| metrics=metrics, |
| hf_token=huggingface_hub.utils.get_token(), |
| ) |
|
|
| def finish(self): |
| """Cleanup when run is finished.""" |
| |
| if self._client_thread is not None: |
| self._client_thread.join() |
|
|