| from __future__ import annotations |
|
|
| import gzip |
| import os |
| import shutil |
| import tempfile |
| from typing import Optional |
|
|
| import torch |
| from huggingface_hub import HfApi, hf_hub_download, login |
|
|
| REPO_ID = "WCNegentropy/BitTransformerLM" |
| FILENAME = "model.pt.gz" |
|
|
|
|
| def hf_login(token: Optional[str] = None) -> None: |
| """Authenticate with Hugging Face. |
| |
| The ``token`` may be provided directly or via the ``HF_TOKEN`` environment |
| variable. If omitted entirely, the library will attempt an interactive login. |
| """ |
| login(token=token) |
|
|
|
|
| def save_checkpoint( |
| model: torch.nn.Module, |
| *, |
| repo_id: str = REPO_ID, |
| filename: str = FILENAME, |
| ) -> None: |
| """Upload the model weights to ``repo_id`` under ``filename``. |
| |
| The file within the repository is overwritten each time to avoid |
| accumulating checkpoints. |
| """ |
| with tempfile.TemporaryDirectory() as tmp: |
| tmp_pt = os.path.join(tmp, "model.pt") |
| tmp_gz = os.path.join(tmp, filename) |
| torch.save(model.state_dict(), tmp_pt) |
| with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst: |
| dst.write(src.read()) |
| HfApi().upload_file( |
| path_or_fileobj=tmp_gz, |
| path_in_repo=f"checkpoints/{filename}", |
| repo_id=repo_id, |
| repo_type="model", |
| overwrite=True, |
| ) |
|
|
|
|
| def download_checkpoint( |
| dest_path: str, |
| *, |
| repo_id: str = REPO_ID, |
| filename: str = FILENAME, |
| ) -> bool: |
| """Download the latest checkpoint to ``dest_path``. |
| |
| Returns ``True`` if the checkpoint was successfully retrieved. |
| """ |
| try: |
| buf = hf_hub_download( |
| repo_id, |
| f"checkpoints/{filename}", |
| repo_type="model", |
| force_download=True, |
| ) |
| except Exception as exc: |
| print("Failed to download checkpoint", exc) |
| return False |
| os.makedirs(os.path.dirname(dest_path), exist_ok=True) |
| shutil.copyfile(buf, dest_path) |
| return True |
|
|
|
|
| __all__ = ["hf_login", "save_checkpoint", "download_checkpoint"] |
|
|