Spaces:
Running on Zero
Running on Zero
| """Shared Granite Vision model loader. | |
| Split into two stages for ZeroGPU compatibility: | |
| - load_processor(): CPU-only, safe to call at startup or outside @spaces.GPU | |
| - load_model(): requires CUDA, must only be called inside a @spaces.GPU context | |
| The processor and model are cached globally so they are loaded at most once. | |
| """ | |
| import os | |
| from typing import Any | |
| _processor: Any = None | |
| _model: Any = None | |
| MODEL_ID = "ibm-granite/granite-vision-4.1-4b" | |
| def load_processor() -> Any: | |
| """Load (or return cached) processor. CPU-only — safe outside @spaces.GPU. | |
| Returns: | |
| AutoProcessor instance, or None if loading fails. | |
| """ | |
| global _processor # noqa: PLW0603 | |
| if _processor is not None: | |
| return _processor | |
| try: | |
| from transformers import AutoProcessor | |
| token = os.environ.get("HF_TOKEN") | |
| _processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, token=token, use_fast=True | |
| ) | |
| print(f"Processor loaded for {MODEL_ID}") | |
| return _processor | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| print(f"Processor load error: {e}") | |
| traceback.print_exc() | |
| return None | |
| def load_model() -> tuple[Any, Any]: | |
| """Load (or return cached) model to CUDA. Must be called inside @spaces.GPU. | |
| Returns: | |
| Tuple of (processor, model), or (None, None) if loading fails. | |
| """ | |
| global _model # noqa: PLW0603 | |
| processor = load_processor() | |
| if processor is None: | |
| return None, None | |
| if _model is not None: | |
| return processor, _model | |
| try: | |
| import torch | |
| from transformers import AutoModelForImageTextToText | |
| token = os.environ.get("HF_TOKEN") | |
| # Load on CPU first to avoid caching_allocator_warmup triggering | |
| # torch._C._cuda_init() before ZeroGPU can intercept it. | |
| _model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| dtype=torch.bfloat16, | |
| token=token, | |
| ).eval() | |
| _model = _model.to("cuda") | |
| if hasattr(_model, "merge_lora_adapters"): | |
| _model = _model.merge_lora_adapters() | |
| print(f"Model loaded: {MODEL_ID} on cuda") | |
| return processor, _model | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| print(f"Model load error: {e}") | |
| traceback.print_exc() | |
| return processor, None | |