Deploybot
Deploy from stable branch
49574d5
"""Shared Granite Vision model loader.
Split into two stages for ZeroGPU compatibility:
- load_processor(): CPU-only, safe to call at startup or outside @spaces.GPU
- load_model(): requires CUDA, must only be called inside a @spaces.GPU context
The processor and model are cached globally so they are loaded at most once.
"""
import os
from typing import Any
_processor: Any = None
_model: Any = None
MODEL_ID = "ibm-granite/granite-vision-4.1-4b"
def load_processor() -> Any:
"""Load (or return cached) processor. CPU-only — safe outside @spaces.GPU.
Returns:
AutoProcessor instance, or None if loading fails.
"""
global _processor # noqa: PLW0603
if _processor is not None:
return _processor
try:
from transformers import AutoProcessor
token = os.environ.get("HF_TOKEN")
_processor = AutoProcessor.from_pretrained(
MODEL_ID, trust_remote_code=True, token=token, use_fast=True
)
print(f"Processor loaded for {MODEL_ID}")
return _processor
except Exception as e: # noqa: BLE001
import traceback
print(f"Processor load error: {e}")
traceback.print_exc()
return None
def load_model() -> tuple[Any, Any]:
"""Load (or return cached) model to CUDA. Must be called inside @spaces.GPU.
Returns:
Tuple of (processor, model), or (None, None) if loading fails.
"""
global _model # noqa: PLW0603
processor = load_processor()
if processor is None:
return None, None
if _model is not None:
return processor, _model
try:
import torch
from transformers import AutoModelForImageTextToText
token = os.environ.get("HF_TOKEN")
# Load on CPU first to avoid caching_allocator_warmup triggering
# torch._C._cuda_init() before ZeroGPU can intercept it.
_model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
dtype=torch.bfloat16,
token=token,
).eval()
_model = _model.to("cuda")
if hasattr(_model, "merge_lora_adapters"):
_model = _model.merge_lora_adapters()
print(f"Model loaded: {MODEL_ID} on cuda")
return processor, _model
except Exception as e: # noqa: BLE001
import traceback
print(f"Model load error: {e}")
traceback.print_exc()
return processor, None