| import os |
| import argparse |
| from typing import Optional |
|
|
| import tensorflow as tf |
|
|
| try: |
| from huggingface_hub import snapshot_download |
| except Exception: |
| snapshot_download = None |
|
|
|
|
| def load_model_from_hub( |
| repo_id: str, |
| token: Optional[str] = None, |
| revision: Optional[str] = None, |
| local_dir: Optional[str] = None, |
| ) -> tf.keras.Model: |
| """Download artifacts from Hugging Face Hub and load the Keras model. |
| |
| Args: |
| repo_id: Repository like `username/lambda-keras-model`. |
| token: Optional HF token; otherwise use cached. |
| revision: Optional git revision, tag, or commit. |
| local_dir: Optional directory to place downloaded snapshot. |
| |
| Returns: |
| Loaded tf.keras.Model |
| """ |
| if snapshot_download is None: |
| raise RuntimeError( |
| "huggingface-hub is not installed. Add it to dependencies and reinstall." |
| ) |
|
|
| cache_dir = snapshot_download( |
| repo_id=repo_id, |
| token=token, |
| revision=revision, |
| local_dir=local_dir, |
| local_dir_use_symlinks=False, |
| ) |
|
|
| model_path = os.path.join(cache_dir, "lambda_model.keras") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found in repo: {model_path}") |
|
|
| model = tf.keras.models.load_model(model_path) |
| return model |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Load the Lambda tf.keras model from Hugging Face Hub") |
| parser.add_argument("--repo-id", type=str, required=True, help="Repo id, e.g. username/lambda-keras-model") |
| parser.add_argument("--hf-token", type=str, default=None, help="Hugging Face token (optional)") |
| parser.add_argument("--revision", type=str, default=None, help="Git revision, tag, or commit (optional)") |
| parser.add_argument("--local-dir", type=str, default=None, help="Optional local directory for download") |
| parser.add_argument("--run", action="store_true", help="Run a quick forward pass as a smoke test") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| model = load_model_from_hub( |
| repo_id=args.repo_id, |
| token=args.hf_token, |
| revision=args.revision, |
| local_dir=args.local_dir, |
| ) |
|
|
| model.summary() |
|
|
| if args.run: |
| |
| input_shape = tuple(dim if dim is not None else 4 for dim in model.input_shape[1:]) |
| example = tf.ones((1,) + input_shape) |
| prediction = model(example) |
| print("Example input:", example.numpy()) |
| print("Model output:", prediction.numpy()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
|
|